From 3ee8aa263ccf495a358e5b2b2e6cd525fd792002 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Thu, 12 Mar 2026 22:42:41 -0700 Subject: [PATCH 1/2] Merge branch 'v2' into 'main' --- scripts/train/assets/validation_ids_v2.txt | 398 ++++ scripts/train/configs/structure.yaml | 2 +- scripts/train/configs/structurev2.yaml | 244 +++ scripts/train/train.py | 14 +- src/boltz/data/crop/boltz.py | 233 ++- src/boltz/data/feature/featurizerv2.py | 17 +- src/boltz/data/feature/featurizerv2_train.py | 1964 ++++++++++++++++++ src/boltz/data/module/trainingv2.py | 839 ++++++-- src/boltz/data/parse/mmcif.py | 4 +- src/boltz/data/sample/cluster.py | 18 +- src/boltz/data/sample/distillation.py | 6 +- src/boltz/data/sample/random.py | 6 +- src/boltz/data/sample/sampler.py | 5 +- src/boltz/data/sample/v2/__init__.py | 0 src/boltz/data/sample/v2/cluster.py | 288 +++ src/boltz/data/sample/v2/distillation.py | 53 + src/boltz/data/sample/v2/random.py | 38 + src/boltz/data/sample/v2/sampler.py | 46 + src/boltz/data/template/__init__.py | 0 src/boltz/data/template/feature.py | 75 + src/boltz/data/tokenize/boltz2.py | 229 +- src/boltz/data/types.py | 46 + src/boltz/main.py | 19 +- src/boltz/model/loss/inference.py | 445 ++++ src/boltz/model/models/boltz2.py | 42 +- src/boltz/model/modules/diffusionv2.py | 57 +- src/boltz/model/validation/rcsb.py | 61 + src/boltz/model/validation/validator.py | 1266 +++++++++++ 28 files changed, 6090 insertions(+), 325 deletions(-) create mode 100644 scripts/train/assets/validation_ids_v2.txt create mode 100644 scripts/train/configs/structurev2.yaml create mode 100644 src/boltz/data/feature/featurizerv2_train.py create mode 100644 src/boltz/data/sample/v2/__init__.py create mode 100644 src/boltz/data/sample/v2/cluster.py create mode 100644 src/boltz/data/sample/v2/distillation.py create mode 100644 src/boltz/data/sample/v2/random.py create mode 100644 src/boltz/data/sample/v2/sampler.py create mode 100644 src/boltz/data/template/__init__.py create mode 100644 src/boltz/data/template/feature.py create mode 100644 src/boltz/model/loss/inference.py create mode 100644 src/boltz/model/validation/rcsb.py create mode 100644 src/boltz/model/validation/validator.py diff --git a/scripts/train/assets/validation_ids_v2.txt b/scripts/train/assets/validation_ids_v2.txt new file mode 100644 index 000000000..f957818e1 --- /dev/null +++ b/scripts/train/assets/validation_ids_v2.txt @@ -0,0 +1,398 @@ +8Q41 +8BH8 +8BH9 +8HIG +8IA3 +8SVD +7ZLQ +8TPK +8Q3Z +8Q40 +8K3D +8SVA +8D9I +8SSU +8PE3 +8GN3 +8GN4 +8K4L +8EB5 +8B4E +8Q43 +8H0L +8TP8 +8SSQ +8F5G +8SQU +8DWJ +8Q42 +8D9E +8Q44 +8ISZ +7XF1 +8HML +8SSR +8B4D +8B4C +7YSF +8J0K +8J0R +8GBA +8GTY +8GBM +7YL4 +8ANU +8K3F +8J7J +8CQM +8CLZ +8HI7 +7YH3 +8OGG +8CR3 +8GBH +8BCK +7YPR +8ILL +8OK7 +8FR5 +8BU0 +8OYY +7XVI +8JWS +8R8A +8Q1K +8IN6 +8JI2 +8GD6 +8JDG +8HHV +8JT9 +8QFN +8GKX +8ACG +8AVZ +8BHU +8GL4 +8CHX +7Z3I +8UPI +8DQ2 +8HVC +8OQH +7Y9G +7YH5 +8B2H +8IQC +7YN2 +8I3X +7YR9 +8CH4 +8GOJ +8FID +7T4W +8D40 +8HFP +8K5K +8GBI +8OW8 +8Q70 +8SDY +8JMR +8IIB +8SW0 +8DQ6 +8B6Q +8H2N +8SUT +8J9Q +8IW5 +8BL6 +8BBR +8JWU +8XBI +8B2G +8ASA +8B2E +8EY3 +8GJ9 +7ZAO +8B2S +8GD8 +8A9N +8TN8 +8I34 +7Y1S +8F9X +8PXC +8B5W +8EC3 +8HFC +8GYG +8IR2 +8CJG +8BUX +7UWU +8GYR +8EHN +7ZCB +7YH1 +8IUB +8BVK +8OYE +8GJ7 +8C4D +8FTV +8GIV +8T8K +8JZU +8A26 +8HC0 +8EO2 +8H72 +8HMM +8ORU +8C05 +8G95 +8BSA +7ZC9 +8IF7 +8QC6 +7QP5 +7YLS +8TN3 +8ANI +7ZET +8AVH +8D2Z +8THR +8K3B +8J1C +7Y4R +8U2O +8BTJ +8HND +8IN1 +8JX6 +7YLR +8GK6 +8B7D +8BGJ +8P0S +8B4L +8JJ5 +8EVX +8H24 +8OF7 +8IK2 +8P5N +8GJA +7YTL +8H4P +8KI5 +8T5J +8GHH +7ZN6 +8JNG +8GR7 +8H7J +8CI3 +8AH9 +8BS3 +8AGA +8U01 +8HN2 +8H5A +7YRO +8AN5 +8HLG +8OKS +8K4X +8D2J +8PKD +7VXT +8H3Z +8TJI +8K76 +8CIL +8HF2 +8P6K +7YW0 +8C26 +8BV8 +8BS0 +8SME +8I8A +7YJI +8EOZ +8OML +8BUY +8OWI +8DTQ +8P26 +8EZT +8FW7 +7Z8Z +8DQA +8K5L +8PR6 +8K1F +8AXR +8JVM +8ANT +8ER5 +8OXU +8H2R +7YPE +8WOJ +8SS1 +8G5S +8G2V +8SAO +8A50 +8BDK +7YIL +8ARD +8KGZ +8HO2 +8ECX +8P32 +8AGY +8SZB +8ONB +8DSP +8BJW +8U1J +8FIH +8KE8 +8SJJ +8E2B +8U1K +8POC +8HNO +8K4W +8UC6 +8AMR +8ON4 +8PHB +8IW0 +8OKW +7Y16 +8PX4 +8BRF +8BZ2 +8HGU +8ERW +8GAQ +8HM5 +7YKH +8SIU +8G0K +8ONQ +7XPC +8BIY +8BNB +8U12 +8FJE +8A39 +8FIT +8SRZ +8HUC +7W91 +7XRB +8R79 +8AU0 +8OFJ +8H1K +8BCS +8TTO +8B3Z +8EBB +8QK8 +8OIJ +8B45 +8EOX +8EK4 +8E1E +8ONF +8HHJ +8ENQ +7YEQ +8TNO +8AN0 +8OK3 +8T4C +8G8K +8I3J +8HDL +8AXJ +8GY4 +8H5S +8CAR +8FEH +8UGC +8GJY +8OYX +8DOT +8D00 +8B6N +8ANK +8ARL +8POF +8HCY +7Y3Z +8KCA +8CGM +7QRL +8GSY +8T0J +8HNA +8OXL +8P7A +8EB9 +8JNM +8I0A +8F8N +8P5P +8OYS +8JIY +7XXE +7YD4 +8FYG +8HZZ +8ES6 +7WU8 +8GJW +8SOT +8B4U +8AY2 +8J1W +8HM4 +8VEH +8TMS +8HX3 +8FJF +8TVL +8BKE +8J69 +8EOV +8OYW +8BJV +7TBO +8OXH +8B74 +8ANJ +8HJJ +8U0X +8OYV +7YJL +8DHJ +8EO7 +8JU8 +8JN0 +7VS2 +8F5D +8ETQ +8J0H +8JNN +8CQ3 \ No newline at end of file diff --git a/scripts/train/configs/structure.yaml b/scripts/train/configs/structure.yaml index 6591f386a..9bd866ddd 100644 --- a/scripts/train/configs/structure.yaml +++ b/scripts/train/configs/structure.yaml @@ -75,7 +75,7 @@ data: compute_constraint_features: false model: - _target_: boltz.model.model.Boltz1 + _target_: boltz.model.models.boltz1.Boltz1 atom_s: 128 atom_z: 16 token_s: 384 diff --git a/scripts/train/configs/structurev2.yaml b/scripts/train/configs/structurev2.yaml new file mode 100644 index 000000000..ea09d18b6 --- /dev/null +++ b/scripts/train/configs/structurev2.yaml @@ -0,0 +1,244 @@ +trainer: + accelerator: cuda + devices: 1 + num_nodes: 1 + precision: bf16-mixed + gradient_clip_val: 10.0 + accumulate_grad_batches: 1 + max_epochs: -1 + +# Optional set wandb here +# wandb: +# name: boltz +# project: boltz +# entity: boltz + +output: # PATH_HERE +pretrained: null +resume: null +disable_checkpoint: false +matmul_precision: null +save_top_k: -1 +v2: true + +data: + datasets: + # RCSB Data + - _target_: boltz.data.module.trainingv2.DatasetConfig + target_dir: #PATH_HERE + msa_dir: #PATH_HERE + template_dir: #PATH_HERE + prob: 0.55 + filters: + - _target_: boltz.data.filter.dynamic.size.SizeFilter + min_chains: 1 + max_chains: 300 + - _target_: boltz.data.filter.dynamic.date.DateFilter + date: "2023-06-01" + ref: released + - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter + resolution: 9.0 + sampler: + _target_: boltz.data.sample.v2.cluster.ClusterSampler + cropper: + _target_: boltz.data.crop.boltz.BoltzCropper + min_neighborhood: 0 + max_neighborhood: 40 + split: ./scripts/train/assets/validation_ids_v2.txt + symmetry_correction: true + val_group: "RCSB" + + # AFDB Distillation Data + - _target_: boltz.data.module.trainingv2.DatasetConfig + target_dir: #PATH_HERE + msa_dir: #PATH_HERE + template_dir: null + prob: 0.45 + filters: + - _target_: boltz.data.filter.dynamic.size.SizeFilter + min_chains: 1 + max_chains: 300 + sampler: + _target_: boltz.data.sample.v2.cluster.ClusterSampler + cropper: + _target_: boltz.data.crop.boltz.BoltzCropper + min_neighborhood: 0 + max_neighborhood: 40 + symmetry_correction: true + override_method: "AFDB" + override_bfactor: true + + checkpoint_monitor_val_group: "val/lddt" # dataset __RCSB is turned to "" # which validation dataset group to use for checkpoint monitoring + tokenizer: + _target_: boltz.data.tokenize.boltz2.Boltz2TrainingTokenizer + featurizer: + _target_: boltz.data.feature.featurizerv2_train.Boltz2Featurizer + + moldir: #PATH_HERE + max_tokens: 384 # 640 + max_atoms: 3456 # 5760 + max_seqs: 8192 + pad_to_max_tokens: true + pad_to_max_atoms: true + pad_to_max_seqs: true + samples_per_epoch: 36096 + batch_size: 1 + num_workers: 2 + random_seed: 42 + pin_memory: false + overfit: null + return_train_symmetries: false + return_val_symmetries: true + train_binder_pocket_conditioned_prop: 0.15 + val_binder_pocket_conditioned_prop: 0.15 + train_contact_conditioned_prop: 0.15 + val_contact_conditioned_prop: 0.15 + binder_pocket_cutoff_val: 6.0 + binder_pocket_cutoff_min: 4.0 + binder_pocket_cutoff_max: 20.0 + binder_pocket_sampling_geometric_p: 0.3 + atoms_per_window_queries: 32 + min_dist: 2.0 + max_dist: 22.0 + num_bins: 64 + num_ensembles_train: 1 + num_ensembles_val: 1 + fix_single_ensemble: false + disto_use_ensemble: true + single_sequence_prop_training: 0.05 + max_templates_train: 4 + max_templates_val: 4 + no_template_prob_train: 0.6 + no_template_prob_val: 1.0 + use_templates: false + msa_sampling_training: true + bfactor_md_correction: true + +model: + _target_: boltz.model.models.boltz2.Boltz2 + atom_s: 128 + atom_z: 16 + token_s: 384 + token_z: 128 + num_bins: 64 + atom_feature_dim: 388 + atoms_per_window_queries: 32 + atoms_per_window_keys: 128 + compile_pairformer: false + compile_templates: false + compile_msa: false + ema: true + ema_decay: 0.999 + exclude_ions_from_lddt: true + fix_sym_check: true + cyclic_pos_enc: true + num_val_datasets: 1 + bond_type_feature: true + conditioning_cutoff_min: ${data.binder_pocket_cutoff_min} + conditioning_cutoff_max: ${data.binder_pocket_cutoff_max} + use_templates: ${data.use_templates} + predict_bfactor: true + checkpoint_diffusion_conditioning: true + + validators: + - _target_: boltz.model.validation.rcsb.RCSBValidator + val_names: ["RCSB"] + confidence_prediction: ${model.confidence_prediction} + + embedder_args: + atom_encoder_depth: 3 + atom_encoder_heads: 4 + add_mol_type_feat: true + add_method_conditioning: true + add_modified_flag: true + add_cyclic_flag: true + + msa_args: + msa_s: 64 + msa_blocks: 4 + msa_dropout: 0.15 + z_dropout: 0.25 + pairwise_head_width: 32 + pairwise_num_heads: 4 + use_paired_feature: true + activation_checkpointing: true + + template_args: + template_dim: 64 + template_blocks: 2 + activation_checkpointing: true + + pairformer_args: + num_blocks: 48 + num_heads: 16 + dropout: 0.25 + post_layer_norm: false + activation_checkpointing: true + v2: true + + score_model_args: + sigma_data: 16 + dim_fourier: 256 + atom_encoder_depth: 3 + atom_encoder_heads: 4 + token_transformer_depth: 24 + token_transformer_heads: 16 + atom_decoder_depth: 3 + atom_decoder_heads: 4 + conditioning_transition_layers: 2 + transformer_post_ln: false + activation_checkpointing: true + + confidence_prediction: false + affinity_prediction: false + structure_prediction_training: true + + training_args: + recycling_steps: 3 + sampling_steps: 20 + diffusion_multiplicity: 32 + diffusion_samples: 1 + affinity_loss_weight: 3e-3 + confidence_loss_weight: 1e-4 + diffusion_loss_weight: 4.0 + distogram_loss_weight: 3e-2 + bfactor_loss_weight: 1e-3 + adam_beta_1: 0.9 + adam_beta_2: 0.95 + adam_eps: 0.00000001 + lr_scheduler: af3 + base_lr: 0.0 + max_lr: 0.0005 + lr_warmup_no_steps: 1000 + lr_start_decay_after_n_steps: 50000 + lr_decay_every_n_steps: 50000 + lr_decay_factor: 0.95 + weight_decay: 0.003 + weight_decay_exclude: true + + validation_args: + recycling_steps: 3 + sampling_steps: 200 + diffusion_samples: 5 + symmetry_correction: true + + diffusion_process_args: + sigma_min: 0.0004 + sigma_max: 160.0 + sigma_data: 16.0 + rho: 7 + P_mean: -1.2 + P_std: 1.5 + gamma_0: 0.8 + gamma_min: 1.0 + noise_scale: 1.0 + step_scale: 1.0 + coordinate_augmentation: true + alignment_reverse_diff: true + synchronize_sigmas: false + + diffusion_loss_args: + add_smooth_lddt_loss: true + nucleotide_loss_weight: 5.0 + ligand_loss_weight: 10.0 + filter_by_plddt: 0.0 \ No newline at end of file diff --git a/scripts/train/train.py b/scripts/train/train.py index f83966bdd..98acafa6e 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -19,6 +19,7 @@ from pytorch_lightning.utilities import rank_zero_only from boltz.data.module.training import BoltzTrainingDataModule, DataConfig +from boltz.data.module.trainingv2 import Boltz2TrainingDataModule, DataConfigV2 @dataclass @@ -57,6 +58,8 @@ class TrainConfig: Fail on mismatched checkpoint weights. load_confidence_from_trunk: Optional[bool] Load pre-trained confidence weights from trunk. + v2: bool + Use v2 model. """ @@ -75,6 +78,7 @@ class TrainConfig: debug: bool = False strict_loading: bool = True load_confidence_from_trunk: Optional[bool] = False + v2: bool = False def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912, PLR0915 @@ -123,10 +127,14 @@ def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912, PLR wandb = None # Create objects - data_config = DataConfig(**cfg.data) - data_module = BoltzTrainingDataModule(data_config) - model_module = cfg.model + if cfg.v2: + data_config = DataConfigV2(**cfg.data) + data_module = Boltz2TrainingDataModule(data_config) + else: + data_config = DataConfig(**cfg.data) + data_module = BoltzTrainingDataModule(data_config) + model_module = cfg.model if cfg.pretrained and not cfg.resume: # Load the pretrained weights into the confidence module if cfg.load_confidence_from_trunk: diff --git a/src/boltz/data/crop/boltz.py b/src/boltz/data/crop/boltz.py index 2e4b31cba..1d2e20e25 100644 --- a/src/boltz/data/crop/boltz.py +++ b/src/boltz/data/crop/boltz.py @@ -6,12 +6,12 @@ from boltz.data import const from boltz.data.crop.cropper import Cropper -from boltz.data.types import Tokenized +from boltz.data.types import TokenizedTraining def pick_random_token( tokens: np.ndarray, - random: np.random.RandomState, + random: np.random.Generator, ) -> np.ndarray: """Pick a random token from the data. @@ -19,7 +19,7 @@ def pick_random_token( ---------- tokens : np.ndarray The token data. - random : np.ndarray + random : np.random.Generator The random state for reproducibility. Returns @@ -28,13 +28,13 @@ def pick_random_token( The selected token. """ - return tokens[random.randint(len(tokens))] + return tokens[random.integers(len(tokens))] def pick_chain_token( tokens: np.ndarray, chain_id: int, - random: np.random.RandomState, + random: np.random.Generator, ) -> np.ndarray: """Pick a random token from a chain. @@ -68,7 +68,7 @@ def pick_chain_token( def pick_interface_token( tokens: np.ndarray, interface: np.ndarray, - random: np.random.RandomState, + random: np.random.Generator, ) -> np.ndarray: """Pick a random token from an interface. @@ -78,7 +78,7 @@ def pick_interface_token( The token data. interface : int The interface ID. - random : np.ndarray + random : np.random.Generator The random state for reproducibility. Returns @@ -124,17 +124,53 @@ def pick_interface_token( return query +def pick_initial_crop_token( + tokens: np.ndarray, + initial_crop: list[int], + random: np.random.Generator, +) -> np.ndarray: + """Pick a random token from the initial crop. + + Parameters + ---------- + tokens : np.ndarray + The token data. + initial_crop : list[int] + The initial crop. + random : np.random.Generator + The random state for reproducibility. + + Returns + ------- + np.ndarray + + """ + # Compute crop centroid + crop_centroid = np.mean(tokens[initial_crop]["center_coords"], axis=0) + + # Compute distances to all tokens + dists = cdist(tokens["center_coords"], crop_centroid[None]) + + # Pick the closest token + return tokens[np.argmin(dists[:, 0])] + + class BoltzCropper(Cropper): """Interpolate between contiguous and spatial crops.""" - def __init__(self, min_neighborhood: int = 0, max_neighborhood: int = 40) -> None: + def __init__( + self, + min_neighborhood: int = 0, + max_neighborhood: int = 40, + dna_double_helix: bool = False, + ) -> None: """Initialize the cropper. Modulates the type of cropping to be performed. Smaller neighborhoods result in more spatial cropping. Larger neighborhoods result in more continuous cropping. A mix can be achieved by - providing a range over which to sample. + providing a list of sizes from which to sample. Parameters ---------- @@ -142,20 +178,24 @@ def __init__(self, min_neighborhood: int = 0, max_neighborhood: int = 40) -> Non The minimum neighborhood size, by default 0. max_neighborhood : int The maximum neighborhood size, by default 40. + dna_double_helix : bool + Whether to use DNA double helix cropping, by default False. """ - sizes = list(range(min_neighborhood, max_neighborhood + 1, 2)) - self.neighborhood_sizes = sizes + self.neighborhood_sizes = list(range(min_neighborhood, max_neighborhood + 1, 2)) + self.dna_double_helix = dna_double_helix def crop( # noqa: PLR0915 self, - data: Tokenized, + data: TokenizedTraining, max_tokens: int, - random: np.random.RandomState, - max_atoms: Optional[int] = None, + random: np.random.Generator, chain_id: Optional[int] = None, interface_id: Optional[int] = None, - ) -> Tokenized: + max_atoms: Optional[int] = None, + return_indices: bool = False, + initial_crop: Optional[list[int]] = None, + ) -> TokenizedTraining: """Crop the data to a maximum number of tokens. Parameters @@ -164,18 +204,14 @@ def crop( # noqa: PLR0915 The tokenized data. max_tokens : int The maximum number of tokens to crop. - random : np.random.RandomState + random : np.random.Generator The random state for reproducibility. - max_atoms : int, optional + max_atoms : Optional[int] The maximum number of atoms to consider. - chain_id : int, optional - The chain ID to crop. - interface_id : int, optional - The interface ID to crop. Returns ------- - Tokenized + TokenizedTraining The cropped data. """ @@ -211,17 +247,19 @@ def crop( # noqa: PLR0915 raise ValueError(msg) # Pick a random token, chain, or interface - if chain_id is not None: + if initial_crop is not None: + query = pick_initial_crop_token(token_data, initial_crop, random) + elif chain_id is not None: query = pick_chain_token(valid_tokens, chain_id, random) elif interface_id is not None: interface = interfaces[interface_id] query = pick_interface_token(valid_tokens, interface, random) elif valid_interfaces.size: - idx = random.randint(len(valid_interfaces)) + idx = random.integers(len(valid_interfaces)) interface = valid_interfaces[idx] query = pick_interface_token(valid_tokens, interface, random) else: - idx = random.randint(len(valid_chains)) + idx = random.integers(len(valid_chains)) chain_id = valid_chains[idx]["asym_id"] query = pick_chain_token(valid_tokens, chain_id, random) @@ -232,44 +270,101 @@ def crop( # noqa: PLR0915 # Select cropped indices cropped: set[int] = set() total_atoms = 0 + + if initial_crop is not None: + cropped.update(initial_crop) + total_atoms = sum(token_data[idx]["atom_num"] for idx in initial_crop) + for idx in indices: # Get the token token = valid_tokens[idx] - # Get all tokens from this chain - chain_tokens = token_data[token_data["asym_id"] == token["asym_id"]] - - # Pick the whole chain if possible, otherwise select - # a contiguous subset centered at the query token - if len(chain_tokens) <= neighborhood_size: - new_tokens = chain_tokens - else: - # First limit to the maximum set of tokens, with the - # neighborhood on both sides to handle edges. This - # is mostly for efficiency with the while loop below. - min_idx = token["res_idx"] - neighborhood_size - max_idx = token["res_idx"] + neighborhood_size - - max_token_set = chain_tokens - max_token_set = max_token_set[max_token_set["res_idx"] >= min_idx] - max_token_set = max_token_set[max_token_set["res_idx"] <= max_idx] - - # Start by adding just the query token - new_tokens = max_token_set[max_token_set["res_idx"] == token["res_idx"]] - - # Expand the neighborhood until we have enough tokens, one - # by one to handle some edge cases with non-standard chains. - # We switch to the res_idx instead of the token_idx to always - # include all tokens from modified residues or from ligands. - min_idx = max_idx = token["res_idx"] - while new_tokens.size < neighborhood_size: - min_idx = min_idx - 1 - max_idx = max_idx + 1 - new_tokens = max_token_set - new_tokens = new_tokens[new_tokens["res_idx"] >= min_idx] - new_tokens = new_tokens[new_tokens["res_idx"] <= max_idx] + neighborhood_size_to_use = neighborhood_size + center_tokens_to_use = [token] + new_tokens_acc = [] + + # If it is a DNA double helix we may change this + if ( + self.dna_double_helix + and token["mol_type"] == const.chain_type_ids["DNA"] + ): + base_coords = data.structure.atoms["coords"][ + token["atom_idx"] : token["atom_idx"] + token["atom_num"] + ] + base_is_present = data.structure.atoms["is_present"][ + token["atom_idx"] : token["atom_idx"] + token["atom_num"] + ] + base_coords = base_coords[base_is_present] + + best_dist = 1e9 + best_other_token = None + + for other_token in valid_tokens: + if ( + other_token["mol_type"] == const.chain_type_ids["DNA"] + and other_token["asym_id"] != token["asym_id"] + ): + other_base_coords = data.structure.atoms["coords"][ + other_token["atom_idx"] : other_token["atom_idx"] + + other_token["atom_num"] + ] + other_base_is_present = data.structure.atoms["is_present"][ + other_token["atom_idx"] : other_token["atom_idx"] + + other_token["atom_num"] + ] + other_base_coords = other_base_coords[other_base_is_present] + + dist = np.min(cdist(base_coords, other_base_coords)) + if dist < best_dist: + best_dist = dist + best_other_token = other_token + + if best_dist < 3.0: + center_tokens_to_use.append(best_other_token) + neighborhood_size_to_use = neighborhood_size_to_use // 2 + + for center_token in center_tokens_to_use: + # Get all tokens from this chain + chain_tokens = token_data[ + token_data["asym_id"] == center_token["asym_id"] + ] + + # Pick the whole chain if possible, otherwise select + # a contiguous subset centered at the query token + if len(chain_tokens) <= neighborhood_size_to_use: + new_tokens = chain_tokens + else: + # First limit to the maximum set of tokens, with the + # neighboorhood on both sides to handle edges. This + # is mostly for efficiency with the while loop below. + min_idx = center_token["res_idx"] - neighborhood_size_to_use + max_idx = center_token["res_idx"] + neighborhood_size_to_use + + max_token_set = chain_tokens + max_token_set = max_token_set[max_token_set["res_idx"] >= min_idx] + max_token_set = max_token_set[max_token_set["res_idx"] <= max_idx] + + # Start by adding just the query token + new_tokens = max_token_set[ + max_token_set["res_idx"] == center_token["res_idx"] + ] + + # Expand the neighborhood until we have enough tokens, one + # by one to handle some edge cases with non-standard chains. + # We switch to the res_idx instead of the token_idx to always + # include all tokens from modified residues or from ligands. + min_idx = max_idx = center_token["res_idx"] + while new_tokens.size < neighborhood_size_to_use: + min_idx = min_idx - 1 + max_idx = max_idx + 1 + new_tokens = max_token_set + new_tokens = new_tokens[new_tokens["res_idx"] >= min_idx] + new_tokens = new_tokens[new_tokens["res_idx"] <= max_idx] + + new_tokens_acc.append(new_tokens) # Compute new tokens and new atoms + new_tokens = np.concatenate(new_tokens_acc) new_indices = set(new_tokens["token_idx"]) - cropped new_tokens = token_data[list(new_indices)] new_atoms = np.sum(new_tokens["atom_num"]) @@ -293,4 +388,28 @@ def crop( # noqa: PLR0915 token_bonds = token_bonds[np.isin(token_bonds["token_2"], indices)] # Return the cropped tokens + if return_indices: + token_ids_mol = set( + token_data[token_data["mol_type"] == 3]["token_idx"].tolist() + ) + return replace(data, tokens=token_data, bonds=token_bonds), sorted( + cropped - token_ids_mol + ) + else: + return replace(data, tokens=token_data, bonds=token_bonds) + + def crop_indices( # noqa: PLR0915 + self, + data: TokenizedTraining, + cropped_indices: list[int], + ) -> TokenizedTraining: + token_data = data.tokens + token_ids_mol = token_data[token_data["mol_type"] == 3]["token_idx"].tolist() # noqa: PLR2004 + cropped_indices = sorted({*token_ids_mol, *cropped_indices}) + token_data = token_data[cropped_indices] + indices = token_data["token_idx"] + token_bonds = data.bonds + token_bonds = token_bonds[np.isin(token_bonds["token_1"], indices)] + token_bonds = token_bonds[np.isin(token_bonds["token_2"], indices)] + return replace(data, tokens=token_data, bonds=token_bonds) diff --git a/src/boltz/data/feature/featurizerv2.py b/src/boltz/data/feature/featurizerv2.py index 2fcb30713..f8a2231ec 100644 --- a/src/boltz/data/feature/featurizerv2.py +++ b/src/boltz/data/feature/featurizerv2.py @@ -1,10 +1,10 @@ import math -from typing import Optional from collections import deque +from typing import Optional + import numba import numpy as np import numpy.typing as npt -import rdkit.Chem.Descriptors import torch from numba import types from rdkit.Chem import Mol @@ -421,8 +421,9 @@ def construct_paired_msa( # noqa: C901, PLR0915, PLR0912 # Map (chain_id, seq_idx, res_idx) to deletion deletions = numba.typed.Dict.empty( key_type=numba.types.Tuple( - [numba.types.int64, numba.types.int64, numba.types.int64]), - value_type=numba.types.int64 + [numba.types.int64, numba.types.int64, numba.types.int64] + ), + value_type=numba.types.int64, ) for chain_id, chain_msa in msa.items(): chain_deletions = chain_msa.deletions @@ -2335,8 +2336,12 @@ def process( chain_constraint_features = process_chain_feature_constraints(data) contact_constraint_features = process_contact_feature_constraints( data=data, - inference_pocket_constraints=inference_pocket_constraints if inference_pocket_constraints else [], - inference_contact_constraints=inference_contact_constraints if inference_contact_constraints else [], + inference_pocket_constraints=inference_pocket_constraints + if inference_pocket_constraints + else [], + inference_contact_constraints=inference_contact_constraints + if inference_contact_constraints + else [], ) return { diff --git a/src/boltz/data/feature/featurizerv2_train.py b/src/boltz/data/feature/featurizerv2_train.py new file mode 100644 index 000000000..5e08441cf --- /dev/null +++ b/src/boltz/data/feature/featurizerv2_train.py @@ -0,0 +1,1964 @@ +import math +from typing import Dict, List, Optional, Tuple + +import networkx as nx +import numba +import numpy as np +import numpy.typing as npt +import torch +from numba import types +from rdkit.Chem import Mol +from scipy.spatial.distance import cdist +from torch import Tensor, from_numpy +from torch.nn.functional import one_hot + +from boltz.data import const +from boltz.data.mol import ( + get_amino_acids_symmetries, + get_chain_symmetries, + get_ligand_symmetries, + get_symmetries, +) +from boltz.data.pad import pad_dim +from boltz.data.types import ( + MSA, + Input, + MSADeletion, + MSAResidue, + MSASequence, +) +from boltz.model.modules.utils import center_random_augmentation + +#################################################################################################### +# HELPERS +#################################################################################################### + + +def convert_atom_name(name: str) -> Tuple[int, int, int, int]: + """Convert an atom name to a standard format. + + Parameters + ---------- + name : str + The atom name. + + Returns + ------- + Tuple[int, int, int, int] + The converted atom name. + + """ + name = str(name).strip() + name = [ord(c) - 32 for c in name] + name = name + [0] * (4 - len(name)) + return tuple(name) + + +def sample_d( + min_d: float, + max_d: float, + n_samples: int, + random: np.random.Generator, +) -> np.ndarray: + """Generate samples from a 1/d distribution between min_d and max_d. + + Parameters + ---------- + min_d : float + Minimum value of d + max_d : float + Maximum value of d + n_samples : int + Number of samples to generate + random : numpy.random.Generator + Random number generator + + Returns + ------- + numpy.ndarray + Array of samples drawn from the distribution + + Notes + ----- + The probability density function is: + f(d) = 1/(d * ln(max_d/min_d)) for d in [min_d, max_d] + + The inverse CDF transform is: + d = min_d * (max_d/min_d)**u where u ~ Uniform(0,1) + + """ + # Generate n_samples uniform random numbers in [0, 1] + u = random.random(n_samples) + # Transform u using the inverse CDF + return min_d * (max_d / min_d) ** u + + +def compute_frames_nonpolymer( + data: Input, + coords, + resolved_mask, + atom_to_token, + frame_data: List, + resolved_frame_data: List, +) -> Tuple[List, List]: + """Get the frames for non-polymer tokens. + + Parameters + ---------- + data : Input + The input data to the model. + frame_data : List + The frame data. + resolved_frame_data : List + The resolved frame data. + + Returns + ------- + Tuple[List, List] + The frame data and resolved frame data. + + """ + ## assert coords.shape[0] == 1, "No support for ensembles yet" + # print("Warning: Frames are only using first conformer! For debug.") + # coords = coords[0] + + frame_data = np.array(frame_data) + resolved_frame_data = np.array(resolved_frame_data) + asym_id_token = data.tokens["asym_id"] + asym_id_atom = data.tokens["asym_id"][atom_to_token] + token_idx = 0 + atom_idx = 0 + for id in np.unique(data.tokens["asym_id"]): + mask_chain_token = asym_id_token == id + mask_chain_atom = asym_id_atom == id + num_tokens = mask_chain_token.sum() + num_atoms = mask_chain_atom.sum() + if ( + data.tokens[token_idx]["mol_type"] != const.chain_type_ids["NONPOLYMER"] + or num_atoms < 3 # noqa: PLR2004 + ): + token_idx += num_tokens + atom_idx += num_atoms + continue + dist_mat = ( + ( + coords.reshape(-1, 3)[mask_chain_atom][:, None, :] + - coords.reshape(-1, 3)[mask_chain_atom][None, :, :] + ) + ** 2 + ).sum(-1) ** 0.5 + resolved_pair = 1 - ( + resolved_mask[mask_chain_atom][None, :] + * resolved_mask[mask_chain_atom][:, None] + ).astype(np.float32) + resolved_pair[resolved_pair == 1] = math.inf + indices = np.argsort(dist_mat + resolved_pair, axis=1) + frames = ( + np.concatenate( + [ + indices[:, 1:2], + indices[:, 0:1], + indices[:, 2:3], + ], + axis=1, + ) + + atom_idx + ) + frame_data[token_idx : token_idx + num_atoms, :] = frames + resolved_frame_data[token_idx : token_idx + num_atoms] = resolved_mask[ + frames + ].all(axis=1) + token_idx += num_tokens + atom_idx += num_atoms + frames_expanded = coords.reshape(-1, 3)[frame_data] + + mask_collinear = compute_collinear_mask( + frames_expanded[:, 1] - frames_expanded[:, 0], + frames_expanded[:, 1] - frames_expanded[:, 2], + ) + return frame_data, resolved_frame_data & mask_collinear + + +def compute_collinear_mask(v1, v2): + norm1 = np.linalg.norm(v1, axis=1, keepdims=True) + norm2 = np.linalg.norm(v2, axis=1, keepdims=True) + v1 = v1 / (norm1 + 1e-6) + v2 = v2 / (norm2 + 1e-6) + mask_angle = np.abs(np.sum(v1 * v2, axis=1)) < 0.9063 + mask_overlap1 = norm1.reshape(-1) > 1e-2 + mask_overlap2 = norm2.reshape(-1) > 1e-2 + return mask_angle & mask_overlap1 & mask_overlap2 + + +def dummy_msa(residues: np.ndarray) -> MSA: + """Create a dummy MSA for a chain. + + Parameters + ---------- + residues : np.ndarray + The residues for the chain. + + Returns + ------- + MSA + The dummy MSA. + + """ + residues = [res["res_type"] for res in residues] + deletions = [] + sequences = [(0, -1, 0, len(residues), 0, 0)] + return MSA( + residues=np.array(residues, dtype=MSAResidue), + deletions=np.array(deletions, dtype=MSADeletion), + sequences=np.array(sequences, dtype=MSASequence), + ) + + +def construct_paired_msa( # noqa: C901, PLR0915, PLR0912 + data: Input, + random: np.random.Generator, + max_seqs: int, + max_pairs: int = 8192, + max_total: int = 16384, + random_subset: bool = False, +) -> Tuple[Tensor, Tensor, Tensor]: + """Pair the MSA data. + + Parameters + ---------- + data : Input + The input data to the model. + + Returns + ------- + Tensor + The MSA data. + Tensor + The deletion data. + Tensor + Mask indicating paired sequences. + + """ + # Get unique chains (ensuring monotonicity in the order) + assert np.all(np.diff(data.tokens["asym_id"], n=1) >= 0) + chain_ids = np.unique(data.tokens["asym_id"]) + + # Get relevant MSA, and create a dummy for chains without + msa: Dict[int, MSA] = {} + for chain_id in chain_ids: + # Get input sequence + chain = data.structure.chains[chain_id] + res_start = chain["res_idx"] + res_end = res_start + chain["res_num"] + residues = data.structure.residues[res_start:res_end] + + # Check if we have an MSA, and that the + # first sequence matches the input sequence + if chain_id in data.msa: + # Set the MSA + msa[chain_id] = data.msa[chain_id] + + # Run length and residue type checks + first = data.msa[chain_id].sequences[0] + first_start = first["res_start"] + first_end = first["res_end"] + msa_residues = data.msa[chain_id].residues + first_residues = msa_residues[first_start:first_end] + + warning = "Warning: MSA does not match input sequence, creating dummy." + if len(residues) == len(first_residues): + # If there is a mismatch, check if it is between MET & UNK + # If so, replace the first sequence with the input sequence. + # Otherwise, replace with a dummy MSA for this chain. + mismatches = residues["res_type"] != first_residues["res_type"] + if mismatches.sum().item(): + idx = np.where(mismatches)[0] + is_met = residues["res_type"][idx] == const.token_ids["MET"] + is_unk = residues["res_type"][idx] == const.token_ids["UNK"] + is_msa_unk = ( + first_residues["res_type"][idx] == const.token_ids["UNK"] + ) + if (np.all(is_met) and np.all(is_msa_unk)) or np.all(is_unk): + msa_residues[first_start:first_end]["res_type"] = residues[ + "res_type" + ] + else: + print( + warning, + "1", + residues["res_type"], + first_residues["res_type"], + data.record.id, + ) + msa[chain_id] = dummy_msa(residues) + else: + print( + warning, + "2", + residues["res_type"], + first_residues["res_type"], + data.record.id, + ) + msa[chain_id] = dummy_msa(residues) + else: + msa[chain_id] = dummy_msa(residues) + + # Map taxonomies to (chain_id, seq_idx) + taxonomy_map: Dict[str, List] = {} + for chain_id, chain_msa in msa.items(): + sequences = chain_msa.sequences + sequences = sequences[sequences["taxonomy"] != -1] + for sequence in sequences: + seq_idx = sequence["seq_idx"] + taxon = sequence["taxonomy"] + taxonomy_map.setdefault(taxon, []).append((chain_id, seq_idx)) + + # Remove taxonomies with only one sequence and sort by the + # number of chain_id present in each of the taxonomies + taxonomy_map = {k: v for k, v in taxonomy_map.items() if len(v) > 1} + taxonomy_map = sorted( + taxonomy_map.items(), + key=lambda x: len({c for c, _ in x[1]}), + reverse=True, + ) + + # Keep track of the sequences available per chain, keeping the original + # order of the sequences in the MSA to favor the best matching sequences + visited = {(c, s) for c, items in taxonomy_map for s in items} + available = {} + for c in chain_ids: + available[c] = [ + i for i in range(1, len(msa[c].sequences)) if (c, i) not in visited + ] + + # Create sequence pairs + is_paired = [] + pairing = [] + + # Start with the first sequence for each chain + is_paired.append({c: 1 for c in chain_ids}) + pairing.append({c: 0 for c in chain_ids}) + + # Then add up to 8191 paired rows + for _, pairs in taxonomy_map: + # Group occurences by chain_id in case we have multiple + # sequences from the same chain and same taxonomy + chain_occurences = {} + for chain_id, seq_idx in pairs: + chain_occurences.setdefault(chain_id, []).append(seq_idx) + + # We create as many pairings as the maximum number of occurences + max_occurences = max(len(v) for v in chain_occurences.values()) + for i in range(max_occurences): + row_pairing = {} + row_is_paired = {} + + # Add the chains present in the taxonomy + for chain_id, seq_idxs in chain_occurences.items(): + # Roll over the sequence index to maximize diversity + idx = i % len(seq_idxs) + seq_idx = seq_idxs[idx] + + # Add the sequence to the pairing + row_pairing[chain_id] = seq_idx + row_is_paired[chain_id] = 1 + + # Add any missing chains + for chain_id in chain_ids: + if chain_id not in row_pairing: + row_is_paired[chain_id] = 0 + if available[chain_id]: + # Add the next available sequence + seq_idx = available[chain_id].pop(0) + row_pairing[chain_id] = seq_idx + else: + # No more sequences available, we place a gap + row_pairing[chain_id] = -1 + + pairing.append(row_pairing) + is_paired.append(row_is_paired) + + # Break if we have enough pairs + if len(pairing) >= max_pairs: + break + + # Break if we have enough pairs + if len(pairing) >= max_pairs: + break + + # Now add up to 16384 unpaired rows total + max_left = max(len(v) for v in available.values()) + for _ in range(min(max_total - len(pairing), max_left)): + row_pairing = {} + row_is_paired = {} + for chain_id in chain_ids: + row_is_paired[chain_id] = 0 + if available[chain_id]: + # Add the next available sequence + seq_idx = available[chain_id].pop(0) + row_pairing[chain_id] = seq_idx + else: + # No more sequences available, we place a gap + row_pairing[chain_id] = -1 + + pairing.append(row_pairing) + is_paired.append(row_is_paired) + + # Break if we have enough sequences + if len(pairing) >= max_total: + break + + # Randomly sample a subset of the pairs + # ensuring the first row is always present + if random_subset: + num_seqs = len(pairing) + if num_seqs > max_seqs: + indices = random.choice( + np.arange(1, num_seqs), size=max_seqs - 1, replace=False + ) # noqa: NPY002 + pairing = [pairing[0]] + [pairing[i] for i in indices] + is_paired = [is_paired[0]] + [is_paired[i] for i in indices] + else: + # Deterministic downsample to max_seqs + pairing = pairing[:max_seqs] + is_paired = is_paired[:max_seqs] + + # Map (chain_id, seq_idx, res_idx) to deletion + deletions = {} + for chain_id, chain_msa in msa.items(): + chain_deletions = chain_msa.deletions + for sequence in chain_msa.sequences: + del_start = sequence["del_start"] + del_end = sequence["del_end"] + chain_deletions = chain_msa.deletions[del_start:del_end] + for deletion_data in chain_deletions: + seq_idx = sequence["seq_idx"] + res_idx = deletion_data["res_idx"] + deletion = deletion_data["deletion"] + deletions[(chain_id, seq_idx, res_idx)] = deletion + + # Add all the token MSA data + msa_data, del_data, paired_data = prepare_msa_arrays( + data.tokens, pairing, is_paired, deletions, msa + ) + + msa_data = torch.tensor(msa_data, dtype=torch.long) + del_data = torch.tensor(del_data, dtype=torch.float) + paired_data = torch.tensor(paired_data, dtype=torch.float) + + return msa_data, del_data, paired_data + + +def prepare_msa_arrays( + tokens, + pairing: list[dict[int, int]], + is_paired: list[dict[int, int]], + deletions: dict[tuple[int, int, int], int], + msa: dict[int, MSA], +) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]: + """Reshape data to play nicely with numba jit.""" + token_asym_ids_arr = np.array([t["asym_id"] for t in tokens], dtype=np.int64) + token_res_idxs_arr = np.array([t["res_idx"] for t in tokens], dtype=np.int64) + + chain_ids = sorted(msa.keys()) + + # chain_ids are not necessarily contiguous (e.g. they might be 0, 24, 25). + # This allows us to look up a chain_id by it's index in the chain_ids list. + chain_id_to_idx = {chain_id: i for i, chain_id in enumerate(chain_ids)} + token_asym_ids_idx_arr = np.array( + [chain_id_to_idx[asym_id] for asym_id in token_asym_ids_arr], dtype=np.int64 + ) + + pairing_arr = np.zeros((len(pairing), len(chain_ids)), dtype=np.int64) + is_paired_arr = np.zeros((len(is_paired), len(chain_ids)), dtype=np.int64) + + for i, row_pairing in enumerate(pairing): + for chain_id in chain_ids: + pairing_arr[i, chain_id_to_idx[chain_id]] = row_pairing[chain_id] + + for i, row_is_paired in enumerate(is_paired): + for chain_id in chain_ids: + is_paired_arr[i, chain_id_to_idx[chain_id]] = row_is_paired[chain_id] + + max_seq_len = max(len(msa[chain_id].sequences) for chain_id in chain_ids) + + # we want res_start from sequences + msa_sequences = np.full((len(chain_ids), max_seq_len), -1, dtype=np.int64) + for chain_id in chain_ids: + for i, seq in enumerate(msa[chain_id].sequences): + msa_sequences[chain_id_to_idx[chain_id], i] = seq["res_start"] + + max_residues_len = max(len(msa[chain_id].residues) for chain_id in chain_ids) + msa_residues = np.full((len(chain_ids), max_residues_len), -1, dtype=np.int64) + for chain_id in chain_ids: + residues = msa[chain_id].residues.astype(np.int64) + idxs = np.arange(len(residues)) + chain_idx = chain_id_to_idx[chain_id] + msa_residues[chain_idx, idxs] = residues + + deletions_dict = numba.typed.Dict.empty( + key_type=numba.types.Tuple( + [numba.types.int64, numba.types.int64, numba.types.int64] + ), + value_type=numba.types.int64, + ) + deletions_dict.update(deletions) + + return _prepare_msa_arrays_inner( + token_asym_ids_arr, + token_res_idxs_arr, + token_asym_ids_idx_arr, + pairing_arr, + is_paired_arr, + deletions_dict, + msa_sequences, + msa_residues, + const.token_ids["-"], + ) + + +deletions_dict_type = types.DictType(types.UniTuple(types.int64, 3), types.int64) + + +@numba.njit( + [ + types.Tuple( + ( + types.int64[:, ::1], # msa_data + types.int64[:, ::1], # del_data + types.int64[:, ::1], # paired_data + ) + )( + types.int64[::1], # token_asym_ids + types.int64[::1], # token_res_idxs + types.int64[::1], # token_asym_ids_idx + types.int64[:, ::1], # pairing + types.int64[:, ::1], # is_paired + deletions_dict_type, # deletions + types.int64[:, ::1], # msa_sequences + types.int64[:, ::1], # msa_residues + types.int64, # gap_token + ) + ], + cache=True, +) +def _prepare_msa_arrays_inner( + token_asym_ids: npt.NDArray[np.int64], + token_res_idxs: npt.NDArray[np.int64], + token_asym_ids_idx: npt.NDArray[np.int64], + pairing: npt.NDArray[np.int64], + is_paired: npt.NDArray[np.int64], + deletions: dict[tuple[int, int, int], int], + msa_sequences: npt.NDArray[np.int64], + msa_residues: npt.NDArray[np.int64], + gap_token: int, +) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]: + n_tokens = len(token_asym_ids) + n_pairs = len(pairing) + msa_data = np.full((n_tokens, n_pairs), gap_token, dtype=np.int64) + paired_data = np.zeros((n_tokens, n_pairs), dtype=np.int64) + del_data = np.zeros((n_tokens, n_pairs), dtype=np.int64) + + # Add all the token MSA data + for token_idx in range(n_tokens): + chain_id_idx = token_asym_ids_idx[token_idx] + chain_id = token_asym_ids[token_idx] + res_idx = token_res_idxs[token_idx] + + for pair_idx in range(n_pairs): + seq_idx = pairing[pair_idx, chain_id_idx] + paired_data[token_idx, pair_idx] = is_paired[pair_idx, chain_id_idx] + + # Add residue type + if seq_idx != -1: + res_start = msa_sequences[chain_id_idx, seq_idx] + res_type = msa_residues[chain_id_idx, res_start + res_idx] + k = (chain_id, seq_idx, res_idx) + if k in deletions: + del_data[token_idx, pair_idx] = deletions[k] + msa_data[token_idx, pair_idx] = res_type + + return msa_data, del_data, paired_data + + +#################################################################################################### +# FEATURES +#################################################################################################### + + +def select_subset_from_mask(mask, p, random: np.random.Generator) -> np.ndarray: + num_true = np.sum(mask) + v = random.geometric(p) + 1 + k = min(v, num_true) + + true_indices = np.where(mask)[0] + + # Randomly select k indices from the true_indices + selected_indices = random.choice(true_indices, size=k, replace=False) + + new_mask = np.zeros_like(mask) + new_mask[selected_indices] = 1 + + return new_mask + + +def get_range_bin(value: float, range_dict: Dict[Tuple[float, float], int], default=0): + """Get the bin of a value given a range dictionary.""" + value = float(value) + for k, idx in range_dict.items(): + if k == "other": + continue + low, high = k + if low <= value < high: + return idx + return default + + +def process_token_features( # noqa: C901, PLR0915, PLR0912 + data: Input, + random: np.random.Generator, + max_tokens: Optional[int] = None, + binder_pocket_conditioned_prop: Optional[float] = 0.0, + contact_conditioned_prop: Optional[float] = 0.0, + binder_pocket_cutoff_min: Optional[float] = 4.0, + binder_pocket_cutoff_max: Optional[float] = 20.0, + binder_pocket_sampling_geometric_p: Optional[float] = 0.0, + only_ligand_binder_pocket: Optional[bool] = False, + only_pp_contact: Optional[bool] = False, + maximum_bond_distance: Optional[int] = 0, + override_method: Optional[str] = None, +) -> Dict[str, Tensor]: + """Get the token features. + + Parameters + ---------- + data : Input + The input data to the model. + max_tokens : int + The maximum number of tokens. + + Returns + ------- + Dict[str, Tensor] + The token features. + + """ + # Token data + token_data = data.tokens + token_bonds = data.bonds + + # Token core features + token_index = torch.arange(len(token_data), dtype=torch.long) + residue_index = from_numpy(token_data["res_idx"]).long() + asym_id = from_numpy(token_data["asym_id"]).long() + entity_id = from_numpy(token_data["entity_id"]).long() + sym_id = from_numpy(token_data["sym_id"]).long() + mol_type = from_numpy(token_data["mol_type"]).long() + res_type = from_numpy(token_data["res_type"]).long() + res_type = one_hot(res_type, num_classes=const.num_tokens) + disto_center = from_numpy(token_data["disto_coords"]) + modified = from_numpy(token_data["modified"]).long() # float() + + ## Conditioning features ## + method = ( + np.zeros(len(token_data)) + + const.method_types_ids[ + ("other" if override_method is None else override_method.lower()) + ] + ) + default_T = const.temperature_bins_ids["other"] + default_pH = const.ph_bins_ids["other"] + temp_feature = np.zeros(len(token_data)) + default_T + ph_feature = np.zeros(len(token_data)) + default_pH + if data.record is not None: + if ( + override_method is None + and data.record.structure.method is not None + and data.record.structure.method.lower() in const.method_types_ids + ): + method = (method * 0) + const.method_types_ids[ + data.record.structure.method.lower() + ] + + if data.record.md is not None: + if data.record.md.temperature is not None: + T = data.record.md.temperature + temp_feature = (temp_feature * 0) + get_range_bin( + T, const.temperature_bins_ids, default=default_T + ) + if data.record.md.pH is not None: + pH = data.record.md.pH + ph_feature = (ph_feature * 0) + get_range_bin( + pH, const.ph_bins_ids, default=default_pH + ) + else: + if data.record.structure.temperature is not None: + T = data.record.structure.temperature + temp_feature = (temp_feature * 0) + get_range_bin( + T, const.temperature_bins_ids, default=default_T + ) + if data.record.structure.pH is not None: + pH = data.record.structure.pH + ph_feature = (ph_feature * 0) + get_range_bin( + pH, const.ph_bins_ids, default=default_pH + ) + + method_feature = from_numpy(method).long() + temp_feature = from_numpy(temp_feature).long() + ph_feature = from_numpy(ph_feature).long() + + # Token mask features + pad_mask = torch.ones(len(token_data), dtype=torch.float) + resolved_mask = from_numpy(token_data["resolved_mask"]).float() + disto_mask = from_numpy(token_data["disto_mask"]).float() + + # Token bond features + if max_tokens is not None: + pad_len = max_tokens - len(token_data) + num_tokens = max_tokens if pad_len > 0 else len(token_data) + else: + num_tokens = len(token_data) + + tok_to_idx = {tok["token_idx"]: idx for idx, tok in enumerate(token_data)} + bonds = torch.zeros(num_tokens, num_tokens, dtype=torch.float) + bonds_type = torch.zeros(num_tokens, num_tokens, dtype=torch.long) + for token_bond in token_bonds: + token_1 = tok_to_idx[token_bond["token_1"]] + token_2 = tok_to_idx[token_bond["token_2"]] + bonds[token_1, token_2] = 1 + bonds[token_2, token_1] = 1 + bond_type = token_bond["type"] + bonds_type[token_1, token_2] = bond_type + bonds_type[token_2, token_1] = bond_type + + if maximum_bond_distance > 1: + G = nx.from_numpy_array(bonds.numpy()) + shortest_path = nx.floyd_warshall_numpy(G) + shortest_path = np.where( + shortest_path > shortest_path.shape[0], + maximum_bond_distance + 1, + np.minimum(shortest_path, maximum_bond_distance), + ) + bonds = one_hot( + torch.from_numpy(shortest_path).long(), + num_classes=maximum_bond_distance + 2, + ) + else: + bonds = bonds.unsqueeze(-1) + + # Pocket conditioned feature + contact_conditioning = ( + np.zeros((len(token_data), len(token_data))) + + const.contact_conditioning_info["UNSELECTED"] + ) + contact_threshold = np.zeros((len(token_data), len(token_data))) + + if binder_pocket_conditioned_prop > 0.0: + # choose as binder a random ligand in the crop, if there are no ligands select a protein chain + binder_asym_ids = np.unique( + token_data["asym_id"][ + token_data["mol_type"] == const.chain_type_ids["NONPOLYMER"] + ] + ) + + if len(binder_asym_ids) == 0: + if not only_ligand_binder_pocket: + binder_asym_ids = np.unique(token_data["asym_id"]) + + while random.random() < binder_pocket_conditioned_prop: + if len(binder_asym_ids) == 0: + break + + pocket_asym_id = random.choice(binder_asym_ids) + binder_asym_ids = binder_asym_ids[binder_asym_ids != pocket_asym_id] + + binder_pocket_cutoff = sample_d( + min_d=binder_pocket_cutoff_min, + max_d=binder_pocket_cutoff_max, + n_samples=1, + random=random, + ) + + binder_mask = token_data["asym_id"] == pocket_asym_id + + binder_coords = [] + for token in token_data: + if token["asym_id"] == pocket_asym_id: + _coords = data.structure.atoms["coords"][ + token["atom_idx"] : token["atom_idx"] + token["atom_num"] + ] + _is_present = data.structure.atoms["is_present"][ + token["atom_idx"] : token["atom_idx"] + token["atom_num"] + ] + binder_coords.append(_coords[_is_present]) + binder_coords = np.concatenate(binder_coords, axis=0) + + # find the tokens in the pocket + token_dist = np.zeros(len(token_data)) + 1000 + for i, token in enumerate(token_data): + if ( + token["mol_type"] != const.chain_type_ids["NONPOLYMER"] + and token["asym_id"] != pocket_asym_id + and token["resolved_mask"] == 1 + ): + token_coords = data.structure.atoms["coords"][ + token["atom_idx"] : token["atom_idx"] + token["atom_num"] + ] + token_is_present = data.structure.atoms["is_present"][ + token["atom_idx"] : token["atom_idx"] + token["atom_num"] + ] + token_coords = token_coords[token_is_present] + + # find chain and apply chain transformation + for chain in data.structure.chains: + if chain["asym_id"] == token["asym_id"]: + break + + token_dist[i] = np.min( + np.linalg.norm( + token_coords[:, None, :] - binder_coords[None, :, :], + axis=-1, + ) + ) + + pocket_mask = token_dist < binder_pocket_cutoff + + if np.sum(pocket_mask) > 0: + if binder_pocket_sampling_geometric_p > 0.0: + # select a subset of the pocket, according + # to a geometric distribution with one as minimum + pocket_mask = select_subset_from_mask( + pocket_mask, + binder_pocket_sampling_geometric_p, + random, + ) + + contact_conditioning[np.ix_(binder_mask, pocket_mask)] = ( + const.contact_conditioning_info["BINDER>POCKET"] + ) + contact_conditioning[np.ix_(pocket_mask, binder_mask)] = ( + const.contact_conditioning_info["POCKET>BINDER"] + ) + contact_threshold[np.ix_(binder_mask, pocket_mask)] = ( + binder_pocket_cutoff + ) + contact_threshold[np.ix_(pocket_mask, binder_mask)] = ( + binder_pocket_cutoff + ) + + # Contact conditioning feature + if contact_conditioned_prop > 0.0: + while random.random() < contact_conditioned_prop: + contact_cutoff = sample_d( + min_d=binder_pocket_cutoff_min, + max_d=binder_pocket_cutoff_max, + n_samples=1, + random=random, + ) + if only_pp_contact: + chain_asym_ids = np.unique( + token_data["asym_id"][ + token_data["mol_type"] == const.chain_type_ids["PROTEIN"] + ] + ) + else: + chain_asym_ids = np.unique(token_data["asym_id"]) + + if len(chain_asym_ids) > 1: + chain_asym_id = random.choice(chain_asym_ids) + + chain_coords = [] + for token in token_data: + if token["asym_id"] == chain_asym_id: + _coords = data.structure.atoms["coords"][ + token["atom_idx"] : token["atom_idx"] + token["atom_num"] + ] + _is_present = data.structure.atoms["is_present"][ + token["atom_idx"] : token["atom_idx"] + token["atom_num"] + ] + chain_coords.append(_coords[_is_present]) + chain_coords = np.concatenate(chain_coords, axis=0) + + # find contacts in other chains + possible_other_chains = [] + for other_chain_id in chain_asym_ids[chain_asym_ids != chain_asym_id]: + for token in token_data: + if token["asym_id"] == other_chain_id: + _coords = data.structure.atoms["coords"][ + token["atom_idx"] : token["atom_idx"] + + token["atom_num"] + ] + _is_present = data.structure.atoms["is_present"][ + token["atom_idx"] : token["atom_idx"] + + token["atom_num"] + ] + if _is_present.sum() == 0: + continue + token_coords = _coords[_is_present] + + # check minimum distance + if ( + np.min(cdist(chain_coords, token_coords)) + < contact_cutoff + ): + possible_other_chains.append(other_chain_id) + break + + if len(possible_other_chains) > 0: + other_chain_id = random.choice(possible_other_chains) + + pairs = [] + for token_1 in token_data: + if token_1["asym_id"] == chain_asym_id: + _coords = data.structure.atoms["coords"][ + token_1["atom_idx"] : token_1["atom_idx"] + + token_1["atom_num"] + ] + _is_present = data.structure.atoms["is_present"][ + token_1["atom_idx"] : token_1["atom_idx"] + + token_1["atom_num"] + ] + if _is_present.sum() == 0: + continue + token_1_coords = _coords[_is_present] + + for token_2 in token_data: + if token_2["asym_id"] == other_chain_id: + _coords = data.structure.atoms["coords"][ + token_2["atom_idx"] : token_2["atom_idx"] + + token_2["atom_num"] + ] + _is_present = data.structure.atoms["is_present"][ + token_2["atom_idx"] : token_2["atom_idx"] + + token_2["atom_num"] + ] + if _is_present.sum() == 0: + continue + token_2_coords = _coords[_is_present] + + if ( + np.min(cdist(token_1_coords, token_2_coords)) + < contact_cutoff + ): + pairs.append( + (token_1["token_idx"], token_2["token_idx"]) + ) + + assert len(pairs) > 0 + + pair = random.choice(pairs) + token_1_mask = token_data["token_idx"] == pair[0] + token_2_mask = token_data["token_idx"] == pair[1] + + contact_conditioning[np.ix_(token_1_mask, token_2_mask)] = ( + const.contact_conditioning_info["CONTACT"] + ) + contact_conditioning[np.ix_(token_2_mask, token_1_mask)] = ( + const.contact_conditioning_info["CONTACT"] + ) + + elif not only_pp_contact: + # only one chain, find contacts within the chain with minimum residue distance + pairs = [] + for token_1 in token_data: + _coords = data.structure.atoms["coords"][ + token_1["atom_idx"] : token_1["atom_idx"] + token_1["atom_num"] + ] + _is_present = data.structure.atoms["is_present"][ + token_1["atom_idx"] : token_1["atom_idx"] + token_1["atom_num"] + ] + if _is_present.sum() == 0: + continue + token_1_coords = _coords[_is_present] + + for token_2 in token_data: + if np.abs(token_1["res_idx"] - token_2["res_idx"]) <= 8: + continue + + _coords = data.structure.atoms["coords"][ + token_2["atom_idx"] : token_2["atom_idx"] + + token_2["atom_num"] + ] + _is_present = data.structure.atoms["is_present"][ + token_2["atom_idx"] : token_2["atom_idx"] + + token_2["atom_num"] + ] + if _is_present.sum() == 0: + continue + token_2_coords = _coords[_is_present] + + if ( + np.min(cdist(token_1_coords, token_2_coords)) + < contact_cutoff + ): + pairs.append((token_1["token_idx"], token_2["token_idx"])) + + if len(pairs) > 0: + pair = random.choice(pairs) + token_1_mask = token_data["token_idx"] == pair[0] + token_2_mask = token_data["token_idx"] == pair[1] + + contact_conditioning[np.ix_(token_1_mask, token_2_mask)] = ( + const.contact_conditioning_info["CONTACT"] + ) + contact_conditioning[np.ix_(token_2_mask, token_1_mask)] = ( + const.contact_conditioning_info["CONTACT"] + ) + + if np.all(contact_conditioning == const.contact_conditioning_info["UNSELECTED"]): + contact_conditioning = ( + contact_conditioning + - const.contact_conditioning_info["UNSELECTED"] + + const.contact_conditioning_info["UNSPECIFIED"] + ) + contact_conditioning = from_numpy(contact_conditioning).long() + contact_conditioning = one_hot( + contact_conditioning, num_classes=len(const.contact_conditioning_info) + ) + contact_threshold = from_numpy(contact_threshold).float() + + # compute cyclic polymer mask + cyclic_ids = {} + for idx_chain, asym_id_iter in enumerate(data.structure.chains["asym_id"]): + for connection in data.structure.bonds: + if ( + idx_chain == connection["chain_1"] == connection["chain_2"] + and data.structure.chains[connection["chain_1"]]["res_num"] > 2 + and connection["res_1"] + != connection["res_2"] # Avoid same residue bonds! + ): + if ( + data.structure.chains[connection["chain_1"]]["res_num"] + == (connection["res_2"] + 1) + and connection["res_1"] == 0 + ) or ( + data.structure.chains[connection["chain_1"]]["res_num"] + == (connection["res_1"] + 1) + and connection["res_2"] == 0 + ): + cyclic_ids[asym_id_iter] = data.structure.chains[ + connection["chain_1"] + ]["res_num"] + cyclic = from_numpy( + np.array( + [ + (cyclic_ids[asym_id_iter] if asym_id_iter in cyclic_ids else 0) + for asym_id_iter in token_data["asym_id"] + ] + ) + ).float() + + # Pad to max tokens if given + if max_tokens is not None: + pad_len = max_tokens - len(token_data) + if pad_len > 0: + token_index = pad_dim(token_index, 0, pad_len) + residue_index = pad_dim(residue_index, 0, pad_len) + asym_id = pad_dim(asym_id, 0, pad_len) + entity_id = pad_dim(entity_id, 0, pad_len) + sym_id = pad_dim(sym_id, 0, pad_len) + mol_type = pad_dim(mol_type, 0, pad_len) + res_type = pad_dim(res_type, 0, pad_len) + disto_center = pad_dim(disto_center, 0, pad_len) + pad_mask = pad_dim(pad_mask, 0, pad_len) + resolved_mask = pad_dim(resolved_mask, 0, pad_len) + disto_mask = pad_dim(disto_mask, 0, pad_len) + contact_conditioning = pad_dim(contact_conditioning, 0, pad_len) + contact_conditioning = pad_dim(contact_conditioning, 1, pad_len) + contact_threshold = pad_dim(contact_threshold, 0, pad_len) + contact_threshold = pad_dim(contact_threshold, 1, pad_len) + method_feature = pad_dim(method_feature, 0, pad_len) + temp_feature = pad_dim(temp_feature, 0, pad_len) + ph_feature = pad_dim(ph_feature, 0, pad_len) + modified = pad_dim(modified, 0, pad_len) + cyclic = pad_dim(cyclic, 0, pad_len) + + token_features = { + "token_index": token_index, + "residue_index": residue_index, + "asym_id": asym_id, + "entity_id": entity_id, + "sym_id": sym_id, + "mol_type": mol_type, + "res_type": res_type, + "disto_center": disto_center, + "token_bonds": bonds, + "type_bonds": bonds_type, + "token_pad_mask": pad_mask, + "token_resolved_mask": resolved_mask, + "token_disto_mask": disto_mask, + "contact_conditioning": contact_conditioning, + "contact_threshold": contact_threshold, + "method_feature": method_feature, + "temp_feature": temp_feature, + "ph_feature": ph_feature, + "modified": modified, + "cyclic_period": cyclic, + } + + return token_features + + +def process_atom_features( + data: Input, + random: np.random.Generator, + ensemble_features: Dict, + molecules: Dict[str, Mol], + atoms_per_window_queries: int = 32, + min_dist: float = 2.0, + max_dist: float = 22.0, + num_bins: int = 64, + max_atoms: Optional[int] = None, + max_tokens: Optional[int] = None, + disto_use_ensemble: Optional[bool] = False, + override_bfactor: bool = False, + compute_frames: bool = False, + override_coords: Optional[Tensor] = None, + bfactor_md_correction: bool = False, +) -> Dict[str, Tensor]: + """Get the atom features. + + Parameters + ---------- + data : Input + The input to the model. + max_atoms : int, optional + The maximum number of atoms. + + Returns + ------- + Dict[str, Tensor] + The atom features. + + """ + # Filter to tokens' atoms + atom_data = [] + atom_name = [] + atom_element = [] + atom_charge = [] + atom_conformer = [] + atom_chirality = [] + ref_space_uid = [] + coord_data = [] + if compute_frames: + frame_data = [] + resolved_frame_data = [] + atom_to_token = [] + token_to_rep_atom = [] # index on cropped atom table + r_set_to_rep_atom = [] + disto_coords_ensemble = [] + backbone_feat_index = [] + + e_offsets = data.structure.ensemble["atom_coord_idx"] + atom_idx = 0 + + # Start atom idx in full atom table for structures chosen. Up to num_ensembles points. + ensemble_atom_starts = [ + data.structure.ensemble[idx]["atom_coord_idx"] + for idx in ensemble_features["ensemble_ref_idxs"] + ] + + # Set unk chirality id + unk_chirality = const.chirality_type_ids[const.unk_chirality_type] + + chain_res_ids = {} + res_index_to_conf_id = {} + for token_id, token in enumerate(data.tokens): + # Get the chain residue ids + chain_idx, res_id = token["asym_id"], token["res_idx"] + chain = data.structure.chains[chain_idx] + + if (chain_idx, res_id) not in chain_res_ids: + new_idx = len(chain_res_ids) + chain_res_ids[(chain_idx, res_id)] = new_idx + else: + new_idx = chain_res_ids[(chain_idx, res_id)] + + # Get the molecule and conformer + mol = molecules[token["res_name"]] + atom_name_to_ref = {a.GetProp("name"): a for a in mol.GetAtoms()} + + # Sample a random conformer + if (chain_idx, res_id) not in res_index_to_conf_id: + conf_ids = [int(conf.GetId()) for conf in mol.GetConformers()] + conf_id = int(random.choice(conf_ids)) + res_index_to_conf_id[(chain_idx, res_id)] = conf_id + + conf_id = res_index_to_conf_id[(chain_idx, res_id)] + conformer = mol.GetConformer(conf_id) + + # Map atoms to token indices + ref_space_uid.extend([new_idx] * token["atom_num"]) + atom_to_token.extend([token_id] * token["atom_num"]) + + # Add atom data + start = token["atom_idx"] + end = token["atom_idx"] + token["atom_num"] + token_atoms = data.structure.atoms[start:end] + + # Add atom ref data + # element, charge, conformer, chirality + token_atom_name = np.array([convert_atom_name(a["name"]) for a in token_atoms]) + token_atoms_ref = np.array([atom_name_to_ref[a["name"]] for a in token_atoms]) + token_atoms_element = np.array([a.GetAtomicNum() for a in token_atoms_ref]) + token_atoms_charge = np.array([a.GetFormalCharge() for a in token_atoms_ref]) + token_atoms_conformer = np.array( + [ + ( + conformer.GetAtomPosition(a.GetIdx()).x, + conformer.GetAtomPosition(a.GetIdx()).y, + conformer.GetAtomPosition(a.GetIdx()).z, + ) + for a in token_atoms_ref + ] + ) + token_atoms_chirality = np.array( + [ + const.chirality_type_ids.get(a.GetChiralTag().name, unk_chirality) + for a in token_atoms_ref + ] + ) + + # Map token to representative atom + token_to_rep_atom.append(atom_idx + token["disto_idx"] - start) + if (chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]) and token[ + "resolved_mask" + ]: + r_set_to_rep_atom.append(atom_idx + token["center_idx"] - start) + + if chain["mol_type"] == const.chain_type_ids["PROTEIN"]: + backbone_index = [ + ( + const.protein_backbone_atom_index[atom_name] + 1 + if atom_name in const.protein_backbone_atom_index + else 0 + ) + for atom_name in token_atoms["name"] + ] + elif ( + chain["mol_type"] == const.chain_type_ids["DNA"] + or chain["mol_type"] == const.chain_type_ids["RNA"] + ): + backbone_index = [ + ( + const.nucleic_backbone_atom_index[atom_name] + + 1 + + len(const.protein_backbone_atom_index) + if atom_name in const.nucleic_backbone_atom_index + else 0 + ) + for atom_name in token_atoms["name"] + ] + else: + backbone_index = [0] * token["atom_num"] + backbone_feat_index.extend(backbone_index) + + # Get token coordinates across sampled ensembles and apply transforms + token_coords = np.array( + [ + data.structure.coords[ + ensemble_atom_start + start : ensemble_atom_start + end + ]["coords"] + for ensemble_atom_start in ensemble_atom_starts + ] + ) + coord_data.append(token_coords) + + if compute_frames: + # Get frame data + res_type = const.tokens[token["res_type"]] + res_name = str(token["res_name"]) + + if token["atom_num"] < 3 or res_type in ["PAD", "UNK", "-"]: + idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0 + mask_frame = False + elif (token["mol_type"] == const.chain_type_ids["PROTEIN"]) and ( + res_name in const.ref_atoms + ): + idx_frame_a, idx_frame_b, idx_frame_c = ( + const.ref_atoms[res_name].index("N"), + const.ref_atoms[res_name].index("CA"), + const.ref_atoms[res_name].index("C"), + ) + mask_frame = ( + token_atoms["is_present"][idx_frame_a] + and token_atoms["is_present"][idx_frame_b] + and token_atoms["is_present"][idx_frame_c] + ) + elif ( + token["mol_type"] == const.chain_type_ids["DNA"] + or token["mol_type"] == const.chain_type_ids["RNA"] + ) and (res_name in const.ref_atoms): + idx_frame_a, idx_frame_b, idx_frame_c = ( + const.ref_atoms[res_name].index("C1'"), + const.ref_atoms[res_name].index("C3'"), + const.ref_atoms[res_name].index("C4'"), + ) + mask_frame = ( + token_atoms["is_present"][idx_frame_a] + and token_atoms["is_present"][idx_frame_b] + and token_atoms["is_present"][idx_frame_c] + ) + elif token["mol_type"] == const.chain_type_ids["PROTEIN"]: + # Try to look for the atom nams in the modified residue + is_ca = token_atoms["name"] == "CA" + idx_frame_a = is_ca.argmax() + ca_present = ( + token_atoms[idx_frame_a]["is_present"] if is_ca.any() else False + ) + + is_n = token_atoms["name"] == "N" + idx_frame_b = is_n.argmax() + n_present = ( + token_atoms[idx_frame_b]["is_present"] if is_n.any() else False + ) + + is_c = token_atoms["name"] == "C" + idx_frame_c = is_c.argmax() + c_present = ( + token_atoms[idx_frame_c]["is_present"] if is_c.any() else False + ) + mask_frame = ca_present and n_present and c_present + + elif (token["mol_type"] == const.chain_type_ids["DNA"]) or ( + token["mol_type"] == const.chain_type_ids["RNA"] + ): + # Try to look for the atom nams in the modified residue + is_c1 = token_atoms["name"] == "C1'" + idx_frame_a = is_c1.argmax() + c1_present = ( + token_atoms[idx_frame_a]["is_present"] if is_c1.any() else False + ) + + is_c3 = token_atoms["name"] == "C3'" + idx_frame_b = is_c3.argmax() + c3_present = ( + token_atoms[idx_frame_b]["is_present"] if is_c3.any() else False + ) + + is_c4 = token_atoms["name"] == "C4'" + idx_frame_c = is_c4.argmax() + c4_present = ( + token_atoms[idx_frame_c]["is_present"] if is_c4.any() else False + ) + mask_frame = c1_present and c3_present and c4_present + else: + idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0 + mask_frame = False + frame_data.append( + [ + idx_frame_a + atom_idx, + idx_frame_b + atom_idx, + idx_frame_c + atom_idx, + ] + ) + resolved_frame_data.append(mask_frame) + + # Get distogram coordinates + disto_coords_ensemble_tok = data.structure.coords[ + e_offsets + token["disto_idx"] + ]["coords"] + disto_coords_ensemble.append(disto_coords_ensemble_tok) + + # Update atom data. This is technically never used again (we rely on coord_data), + # but we update for consistency and to make sure the Atom object has valid, transformed coordinates. + token_atoms = token_atoms.copy() + token_atoms["coords"] = token_coords[ + 0 + ] # atom has a copy of first coords in ensemble + atom_data.append(token_atoms) + atom_name.append(token_atom_name) + atom_element.append(token_atoms_element) + atom_charge.append(token_atoms_charge) + atom_conformer.append(token_atoms_conformer) + atom_chirality.append(token_atoms_chirality) + atom_idx += len(token_atoms) + + disto_coords_ensemble = np.array(disto_coords_ensemble) # (N_TOK, N_ENS, 3) + + # Compute ensemble distogram + L = len(data.tokens) + + if disto_use_ensemble: + # Use all available structures to create distogram + idx_list = range(disto_coords_ensemble.shape[1]) + else: + # Only use a sampled structures to create distogram + idx_list = ensemble_features["ensemble_ref_idxs"] + + # Save a numpy array of the distogram to a path + # pdb_id = data.record.id + # with open(f"/afs/csail.mit.edu/u/m/mreveiz/rbg/temp_while_cp_rsg/temp/disto_outs_atlas10ns/{pdb_id}_disto_coords_ensemble.npy", "wb") as f: + # np.save(f, disto_coords_ensemble) + + # Create distogram + disto_target = torch.zeros(L, L, len(idx_list), num_bins) # TODO1 + + # disto_target = torch.zeros(L, L, num_bins) + for i, e_idx in enumerate(idx_list): + t_center = torch.Tensor(disto_coords_ensemble[:, e_idx, :]) + t_dists = torch.cdist(t_center, t_center) + boundaries = torch.linspace(min_dist, max_dist, num_bins - 1) + distogram = (t_dists.unsqueeze(-1) > boundaries).sum(dim=-1).long() + # disto_target += one_hot(distogram, num_classes=num_bins) + disto_target[:, :, i, :] = one_hot(distogram, num_classes=num_bins) # TODO1 + + # Normalize distogram + # disto_target = disto_target / disto_target.sum(-1)[..., None] # remove TODO1 + atom_data = np.concatenate(atom_data) + atom_name = np.concatenate(atom_name) + atom_element = np.concatenate(atom_element) + atom_charge = np.concatenate(atom_charge) + atom_conformer = np.concatenate(atom_conformer) + atom_chirality = np.concatenate(atom_chirality) + coord_data = np.concatenate(coord_data, axis=1) + ref_space_uid = np.array(ref_space_uid) + + # Compute features + disto_coords_ensemble = from_numpy(disto_coords_ensemble) + disto_coords_ensemble = disto_coords_ensemble[ + :, ensemble_features["ensemble_ref_idxs"] + ].permute(1, 0, 2) + backbone_feat_index = from_numpy(np.asarray(backbone_feat_index)).long() + ref_atom_name_chars = from_numpy(atom_name).long() + ref_element = from_numpy(atom_element).long() + ref_charge = from_numpy(atom_charge).float() + ref_pos = from_numpy(atom_conformer).float() + ref_space_uid = from_numpy(ref_space_uid) + ref_chirality = from_numpy(atom_chirality).long() + coords = from_numpy(coord_data.copy()) + resolved_mask = from_numpy(atom_data["is_present"]) + pad_mask = torch.ones(len(atom_data), dtype=torch.float) + atom_to_token = torch.tensor(atom_to_token, dtype=torch.long) + token_to_rep_atom = torch.tensor(token_to_rep_atom, dtype=torch.long) + r_set_to_rep_atom = torch.tensor(r_set_to_rep_atom, dtype=torch.long) + bfactor = from_numpy(atom_data["bfactor"].copy()) + plddt = from_numpy(atom_data["plddt"].copy()) + if override_bfactor: + bfactor = bfactor * 0.0 + + if bfactor_md_correction and data.record.structure.method.lower() == "md": + # MD bfactor was computed as RMSF + # Convert to b-factor + bfactor = 8 * (np.pi**2) * (bfactor**2) + + # We compute frames within ensemble + if compute_frames: + frames = [] + frame_resolved_mask = [] + for i in range(coord_data.shape[0]): + frame_data_, resolved_frame_data_ = compute_frames_nonpolymer( + data, + coord_data[i], + atom_data["is_present"], + atom_to_token, + frame_data, + resolved_frame_data, + ) # Compute frames for NONPOLYMER tokens + frames.append(frame_data_.copy()) + frame_resolved_mask.append(resolved_frame_data_.copy()) + frames = from_numpy(np.stack(frames)) # (N_ENS, N_TOK, 3) + frame_resolved_mask = from_numpy(np.stack(frame_resolved_mask)) + + # Convert to one-hot + backbone_feat_index = one_hot( + backbone_feat_index, + num_classes=1 + + len(const.protein_backbone_atom_index) + + len(const.nucleic_backbone_atom_index), + ) + ref_atom_name_chars = one_hot(ref_atom_name_chars, num_classes=64) + ref_element = one_hot(ref_element, num_classes=const.num_elements) + atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1) + token_to_rep_atom = one_hot(token_to_rep_atom, num_classes=len(atom_data)) + r_set_to_rep_atom = one_hot(r_set_to_rep_atom, num_classes=len(atom_data)) + + # Center the ground truth coordinates + center = (coords * resolved_mask[None, :, None]).sum(dim=1) + center = center / resolved_mask.sum().clamp(min=1) + coords = coords - center[:, None] + + if isinstance(override_coords, Tensor): + coords = override_coords.unsqueeze(0) + + # Apply random roto-translation to the input conformers + for i in range(torch.max(ref_space_uid)): + included = ref_space_uid == i + if torch.sum(included) > 0: + ref_pos[included] = center_random_augmentation( + ref_pos[included][None], + torch.ones_like(resolved_mask[included][None]), + centering=True, + )[0] + + # Compute padding and apply + if max_atoms is not None: + assert max_atoms % atoms_per_window_queries == 0 + pad_len = max_atoms - len(atom_data) + else: + pad_len = ( + (len(atom_data) - 1) // atoms_per_window_queries + 1 + ) * atoms_per_window_queries - len(atom_data) + + if pad_len > 0: + pad_mask = pad_dim(pad_mask, 0, pad_len) + ref_pos = pad_dim(ref_pos, 0, pad_len) + resolved_mask = pad_dim(resolved_mask, 0, pad_len) + ref_atom_name_chars = pad_dim(ref_atom_name_chars, 0, pad_len) + ref_element = pad_dim(ref_element, 0, pad_len) + ref_charge = pad_dim(ref_charge, 0, pad_len) + ref_chirality = pad_dim(ref_chirality, 0, pad_len) + backbone_feat_index = pad_dim(backbone_feat_index, 0, pad_len) + ref_space_uid = pad_dim(ref_space_uid, 0, pad_len) + coords = pad_dim(coords, 1, pad_len) + atom_to_token = pad_dim(atom_to_token, 0, pad_len) + token_to_rep_atom = pad_dim(token_to_rep_atom, 1, pad_len) + r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 1, pad_len) + bfactor = pad_dim(bfactor, 0, pad_len) + plddt = pad_dim(plddt, 0, pad_len) + + if max_tokens is not None: + pad_len = max_tokens - token_to_rep_atom.shape[0] + if pad_len > 0: + atom_to_token = pad_dim(atom_to_token, 1, pad_len) + token_to_rep_atom = pad_dim(token_to_rep_atom, 0, pad_len) + r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 0, pad_len) + disto_target = pad_dim(pad_dim(disto_target, 0, pad_len), 1, pad_len) + disto_coords_ensemble = pad_dim(disto_coords_ensemble, 1, pad_len) + + if compute_frames: + frames = pad_dim(frames, 1, pad_len) + frame_resolved_mask = pad_dim(frame_resolved_mask, 1, pad_len) + + atom_features = { + "ref_pos": ref_pos, + "atom_resolved_mask": resolved_mask, + "ref_atom_name_chars": ref_atom_name_chars, + "ref_element": ref_element, + "ref_charge": ref_charge, + "ref_chirality": ref_chirality, + "atom_backbone_feat": backbone_feat_index, + "ref_space_uid": ref_space_uid, + "coords": coords, + "atom_pad_mask": pad_mask, + "atom_to_token": atom_to_token, + "token_to_rep_atom": token_to_rep_atom, + "r_set_to_rep_atom": r_set_to_rep_atom, + "disto_target": disto_target, + "disto_coords_ensemble": disto_coords_ensemble, + "bfactor": bfactor, + "plddt": plddt, + } + + if compute_frames: + atom_features["frames_idx"] = frames + atom_features["frame_resolved_mask"] = frame_resolved_mask + + return atom_features + + +def process_msa_features( + data: Input, + random: np.random.Generator, + max_seqs_batch: int, + max_seqs: int, + max_tokens: Optional[int] = None, + pad_to_max_seqs: bool = False, + msa_sampling: bool = False, +) -> Dict[str, Tensor]: + """Get the MSA features. + + Parameters + ---------- + data : Input + The input to the model. + random : np.random.Generator + The random number generator. + max_seqs : int + The maximum number of MSA sequences. + max_tokens : int + The maximum number of tokens. + pad_to_max_seqs : bool + Whether to pad to the maximum number of sequences. + msa_sampling : bool + Whether to sample the MSA. + + Returns + ------- + Dict[str, Tensor] + The MSA features. + + """ + # Created paired MSA + msa, deletion, paired = construct_paired_msa( + data=data, + random=random, + max_seqs=max_seqs_batch, + random_subset=msa_sampling, + ) + msa, deletion, paired = ( + msa.transpose(1, 0), + deletion.transpose(1, 0), + paired.transpose(1, 0), + ) # (N_MSA, N_RES, N_AA) + + # Prepare features + assert torch.all(msa >= 0) and torch.all(msa < const.num_tokens) + msa_one_hot = torch.nn.functional.one_hot(msa, num_classes=const.num_tokens) + msa_mask = torch.ones_like(msa) + profile = msa_one_hot.float().mean(dim=0) + has_deletion = deletion > 0 + deletion = np.pi / 2 * np.arctan(deletion / 3) + deletion_mean = deletion.mean(axis=0) + + # Pad in the MSA dimension (dim=0) + if pad_to_max_seqs: + pad_len = max_seqs - msa.shape[0] + if pad_len > 0: + msa = pad_dim(msa, 0, pad_len, const.token_ids["-"]) + paired = pad_dim(paired, 0, pad_len) + msa_mask = pad_dim(msa_mask, 0, pad_len) + has_deletion = pad_dim(has_deletion, 0, pad_len) + deletion = pad_dim(deletion, 0, pad_len) + + # Pad in the token dimension (dim=1) + if max_tokens is not None: + pad_len = max_tokens - msa.shape[1] + if pad_len > 0: + msa = pad_dim(msa, 1, pad_len, const.token_ids["-"]) + paired = pad_dim(paired, 1, pad_len) + msa_mask = pad_dim(msa_mask, 1, pad_len) + has_deletion = pad_dim(has_deletion, 1, pad_len) + deletion = pad_dim(deletion, 1, pad_len) + profile = pad_dim(profile, 0, pad_len) + deletion_mean = pad_dim(deletion_mean, 0, pad_len) + + return { + "msa": msa, + "msa_paired": paired, + "deletion_value": deletion, + "has_deletion": has_deletion, + "deletion_mean": deletion_mean, + "profile": profile, + "msa_mask": msa_mask, + } + + +def process_template_features( + data: Input, + max_tokens: int, + max_templates: int, + pad_to_max_templates: bool, +) -> Dict[str, Tensor]: + """Get the template features. + + Parameters + ---------- + data : Input + The input to the model. + max_tokens : int + The maximum number of tokens. + max_templates : int + The maximum number of templates. + pad_to_max_templates : bool + Whether to pad to the maximum number of templates. + + Returns + ------- + Dict[str, Tensor] + The template features. + + """ + # Get feature dimensions + tdim = ( + max([1] + [len(v) for v in data.templates.values()]) + if data.templates is not None + else 1 + ) + tdim = tdim if not pad_to_max_templates else max_templates + num_tokens = len(data.tokens) if max_tokens is None else max_tokens + + # Allocate features + res_type = np.zeros((tdim, num_tokens), dtype=np.int64) + frame_rot = np.zeros((tdim, num_tokens, 3, 3), dtype=np.float32) + frame_t = np.zeros((tdim, num_tokens, 3), dtype=np.float32) + cb_coords = np.zeros((tdim, num_tokens, 3), dtype=np.float32) + ca_coords = np.zeros((tdim, num_tokens, 3), dtype=np.float32) + frame_mask = np.zeros((tdim, num_tokens), dtype=np.float32) + cb_mask = np.zeros((tdim, num_tokens), dtype=np.float32) + template_mask = np.zeros((tdim, num_tokens), dtype=np.float32) + + # Now create features per token + if data.templates is not None: + for tok_idx, token in enumerate(data.tokens): + # Check if chain has templates + chain_id = int(token["asym_id"]) + if chain_id not in data.templates: + continue + + # Add per template features + for temp_idx, template in enumerate(data.templates[chain_id]): + t_data = template.coordinates + idx = np.where(t_data["res_idx"] == token["res_idx"])[0] + if len(idx) == 0: + continue + + # Add template features + res_type[temp_idx, tok_idx] = t_data["res_type"][idx[0]] + frame_rot[temp_idx, tok_idx] = t_data["frame_rot"][idx[0]].reshape(3, 3) + frame_t[temp_idx, tok_idx] = t_data["frame_t"][idx[0]] + cb_coords[temp_idx, tok_idx] = t_data["coords_cb"][idx[0]] + ca_coords[temp_idx, tok_idx] = t_data["coords_ca"][idx[0]] + cb_mask[temp_idx, tok_idx] = t_data["mask_cb"][idx[0]] + frame_mask[temp_idx, tok_idx] = t_data["mask_frame"][idx[0]] + template_mask[temp_idx, tok_idx] = 1 + + # Convert to one-hot + res_type = torch.from_numpy(res_type) + res_type = one_hot(res_type, num_classes=const.num_tokens) + + return { + "template_restype": res_type, + "template_frame_rot": torch.from_numpy(frame_rot), + "template_frame_t": torch.from_numpy(frame_t), + "template_cb": torch.from_numpy(cb_coords), + "template_ca": torch.from_numpy(ca_coords), + "template_mask_cb": torch.from_numpy(cb_mask), + "template_mask_frame": torch.from_numpy(frame_mask), + "template_mask": torch.from_numpy(template_mask), + } + + +def process_symmetry_features(cropped: Input, symmetries: Dict) -> Dict[str, Tensor]: + """Get the symmetry features. + + Parameters + ---------- + data : Input + The input to the model. + + Returns + ------- + Dict[str, Tensor] + The symmetry features. + + """ + # TODO this does not work with multiple conformers + features = get_chain_symmetries(cropped) + features.update(get_amino_acids_symmetries(cropped)) + features.update(get_ligand_symmetries(cropped, symmetries)) + + return features + + +def process_ensemble_features( + data: Input, + random: np.random.Generator, + num_ensembles: int, + ensemble_sample_replacement: bool, + fix_single_ensemble: bool, +) -> Dict[str, Tensor]: + """Get the ensemble features. + + Parameters + ---------- + data : Input + The input to the model. + random : np.random.Generator + The random number generator. + num_ensembles : int + The maximum number of ensembles to sample. + ensemble_sample_replacement : bool + Whether to sample with replacement. + + Returns + ------- + Dict[str, Tensor] + The ensemble features. + + """ + assert num_ensembles > 0, "Number of conformers sampled must be greater than 0." + + # Number of available conformers in the structure + # s_ensemble_num = min(len(cropped.structure.ensemble), 24) # Limit to 24 conformers DEBUG: TODO: remove ! + s_ensemble_num = len(data.structure.ensemble) + + if fix_single_ensemble: + # Always take the first conformer for train and validation + assert num_ensembles == 1, ( + "Number of conformers sampled must be 1 with fix_single_ensemble=True." + ) + ensemble_ref_idxs = np.array([0]) + else: + if ensemble_sample_replacement: + # Used in training + ensemble_ref_idxs = random.integers(0, s_ensemble_num, (num_ensembles,)) + else: + # Used in validation + if s_ensemble_num < num_ensembles: + # Take all available conformers + ensemble_ref_idxs = np.arange(0, s_ensemble_num) + else: + # Sample without replacement + ensemble_ref_idxs = random.choice( + s_ensemble_num, num_ensembles, replace=False + ) + + ensemble_features = { + "ensemble_ref_idxs": torch.Tensor(ensemble_ref_idxs).long(), + } + + return ensemble_features + + +class Boltz2Featurizer: + """Boltz2 featurizer for model training.""" + + def process( + self, + data: Input, + random: np.random.Generator, + molecules: Dict[str, Mol], + training: bool, + max_seqs: int, + atoms_per_window_queries: int = 32, + min_dist: float = 2.0, + max_dist: float = 22.0, + num_bins: int = 64, + num_ensembles: int = 1, + ensemble_sample_replacement: bool = False, + disto_use_ensemble: Optional[bool] = False, + fix_single_ensemble: Optional[bool] = True, + max_tokens: Optional[int] = None, + max_atoms: Optional[int] = None, + pad_to_max_seqs: bool = False, + compute_symmetries: bool = False, + binder_pocket_conditioned_prop: Optional[float] = 0.0, + contact_conditioned_prop: Optional[float] = 0.0, + binder_pocket_cutoff_min: Optional[float] = 4.0, + binder_pocket_cutoff_max: Optional[float] = 20.0, + binder_pocket_sampling_geometric_p: Optional[float] = 0.0, + only_ligand_binder_pocket: Optional[bool] = False, + only_pp_contact: Optional[bool] = False, + maximum_bond_distance: Optional[int] = 0, + single_sequence_prop: Optional[float] = 0.0, + msa_sampling: bool = False, + use_templates: bool = False, + max_templates: int = 4, + pad_to_max_templates: bool = False, + override_bfactor: float = False, + override_method: Optional[str] = None, + compute_frames: bool = False, + override_coords: Optional[Tensor] = None, + bfactor_md_correction: bool = False, + ) -> Dict[str, Tensor]: + """Compute features. + + Parameters + ---------- + data : Input + The input to the model. + training : bool + Whether the model is in training mode. + max_tokens : int, optional + The maximum number of tokens. + max_atoms : int, optional + The maximum number of atoms + max_seqs : int, optional + The maximum number of sequences. + + Returns + ------- + Dict[str, Tensor] + The features for model training. + + """ + # Compute random number of sequences + if training and max_seqs is not None: + if random.random() > single_sequence_prop: + max_seqs_batch = random.integers(1, max_seqs + 1) + else: + max_seqs_batch = 1 + else: + max_seqs_batch = max_seqs + + # Compute ensemble features + ensemble_features = process_ensemble_features( + data=data, + random=random, + num_ensembles=num_ensembles, + ensemble_sample_replacement=ensemble_sample_replacement, + fix_single_ensemble=fix_single_ensemble, + ) + + # Compute token features + token_features = process_token_features( + data=data, + random=random, + max_tokens=max_tokens, + binder_pocket_conditioned_prop=binder_pocket_conditioned_prop, + contact_conditioned_prop=contact_conditioned_prop, + binder_pocket_cutoff_min=binder_pocket_cutoff_min, + binder_pocket_cutoff_max=binder_pocket_cutoff_max, + binder_pocket_sampling_geometric_p=binder_pocket_sampling_geometric_p, + only_ligand_binder_pocket=only_ligand_binder_pocket, + only_pp_contact=only_pp_contact, + maximum_bond_distance=maximum_bond_distance, + override_method=override_method, + ) + + # Compute atom features + atom_features = process_atom_features( + data=data, + random=random, + molecules=molecules, + ensemble_features=ensemble_features, + atoms_per_window_queries=atoms_per_window_queries, + min_dist=min_dist, + max_dist=max_dist, + num_bins=num_bins, + max_atoms=max_atoms, + max_tokens=max_tokens, + disto_use_ensemble=disto_use_ensemble, + override_bfactor=override_bfactor, + compute_frames=compute_frames, + override_coords=override_coords, + bfactor_md_correction=bfactor_md_correction, + ) + + # Compute MSA features + msa_features = process_msa_features( + data=data, + random=random, + max_seqs_batch=max_seqs_batch, + max_seqs=max_seqs, + max_tokens=max_tokens, + pad_to_max_seqs=pad_to_max_seqs, + msa_sampling=training and msa_sampling, + ) + + # Compute template features + if use_templates: + template_features = process_template_features( + data=data, + max_tokens=max_tokens, + max_templates=max_templates, + pad_to_max_templates=pad_to_max_templates, + ) + else: + template_features = {} + + # Compute symmetry features + symmetry_features = {} + if compute_symmetries: + symmetries = get_symmetries(molecules) + symmetry_features = process_symmetry_features(data, symmetries) + + return { + **token_features, + **atom_features, + **msa_features, + **template_features, + **symmetry_features, + **ensemble_features, + } diff --git a/src/boltz/data/module/trainingv2.py b/src/boltz/data/module/trainingv2.py index 2141db172..99891fe47 100644 --- a/src/boltz/data/module/trainingv2.py +++ b/src/boltz/data/module/trainingv2.py @@ -1,21 +1,32 @@ +import json +from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import Optional import numpy as np +import pandas as pd import pytorch_lightning as pl import torch +from rdkit.Chem import Mol from torch import Tensor from torch.utils.data import DataLoader from boltz.data.crop.cropper import Cropper from boltz.data.feature.featurizer import BoltzFeaturizer -from boltz.data.feature.symmetry import get_symmetries from boltz.data.filter.dynamic.filter import DynamicFilter +from boltz.data.mol import load_canonicals, load_molecules from boltz.data.pad import pad_to_max -from boltz.data.sample.sampler import Sample, Sampler +from boltz.data.sample.v2.sampler import Sample, Sampler from boltz.data.tokenize.tokenizer import Tokenizer -from boltz.data.types import MSA, Connection, Input, Manifest, Record, Structure +from boltz.data.types import ( + MSA, + InputTraining, + Manifest, + Record, + StructureV2, + Template, +) @dataclass @@ -24,20 +35,25 @@ class DatasetConfig: target_dir: str msa_dir: str - prob: float + prob: Optional[float] sampler: Sampler cropper: Cropper - filters: Optional[list] = None + template_dir: Optional[str] = None + filters: Optional[list[DynamicFilter]] = None split: Optional[str] = None - manifest_path: Optional[str] = None + symmetry_correction: bool = True + val_group: Optional[str] = "RCSB" + use_train_subset: Optional[float] = None + moldir: Optional[str] = None + override_bfactor: Optional[bool] = False + override_method: Optional[str] = None @dataclass -class DataConfig: +class DataConfigV2: """Data configuration.""" datasets: list[DatasetConfig] - filters: list[DynamicFilter] featurizer: BoltzFeaturizer tokenizer: Tokenizer max_atoms: int @@ -48,78 +64,205 @@ class DataConfig: num_workers: int random_seed: int pin_memory: bool - symmetries: str atoms_per_window_queries: int min_dist: float max_dist: float num_bins: int + checkpoint_monitor_val_group: str + num_ensembles_train: int = 1 + num_ensembles_val: int = 1 + disto_use_ensemble: Optional[bool] = False + fix_single_ensemble: Optional[bool] = True overfit: Optional[int] = None pad_to_max_tokens: bool = False pad_to_max_atoms: bool = False pad_to_max_seqs: bool = False - crop_validation: bool = False return_train_symmetries: bool = False return_val_symmetries: bool = True train_binder_pocket_conditioned_prop: float = 0.0 val_binder_pocket_conditioned_prop: float = 0.0 - binder_pocket_cutoff: float = 6.0 + train_contact_conditioned_prop: float = 0.0 + val_contact_conditioned_prop: float = 0.0 + binder_pocket_cutoff_min: float = 4.0 + binder_pocket_cutoff_max: float = 20.0 + binder_pocket_cutoff_val: float = 6.0 binder_pocket_sampling_geometric_p: float = 0.0 val_batch_size: int = 1 + single_sequence_prop_training: float = 0.0 + msa_sampling_training: bool = False + use_templates: bool = False + max_templates_train: int = 4 + max_templates_val: int = 4 + no_template_prob_train: float = 1.0 + no_template_prob_val: float = 1.0 + moldir: Optional[str] = None + compute_frames: bool = False + bfactor_md_correction: Optional[bool] = False @dataclass class Dataset: """Data holder.""" - target_dir: Path + samples: pd.DataFrame + struct_dir: Path msa_dir: Path - manifest: Manifest + record_dir: Path + template_dir: Path prob: float - sampler: Sampler cropper: Cropper tokenizer: Tokenizer featurizer: BoltzFeaturizer + val_group: str + symmetry_correction: bool = True + moldir: Optional[str] = None + override_bfactor: Optional[bool] = False + override_method: Optional[str] = None + + +def load_record(record_id: str, record_dir: Path) -> Record: + """Load the given record. + + Parameters + ---------- + record_id : str + The record id to load. + record_dir : Path + The path to the record directory. + + Returns + ------- + Record + The loaded record. + """ + return Record.load(record_dir / f"{record_id}.json") -def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input: +def load_structure(record: Record, struct_dir: Path) -> StructureV2: """Load the given input data. Parameters ---------- - record : Record + record : str The record to load. target_dir : Path The path to the data directory. - msa_dir : Path - The path to msa directory. Returns ------- - Input + InputTraining The loaded input. """ - # Load the structure - structure = np.load(target_dir / "structures" / f"{record.id}.npz") - structure = Structure( - atoms=structure["atoms"], - bonds=structure["bonds"], - residues=structure["residues"], - chains=structure["chains"], - connections=structure["connections"].astype(Connection), - interfaces=structure["interfaces"], - mask=structure["mask"], - ) + if (struct_dir / f"{record.id}.npz").exists(): + structure_path = struct_dir / f"{record.id}.npz" + else: + structure_path = struct_dir / f"{record.id}" / f"{record.id}_model_0.npz" + return StructureV2.load(structure_path) + + +def load_msas(chain_ids: set[int], record: Record, msa_dir: Path) -> InputTraining: + """Load the given input data. + Parameters + ---------- + chain_ids : set[int] + The chain ids to load. + record : Record + The record to load. + msa_dir : Path + The path to the MSA directory. + + Returns + ------- + InputTraining + The loaded input. + + """ msas = {} for chain in record.chains: + if chain.chain_id not in chain_ids: + continue + msa_id = chain.msa_id - # Load the MSA for this chain, if any - if msa_id != -1 and msa_id != "": - msa = np.load(msa_dir / f"{msa_id}.npz") - msas[chain.chain_id] = MSA(**msa) + if msa_id != -1: + msa_path = msa_dir / f"{msa_id}.npz" + msa = MSA.load(msa_path) + msas[chain.chain_id] = msa + + return msas + + +def load_templates( + chain_ids: set[int], + record: Record, + template_dir: Path, + max_templates: int, + no_template_prob: float, + training: bool, + random: np.random.Generator, +) -> dict[str, list[Template]]: + """Load the given input data. - return Input(structure, msas) + Parameters + ---------- + record : str + The record to load. + target_dir : Path + The path to the data directory. + msa_dir : Path + The path to the MSA directory. + template_dir : Path + The path to the template directory. + max_templates : int + The maximum number of templates to load. + no_template_prob : float + The probability of not loading any templates. + training : bool + Whether the data is for training. + random : np.random.Generator + The random number generator. + + Returns + ------- + dict[str, list[Template]] + The loaded templates. + + """ + templates = {} + for chain in record.chains: + if chain.chain_id not in chain_ids: + continue + + # Check if chain has templates, skipping non proteins + template_ids = chain.template_ids + if template_ids is None: + continue + + # Pick how many templates to sample + max_chain_templates = min(max_templates, len(template_ids)) + + # If 0, skips + if (max_chain_templates == 0) or (random.random() < no_template_prob): + continue + + # Sample for training, pick firsts for validation + if training: + max_chain_templates = random.integers(1, max_chain_templates + 1) + template_indices = torch.randperm(len(template_ids)) + template_indices = template_indices[:max_chain_templates] + template_ids = [template_ids[idx.item()] for idx in template_indices] + else: + template_ids = template_ids[:max_chain_templates] + + # Load templates + templates[chain.chain_id] = [] + for template_name in template_ids: + template_path = template_dir / f"{template_name}.npz" + template = Template.load(template_path) + templates[chain.chain_id].append(template) + + return templates def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]: @@ -127,12 +270,12 @@ def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]: Parameters ---------- - data : list[dict[str, Tensor]] + data : List[Dict[str, Tensor]] The data to collate. Returns ------- - dict[str, Tensor] + Dict[str, Tensor] The collated data. """ @@ -149,8 +292,31 @@ def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]: "all_resolved_mask", "crop_to_all_atom_map", "chain_symmetries", + "chain_swaps", "amino_acids_symmetries", "ligand_symmetries", + "activity_name", + "activity_qualifier", + "sid", + "cid", + "normalized_protein_accession", + "pair_id", + "ligand_edge_index", + "ligand_edge_lower_bounds", + "ligand_edge_upper_bounds", + "ligand_edge_bond_mask", + "ligand_edge_angle_mask", + "connections_edge_index", + "ligand_chiral_atom_index", + "ligand_chiral_check_mask", + "ligand_chiral_atom_orientations", + "ligand_stereo_bond_index", + "ligand_stereo_check_mask", + "ligand_stereo_bond_orientations", + "ligand_aromatic_5_ring_index", + "ligand_aromatic_6_ring_index", + "ligand_planar_double_bond_index", + "pdb_id", ]: # Check if all have the same shape shape = values[0].shape @@ -171,8 +337,9 @@ class TrainingDataset(torch.utils.data.Dataset): def __init__( self, datasets: list[Dataset], + canonicals: dict[str, Mol], + moldir: str, samples_per_epoch: int, - symmetries: dict, max_atoms: int, max_tokens: int, max_seqs: int, @@ -183,18 +350,48 @@ def __init__( min_dist: float = 2.0, max_dist: float = 22.0, num_bins: int = 64, + num_ensembles: int = 1, + ensemble_sample_replacement: Optional[bool] = True, + disto_use_ensemble: Optional[bool] = False, + fix_single_ensemble: Optional[bool] = True, overfit: Optional[int] = None, binder_pocket_conditioned_prop: Optional[float] = 0.0, - binder_pocket_cutoff: Optional[float] = 6.0, + contact_conditioned_prop: Optional[float] = 0.0, + binder_pocket_cutoff_min: Optional[float] = 4.0, + binder_pocket_cutoff_max: Optional[float] = 20.0, binder_pocket_sampling_geometric_p: Optional[float] = 0.0, return_symmetries: Optional[bool] = False, + use_templates: bool = False, + max_templates: int = 4, + no_template_prob: float = 0.6, + single_sequence_prop: Optional[float] = 0.0, + msa_sampling: bool = False, + compute_frames: bool = False, + bfactor_md_correction: bool = False, ) -> None: - """Initialize the training dataset.""" + """Initialize the training dataset. + + Parameters + ---------- + datasets : List[Dataset] + The datasets to sample from. + samplers : List[Sampler] + The samplers to sample from each dataset. + probs : List[float] + The probabilities to sample from each dataset. + samples_per_epoch : int + The number of samples per epoch. + max_tokens : int + The maximum number of tokens. + + """ super().__init__() + self.datasets = datasets + self.canonicals = canonicals + self.moldir = moldir self.probs = [d.prob for d in datasets] self.samples_per_epoch = samples_per_epoch - self.symmetries = symmetries self.max_tokens = max_tokens self.max_seqs = max_seqs self.max_atoms = max_atoms @@ -205,56 +402,84 @@ def __init__( self.min_dist = min_dist self.max_dist = max_dist self.num_bins = num_bins + self.num_ensembles = num_ensembles + self.ensemble_sample_replacement = ensemble_sample_replacement + self.disto_use_ensemble = disto_use_ensemble + self.fix_single_ensemble = fix_single_ensemble self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop - self.binder_pocket_cutoff = binder_pocket_cutoff + self.contact_conditioned_prop = contact_conditioned_prop + self.binder_pocket_cutoff_min = binder_pocket_cutoff_min + self.binder_pocket_cutoff_max = binder_pocket_cutoff_max self.binder_pocket_sampling_geometric_p = binder_pocket_sampling_geometric_p self.return_symmetries = return_symmetries - self.samples = [] - for dataset in datasets: - records = dataset.manifest.records - if overfit is not None: - records = records[:overfit] - iterator = dataset.sampler.sample(records, np.random) - self.samples.append(iterator) + self.single_sequence_prop = single_sequence_prop + self.msa_sampling = msa_sampling + self.use_templates = use_templates + self.max_templates = max_templates + self.no_template_prob = no_template_prob + self.overfit = overfit + self.compute_frames = compute_frames + self.bfactor_md_correction = bfactor_md_correction + + self.samples: list[pd.DataFrame] = [] + for d in datasets: + self.samples.append(d.samples[:overfit] if overfit else d.samples) def __getitem__(self, idx: int) -> dict[str, Tensor]: """Get an item from the dataset. - Parameters - ---------- - idx : int - The data index. - Returns ------- - dict[str, Tensor] + Dict[str, Tensor] The sampled data features. """ + # Set a random state + random = np.random.default_rng() + # Pick a random dataset - dataset_idx = np.random.choice( - len(self.datasets), - p=self.probs, - ) + dataset_idx = random.choice(len(self.datasets), p=self.probs) dataset = self.datasets[dataset_idx] # Get a sample from the dataset - sample: Sample = next(self.samples[dataset_idx]) + samples = self.samples[dataset_idx] + sample_idx = random.choice( + len(samples), + p=( + samples["weight"] / np.sum(samples["weight"]) + if self.overfit + else samples["weight"] + ), + ) + sample = samples.iloc[sample_idx].to_dict() + sample: Sample = Sample( + record_id=str(sample["record_id"]), + chain_id=( + int(sample["chain_id"]) if sample["chain_id"] is not None else None + ), + interface_id=( + int(sample["interface_id"]) + if sample["interface_id"] is not None + else None + ), + weight=float(sample["weight"]), + ) + + # Load record + record = load_record(sample.record_id, dataset.record_dir) # Get the structure try: - input_data = load_input(sample.record, dataset.target_dir, dataset.msa_dir) - except Exception as e: - print( - f"Failed to load input for {sample.record.id} with error {e}. Skipping." - ) + structure = load_structure(record, dataset.struct_dir) + except Exception as e: # noqa: BLE001 + print(f"Failed to load input for {record.id} with error {e}. Skipping.") return self.__getitem__(idx) # Tokenize structure try: - tokenized = dataset.tokenizer.tokenize(input_data) - except Exception as e: - print(f"Tokenizer failed on {sample.record.id} with error {e}. Skipping.") + tokenized = dataset.tokenizer.tokenize(structure) + except Exception as e: # noqa: BLE001 + print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") return self.__getitem__(idx) # Compute crop @@ -264,42 +489,121 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: tokenized, max_atoms=self.max_atoms, max_tokens=self.max_tokens, - random=np.random, chain_id=sample.chain_id, interface_id=sample.interface_id, + random=random, ) - except Exception as e: - print(f"Cropper failed on {sample.record.id} with error {e}. Skipping.") + if len(tokenized.tokens) == 0: + msg = "No tokens in cropped structure." + raise ValueError(msg) # noqa: TRY301 + except Exception as e: # noqa: BLE001 + print(f"Cropper failed on {record.id} with error {e}. Skipping.") return self.__getitem__(idx) - # Check if there are tokens - if len(tokenized.tokens) == 0: - msg = "No tokens in cropped structure." - raise ValueError(msg) + # Get unique chain ids + chain_ids = set(tokenized.tokens["asym_id"]) + + # Load msas and templates + try: + msas = load_msas( + chain_ids=chain_ids, + record=record, + msa_dir=dataset.msa_dir, + ) + except Exception as e: # noqa: BLE001 + print(f"MSA loading failed for {record.id} with error {e}. Skipping.") + return self.__getitem__(0) + + # Load templates + templates = FileNotFoundError + if self.use_templates and dataset.template_dir is not None: + try: + templates = load_templates( + chain_ids=chain_ids, + record=record, + template_dir=dataset.template_dir, + max_templates=self.max_templates, + no_template_prob=self.no_template_prob, + training=True, + random=random, + ) + except Exception as e: # noqa: BLE001 + print( + f"Template loading failed for {record.id} with error {e}. Using no templates." + ) + templates = None + else: + templates = None + + # Load molecules + try: + # Try to find molecules in the dataset moldir if provided + # Find missing ones in global moldir and check if all found + molecules = {} + molecules.update(self.canonicals) + mol_names = set(tokenized.tokens["res_name"].tolist()) + mol_names = mol_names - set(self.canonicals.keys()) + if dataset.moldir is not None: + molecules.update(load_molecules(dataset.moldir, mol_names)) + + mol_names = mol_names - set(molecules.keys()) + molecules.update(load_molecules(self.moldir, mol_names)) + except Exception as e: # noqa: BLE001 + print(f"Molecule loading failed for {record.id} with error {e}. Skipping.") + return self.__getitem__(0) + + # Finalize input data + input_data = InputTraining( + tokens=tokenized.tokens, + bonds=tokenized.bonds, + structure=structure, + msa=msas, + templates=templates, + record=record, + ) # Compute features try: - features = dataset.featurizer.process( - tokenized, + features: dict = dataset.featurizer.process( + input_data, + molecules=molecules, + random=random, training=True, max_atoms=self.max_atoms if self.pad_to_max_atoms else None, max_tokens=self.max_tokens if self.pad_to_max_tokens else None, max_seqs=self.max_seqs, pad_to_max_seqs=self.pad_to_max_seqs, - symmetries=self.symmetries, atoms_per_window_queries=self.atoms_per_window_queries, min_dist=self.min_dist, max_dist=self.max_dist, num_bins=self.num_bins, + num_ensembles=self.num_ensembles, + ensemble_sample_replacement=self.ensemble_sample_replacement, + disto_use_ensemble=self.disto_use_ensemble, + fix_single_ensemble=self.fix_single_ensemble, compute_symmetries=self.return_symmetries, binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop, - binder_pocket_cutoff=self.binder_pocket_cutoff, + contact_conditioned_prop=self.contact_conditioned_prop, + binder_pocket_cutoff_min=self.binder_pocket_cutoff_min, + binder_pocket_cutoff_max=self.binder_pocket_cutoff_max, binder_pocket_sampling_geometric_p=self.binder_pocket_sampling_geometric_p, + single_sequence_prop=self.single_sequence_prop, + msa_sampling=self.msa_sampling, + use_templates=self.use_templates, + max_templates=self.max_templates, + override_bfactor=dataset.override_bfactor, + override_method=dataset.override_method, + compute_frames=self.compute_frames, + bfactor_md_correction=self.bfactor_md_correction, ) - except Exception as e: - print(f"Featurizer failed on {sample.record.id} with error {e}. Skipping.") + except Exception as e: # noqa: BLE001 + print(f"Featurizer failed on {record.id} with error {e}. Skipping.") + import traceback + + traceback.print_exc() return self.__getitem__(idx) + features["pdb_id"] = record.id return features def __len__(self) -> int: @@ -320,8 +624,9 @@ class ValidationDataset(torch.utils.data.Dataset): def __init__( self, datasets: list[Dataset], + canonicals: dict[str, Mol], + moldir: str, seed: int, - symmetries: dict, max_atoms: Optional[int] = None, max_tokens: Optional[int] = None, max_seqs: Optional[int] = None, @@ -332,51 +637,81 @@ def __init__( min_dist: float = 2.0, max_dist: float = 22.0, num_bins: int = 64, + num_ensembles: int = 1, + ensemble_sample_replacement: Optional[bool] = False, + disto_use_ensemble: Optional[bool] = False, + fix_single_ensemble: Optional[bool] = True, overfit: Optional[int] = None, - crop_validation: bool = False, return_symmetries: Optional[bool] = False, binder_pocket_conditioned_prop: Optional[float] = 0.0, + contact_conditioned_prop: Optional[float] = 0.0, binder_pocket_cutoff: Optional[float] = 6.0, + use_templates: bool = False, + max_templates: int = 4, + no_template_prob: float = 0.0, + compute_frames: bool = False, + bfactor_md_correction: bool = False, ) -> None: - """Initialize the validation dataset.""" + """Initialize the training dataset. + + Parameters + ---------- + datasets : List[Dataset] + The datasets to sample from. + seed : int + The random seed. + max_tokens : int + The maximum number of tokens. + overfit : bool + Whether to overfit the dataset + + """ super().__init__() self.datasets = datasets + self.canonicals = canonicals + self.moldir = moldir self.max_atoms = max_atoms self.max_tokens = max_tokens self.max_seqs = max_seqs self.seed = seed - self.symmetries = symmetries - self.random = np.random if overfit else np.random.RandomState(self.seed) self.pad_to_max_tokens = pad_to_max_tokens self.pad_to_max_atoms = pad_to_max_atoms self.pad_to_max_seqs = pad_to_max_seqs self.overfit = overfit - self.crop_validation = crop_validation self.atoms_per_window_queries = atoms_per_window_queries self.min_dist = min_dist self.max_dist = max_dist self.num_bins = num_bins + self.num_ensembles = num_ensembles + self.ensemble_sample_replacement = ensemble_sample_replacement + self.disto_use_ensemble = disto_use_ensemble + self.fix_single_ensemble = fix_single_ensemble self.return_symmetries = return_symmetries self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop + self.contact_conditioned_prop = contact_conditioned_prop self.binder_pocket_cutoff = binder_pocket_cutoff + self.use_templates = use_templates + self.max_templates = max_templates + self.no_template_prob = no_template_prob + self.compute_frames = compute_frames + self.bfactor_md_correction = bfactor_md_correction - def __getitem__(self, idx: int) -> dict[str, Tensor]: + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: """Get an item from the dataset. - Parameters - ---------- - idx : int - The data index. - Returns ------- - dict[str, Tensor] + Dict[str, Tensor] The sampled data features. """ + # Set random state + seed = self.seed if self.overfit is None else None + random = np.random.default_rng(seed) + # Pick dataset based on idx - for dataset in self.datasets: - size = len(dataset.manifest.records) + for idx_dataset, dataset in enumerate(self.datasets): # noqa: B007 + size = len(dataset.samples) if self.overfit is not None: size = min(size, self.overfit) if idx < size: @@ -384,67 +719,120 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: idx -= size # Get a sample from the dataset - record = dataset.manifest.records[idx] + sample = Sample(**dataset.samples.iloc[idx].to_dict()) + record = load_record(sample.record_id, dataset.record_dir) # Get the structure try: - input_data = load_input(record, dataset.target_dir, dataset.msa_dir) - except Exception as e: + structure = load_structure(record, dataset.struct_dir) + except Exception as e: # noqa: BLE001 print(f"Failed to load input for {record.id} with error {e}. Skipping.") return self.__getitem__(0) # Tokenize structure try: - tokenized = dataset.tokenizer.tokenize(input_data) - except Exception as e: + tokenized = dataset.tokenizer.tokenize(structure) + except Exception as e: # noqa: BLE001 print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") return self.__getitem__(0) - # Compute crop + # Get unique chains + chain_ids = set(np.unique(tokenized.tokens["asym_id"]).tolist()) + + # Load msas and templates try: - if self.crop_validation and (self.max_tokens is not None): - tokenized = dataset.cropper.crop( - tokenized, - max_tokens=self.max_tokens, - random=self.random, - max_atoms=self.max_atoms, + msas = load_msas(chain_ids, record, dataset.msa_dir) + except Exception as e: # noqa: BLE001 + print(f"MSA loading failed for {record.id} with error {e}. Skipping.") + return self.__getitem__(0) + + # Load templates + if self.use_templates and dataset.template_dir is not None: + try: + templates = load_templates( + chain_ids=chain_ids, + record=record, + template_dir=dataset.template_dir, + max_templates=self.max_templates, + no_template_prob=self.no_template_prob, + training=False, + random=random, ) - except Exception as e: - print(f"Cropper failed on {record.id} with error {e}. Skipping.") + except Exception as e: # noqa: BLE001 + print( + f"Template loading failed for {record.id} with error {e}. Using no templates." + ) + templates = None + else: + templates = None + try: + # Try to find molecules in the dataset moldir if provided + # Find missing ones in global moldir and check if all found + molecules = {} + molecules.update(self.canonicals) + mol_names = set(tokenized.tokens["res_name"].tolist()) + mol_names = mol_names - set(self.canonicals.keys()) + if dataset.moldir is not None: + molecules.update(load_molecules(dataset.moldir, mol_names)) + + mol_names = mol_names - set(molecules.keys()) + molecules.update(load_molecules(self.moldir, mol_names)) + except Exception as e: # noqa: BLE001 + print(f"Molecule loading failed for {record.id} with error {e}. Skipping.") return self.__getitem__(0) - # Check if there are tokens - if len(tokenized.tokens) == 0: - msg = "No tokens in cropped structure." - raise ValueError(msg) + # Finalize input data + input_data = InputTraining( + tokens=tokenized.tokens, + bonds=tokenized.bonds, + structure=structure, + msa=msas, + templates=templates, + record=record, + ) # Compute features try: - pad_atoms = self.crop_validation and self.pad_to_max_atoms - pad_tokens = self.crop_validation and self.pad_to_max_tokens - - features = dataset.featurizer.process( - tokenized, + features: dict = dataset.featurizer.process( + input_data, + molecules=molecules, + random=random, training=False, - max_atoms=self.max_atoms if pad_atoms else None, - max_tokens=self.max_tokens if pad_tokens else None, + max_atoms=None, + max_tokens=None, max_seqs=self.max_seqs, pad_to_max_seqs=self.pad_to_max_seqs, - symmetries=self.symmetries, atoms_per_window_queries=self.atoms_per_window_queries, min_dist=self.min_dist, max_dist=self.max_dist, num_bins=self.num_bins, + num_ensembles=self.num_ensembles, + ensemble_sample_replacement=self.ensemble_sample_replacement, + disto_use_ensemble=self.disto_use_ensemble, + fix_single_ensemble=self.fix_single_ensemble, compute_symmetries=self.return_symmetries, binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop, - binder_pocket_cutoff=self.binder_pocket_cutoff, + contact_conditioned_prop=self.contact_conditioned_prop, + binder_pocket_cutoff_min=self.binder_pocket_cutoff, + binder_pocket_cutoff_max=self.binder_pocket_cutoff, binder_pocket_sampling_geometric_p=1.0, # this will only sample a single pocket token only_ligand_binder_pocket=True, + only_pp_contact=True, + single_sequence_prop=0.0, + use_templates=self.use_templates, + max_templates=self.max_templates, + override_method=dataset.override_method, + compute_frames=self.compute_frames, + bfactor_md_correction=self.bfactor_md_correction, ) - except Exception as e: + + except Exception as e: # noqa: BLE001 print(f"Featurizer failed on {record.id} with error {e}. Skipping.") return self.__getitem__(0) + # Add dataset idx + idx_dataset = torch.tensor([idx_dataset], dtype=torch.long) + features.update({"idx_dataset": idx_dataset}) return features def __len__(self) -> int: @@ -453,26 +841,26 @@ def __len__(self) -> int: Returns ------- int - The length of the dataset. + The length of the dataaset. """ if self.overfit is not None: - length = sum(len(d.manifest.records[: self.overfit]) for d in self.datasets) + length = sum(len(d.samples[: self.overfit]) for d in self.datasets) else: - length = sum(len(d.manifest.records) for d in self.datasets) + length = sum(len(d.samples) for d in self.datasets) return length -class BoltzTrainingDataModule(pl.LightningDataModule): - """DataModule for boltz.""" +class Boltz2TrainingDataModule(pl.LightningDataModule): + """DataModule for Boltz2.""" - def __init__(self, cfg: DataConfig) -> None: + def __init__(self, cfg: DataConfigV2) -> None: """Initialize the DataModule. Parameters ---------- - config : DataConfig + config : DataConfigV2 The data configuration. """ @@ -480,27 +868,29 @@ def __init__(self, cfg: DataConfig) -> None: self.cfg = cfg assert self.cfg.val_batch_size == 1, "Validation only works with batch size=1." - - # Load symmetries - symmetries = get_symmetries(cfg.symmetries) - # Load datasets train: list[Dataset] = [] val: list[Dataset] = [] for data_config in cfg.datasets: - # Set target_dir - target_dir = Path(data_config.target_dir) + # Get relevant directories + manifest_path = Path(data_config.target_dir) / "manifest.json" + struct_dir = Path(data_config.target_dir) / "structures" + record_dir = Path(data_config.target_dir) / "records" msa_dir = Path(data_config.msa_dir) - # Load manifest - if data_config.manifest_path is not None: - path = Path(data_config.manifest_path) - else: - path = target_dir / "manifest.json" - manifest: Manifest = Manifest.load(path) + # Get template_dir, if any + template_dir = data_config.template_dir + template_dir = Path(template_dir) if template_dir is not None else None - # Split records if given + # Get moldir, if any + moldir = data_config.moldir + moldir = Path(moldir) if moldir is not None else None + + # Load all records + manifest: Manifest = Manifest.load(manifest_path) + + # Split records if givens if data_config.split is not None: with Path(data_config.split).open("r") as f: split = {x.lower() for x in f.read().splitlines()} @@ -514,15 +904,15 @@ def __init__(self, cfg: DataConfig) -> None: train_records.append(record) else: train_records = manifest.records - val_records = [] + if cfg.overfit is None: + val_records = [] + else: + print("Warning: modified overfit val behavior.") + val_records = manifest.records[: cfg.overfit] + + print("train_records before filter", len(train_records)) - # Filter training records - train_records = [ - record - for record in train_records - if all(f.filter(record) for f in cfg.filters) - ] - # Filter training records + # Apply dataset-specific filters if data_config.filters is not None: train_records = [ record @@ -530,49 +920,108 @@ def __init__(self, cfg: DataConfig) -> None: if all(f.filter(record) for f in data_config.filters) ] + # Train with subset of data + if data_config.use_train_subset is not None: + # Shuffle train_records list + assert 0 < data_config.use_train_subset < 1.0 + rng = np.random.default_rng(cfg.random_seed) + rng.shuffle(train_records) + train_records = train_records[ + 0 : int(len(train_records) * data_config.use_train_subset) + ] + print("train_records after filter", len(train_records)) + print("val_records after filter", len(val_records)) + + # Get samples + train_samples: list[Sample] = data_config.sampler.sample(train_records) + val_samples: list[Sample] = [Sample(r.id) for r in val_records] + del manifest, train_records, val_records + + # Convert samples to pandas dataframe to avoid copy-on-write behavior + train_samples = pd.DataFrame( + [ + ( + r.record_id, + r.chain_id, + r.interface_id, + r.weight, + ) + for r in train_samples + ], + columns=["record_id", "chain_id", "interface_id", "weight"], + ) + val_samples = pd.DataFrame( + [s.record_id for s in val_samples], columns=["record_id"] + ) + + # Use appropriate string type + train_samples = train_samples.replace({np.nan: None}) + val_samples = val_samples.replace({np.nan: None}) + train_samples["record_id"] = train_samples["record_id"].astype("string") + val_samples["record_id"] = val_samples["record_id"].astype("string") + # Create train dataset - train_manifest = Manifest(train_records) - train.append( - Dataset( - target_dir, - msa_dir, - train_manifest, - data_config.prob, - data_config.sampler, - data_config.cropper, - cfg.tokenizer, - cfg.featurizer, + if data_config.prob > 0: + train.append( + Dataset( + samples=train_samples, + record_dir=record_dir, + struct_dir=struct_dir, + msa_dir=msa_dir, + template_dir=template_dir, + moldir=moldir, + prob=data_config.prob, + cropper=data_config.cropper, + tokenizer=cfg.tokenizer, + featurizer=cfg.featurizer, + val_group=data_config.val_group, + symmetry_correction=data_config.symmetry_correction, + override_bfactor=data_config.override_bfactor, + override_method=data_config.override_method, + ) ) - ) # Create validation dataset - if val_records: - val_manifest = Manifest(val_records) + if len(val_samples) > 0: val.append( Dataset( - target_dir, - msa_dir, - val_manifest, - data_config.prob, - data_config.sampler, - data_config.cropper, - cfg.tokenizer, - cfg.featurizer, + samples=val_samples, + record_dir=record_dir, + struct_dir=struct_dir, + msa_dir=msa_dir, + template_dir=template_dir, + moldir=moldir, + prob=data_config.prob, + cropper=data_config.cropper, + tokenizer=cfg.tokenizer, + featurizer=cfg.featurizer, + val_group=data_config.val_group, + symmetry_correction=data_config.symmetry_correction, ) ) # Print dataset sizes for dataset in train: dataset: Dataset - print(f"Training dataset size: {len(dataset.manifest.records)}") + print(f"Training dataset size: {len(dataset.samples)}") - for dataset in val: + self.val_group_mapper = defaultdict(dict) + for i, dataset in enumerate(train if cfg.overfit is not None else val): dataset: Dataset - print(f"Validation dataset size: {len(dataset.manifest.records)}") + print(f"Validation dataset size: {len(dataset.samples)}") + self.val_group_mapper[i]["label"] = dataset.val_group + self.val_group_mapper[i]["symmetry_correction"] = ( + dataset.symmetry_correction + ) + + # Load canonical molecules + canonicals = load_canonicals(cfg.moldir) # Create wrapper datasets self._train_set = TrainingDataset( datasets=train, + canonicals=canonicals, + moldir=cfg.moldir, samples_per_epoch=cfg.samples_per_epoch, max_atoms=cfg.max_atoms, max_tokens=cfg.max_tokens, @@ -580,19 +1029,33 @@ def __init__(self, cfg: DataConfig) -> None: pad_to_max_atoms=cfg.pad_to_max_atoms, pad_to_max_tokens=cfg.pad_to_max_tokens, pad_to_max_seqs=cfg.pad_to_max_seqs, - symmetries=symmetries, atoms_per_window_queries=cfg.atoms_per_window_queries, min_dist=cfg.min_dist, max_dist=cfg.max_dist, num_bins=cfg.num_bins, + num_ensembles=cfg.num_ensembles_train, + ensemble_sample_replacement=True, + disto_use_ensemble=cfg.disto_use_ensemble, + fix_single_ensemble=cfg.fix_single_ensemble, overfit=cfg.overfit, binder_pocket_conditioned_prop=cfg.train_binder_pocket_conditioned_prop, - binder_pocket_cutoff=cfg.binder_pocket_cutoff, + contact_conditioned_prop=cfg.train_contact_conditioned_prop, + binder_pocket_cutoff_min=cfg.binder_pocket_cutoff_min, + binder_pocket_cutoff_max=cfg.binder_pocket_cutoff_max, binder_pocket_sampling_geometric_p=cfg.binder_pocket_sampling_geometric_p, return_symmetries=cfg.return_train_symmetries, + single_sequence_prop=cfg.single_sequence_prop_training, + msa_sampling=cfg.msa_sampling_training, + use_templates=cfg.use_templates, + max_templates=cfg.max_templates_train, + no_template_prob=cfg.no_template_prob_train, + compute_frames=cfg.compute_frames, + bfactor_md_correction=cfg.bfactor_md_correction, ) self._val_set = ValidationDataset( datasets=train if cfg.overfit is not None else val, + canonicals=canonicals, + moldir=cfg.moldir, seed=cfg.random_seed, max_atoms=cfg.max_atoms, max_tokens=cfg.max_tokens, @@ -600,19 +1063,27 @@ def __init__(self, cfg: DataConfig) -> None: pad_to_max_atoms=cfg.pad_to_max_atoms, pad_to_max_tokens=cfg.pad_to_max_tokens, pad_to_max_seqs=cfg.pad_to_max_seqs, - symmetries=symmetries, atoms_per_window_queries=cfg.atoms_per_window_queries, min_dist=cfg.min_dist, max_dist=cfg.max_dist, num_bins=cfg.num_bins, + num_ensembles=cfg.num_ensembles_val, + ensemble_sample_replacement=False, + disto_use_ensemble=cfg.disto_use_ensemble, + fix_single_ensemble=cfg.fix_single_ensemble, overfit=cfg.overfit, - crop_validation=cfg.crop_validation, return_symmetries=cfg.return_val_symmetries, binder_pocket_conditioned_prop=cfg.val_binder_pocket_conditioned_prop, - binder_pocket_cutoff=cfg.binder_pocket_cutoff, + contact_conditioned_prop=cfg.val_contact_conditioned_prop, + binder_pocket_cutoff=cfg.binder_pocket_cutoff_val, + use_templates=cfg.use_templates, + max_templates=cfg.max_templates_val, + no_template_prob=cfg.no_template_prob_val, + compute_frames=cfg.compute_frames, + bfactor_md_correction=cfg.bfactor_md_correction, ) - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: Optional[str] = None) -> None: # noqa: ARG002 (unused) """Run the setup for the DataModule. Parameters @@ -647,7 +1118,7 @@ def val_dataloader(self) -> DataLoader: Returns ------- DataLoader - The validation dataloader. + The validation dataloader.s """ return DataLoader( diff --git a/src/boltz/data/parse/mmcif.py b/src/boltz/data/parse/mmcif.py index cade3c2db..ff68528a3 100644 --- a/src/boltz/data/parse/mmcif.py +++ b/src/boltz/data/parse/mmcif.py @@ -316,7 +316,7 @@ def compute_covalent_ligands( return covalent_chain_ids -def compute_interfaces(atom_data: np.ndarray, chain_data: np.ndarray) -> np.ndarray: +def find_interfaces(atom_data: np.ndarray, chain_data: np.ndarray) -> np.ndarray: """Compute the chain-chain interfaces from a gemmi structure. Parameters @@ -1204,7 +1204,7 @@ def parse_mmcif( # noqa: C901, PLR0915, PLR0912 # Compute interface chains (find chains with a heavy atom within 5A) if compute_interfaces: - interfaces = compute_interfaces(atoms, chains) + interfaces = find_interfaces(atoms, chains) else: interfaces = np.array([], dtype=Interface) diff --git a/src/boltz/data/sample/cluster.py b/src/boltz/data/sample/cluster.py index fb5c2e66b..e6c657e63 100644 --- a/src/boltz/data/sample/cluster.py +++ b/src/boltz/data/sample/cluster.py @@ -1,11 +1,11 @@ -from typing import Dict, Iterator, List +from collections.abc import Iterator import numpy as np from numpy.random import RandomState from boltz.data import const -from boltz.data.types import ChainInfo, InterfaceInfo, Record from boltz.data.sample.sampler import Sample, Sampler +from boltz.data.types import ChainInfo, InterfaceInfo, Record def get_chain_cluster(chain: ChainInfo, record: Record) -> str: # noqa: ARG001 @@ -58,7 +58,7 @@ def get_interface_cluster(interface: InterfaceInfo, record: Record) -> str: def get_chain_weight( chain: ChainInfo, record: Record, # noqa: ARG001 - clusters: Dict[str, int], + clusters: dict[str, int], beta_chain: float, alpha_prot: float, alpha_nucl: float, @@ -108,7 +108,7 @@ def get_chain_weight( def get_interface_weight( interface: InterfaceInfo, record: Record, - clusters: Dict[str, int], + clusters: dict[str, int], beta_interface: float, alpha_prot: float, alpha_nucl: float, @@ -201,7 +201,11 @@ def __init__( self.beta_chain = beta_chain self.beta_interface = beta_interface - def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: # noqa: C901, PLR0912 + def sample( # noqa: C901, PLR0912 + self, + records: list[Record], + random: RandomState, + ) -> Iterator[Sample]: """Sample a structure from the dataset infinitely. Parameters @@ -218,7 +222,7 @@ def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample] """ # Compute chain cluster sizes - chain_clusters: Dict[str, int] = {} + chain_clusters: dict[str, int] = {} for record in records: for chain in record.chains: if not chain.valid: @@ -229,7 +233,7 @@ def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample] chain_clusters[cluster_id] += 1 # Compute interface clusters sizes - interface_clusters: Dict[str, int] = {} + interface_clusters: dict[str, int] = {} for record in records: for interface in record.interfaces: if not interface.valid: diff --git a/src/boltz/data/sample/distillation.py b/src/boltz/data/sample/distillation.py index 9314f5109..f24823d99 100644 --- a/src/boltz/data/sample/distillation.py +++ b/src/boltz/data/sample/distillation.py @@ -1,9 +1,9 @@ -from typing import Iterator, List +from collections.abc import Iterator from numpy.random import RandomState -from boltz.data.types import Record from boltz.data.sample.sampler import Sample, Sampler +from boltz.data.types import Record class DistillationSampler(Sampler): @@ -23,7 +23,7 @@ def __init__(self, small_size: int = 200, small_prob: float = 0.01) -> None: self._size = small_size self._prob = small_prob - def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: + def sample(self, records: list[Record], random: RandomState) -> Iterator[Sample]: """Sample a structure from the dataset infinitely. Parameters diff --git a/src/boltz/data/sample/random.py b/src/boltz/data/sample/random.py index e2ee2314e..11770c326 100644 --- a/src/boltz/data/sample/random.py +++ b/src/boltz/data/sample/random.py @@ -1,16 +1,16 @@ +from collections.abc import Iterator from dataclasses import replace -from typing import Iterator, List from numpy.random import RandomState -from boltz.data.types import Record from boltz.data.sample.sampler import Sample, Sampler +from boltz.data.types import Record class RandomSampler(Sampler): """A simple random sampler with replacement.""" - def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: + def sample(self, records: list[Record], random: RandomState) -> Iterator[Sample]: """Sample a structure from the dataset infinitely. Parameters diff --git a/src/boltz/data/sample/sampler.py b/src/boltz/data/sample/sampler.py index 6c6ab6dd1..738a371b4 100644 --- a/src/boltz/data/sample/sampler.py +++ b/src/boltz/data/sample/sampler.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod +from collections.abc import Iterator from dataclasses import dataclass -from typing import Iterator, List, Optional +from typing import Optional from numpy.random import RandomState @@ -30,7 +31,7 @@ class Sampler(ABC): """Abstract base class for samplers.""" @abstractmethod - def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: + def sample(self, records: list[Record], random: RandomState) -> Iterator[Sample]: """Sample a structure from the dataset infinitely. Parameters diff --git a/src/boltz/data/sample/v2/__init__.py b/src/boltz/data/sample/v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/boltz/data/sample/v2/cluster.py b/src/boltz/data/sample/v2/cluster.py new file mode 100644 index 000000000..6ee600687 --- /dev/null +++ b/src/boltz/data/sample/v2/cluster.py @@ -0,0 +1,288 @@ +import numpy as np + +from boltz.data import const +from boltz.data.sample.v2.sampler import Sample, Sampler +from boltz.data.types import ChainInfo, InterfaceInfo, Record + + +def get_chain_cluster(chain: ChainInfo, record: Record) -> str: # noqa: ARG001 + """Get the cluster id for a chain. + + Parameters + ---------- + chain : ChainInfo + The chain id to get the cluster id for. + record : Record + The record the interface is part of. + + Returns + ------- + str + The cluster id of the chain. + + """ + return chain.cluster_id + + +def get_interface_cluster(interface: InterfaceInfo, record: Record) -> str: + """Get the cluster id for an interface. + + Parameters + ---------- + interface : InterfaceInfo + The interface to get the cluster id for. + record : Record + The record the interface is part of. + + Returns + ------- + str + The cluster id of the interface. + + """ + chain1 = record.chains[interface.chain_1] + chain2 = record.chains[interface.chain_2] + + cluster_1 = str(chain1.cluster_id) + cluster_2 = str(chain2.cluster_id) + + cluster_id = (cluster_1, cluster_2) + cluster_id = tuple(sorted(cluster_id)) + + return cluster_id + + +def get_chain_weight( + chain: ChainInfo, + record: Record, # noqa: ARG001 + clusters: dict[str, int], + beta_chain: float, + alpha_prot: float, + alpha_nucl: float, + alpha_ligand: float, +) -> float: + """Get the weight of a chain. + + Parameters + ---------- + chain : ChainInfo + The chain to get the weight for. + record : Record + The record the chain is part of. + clusters : Dict[str, int] + The cluster sizes. + beta_chain : float + The beta value for chains. + alpha_prot : float + The alpha value for proteins. + alpha_nucl : float + The alpha value for nucleic acids. + alpha_ligand : float + The alpha value for ligands. + + Returns + ------- + float + The weight of the chain. + + """ + prot_id = const.chain_type_ids["PROTEIN"] + rna_id = const.chain_type_ids["RNA"] + dna_id = const.chain_type_ids["DNA"] + ligand_id = const.chain_type_ids["NONPOLYMER"] + + weight = beta_chain / clusters[chain.cluster_id] + if chain.mol_type == prot_id: + weight *= alpha_prot + elif chain.mol_type in [rna_id, dna_id]: + weight *= alpha_nucl + elif chain.mol_type == ligand_id: + weight *= alpha_ligand + + return weight + + +def get_interface_weight( + interface: InterfaceInfo, + record: Record, + clusters: dict[str, int], + beta_interface: float, + alpha_prot: float, + alpha_nucl: float, + alpha_ligand: float, +) -> float: + """Get the weight of an interface. + + Parameters + ---------- + interface : InterfaceInfo + The interface to get the weight for. + record : Record + The record the interface is part of. + clusters : Dict[str, int] + The cluster sizes. + beta_interface : float + The beta value for interfaces. + alpha_prot : float + The alpha value for proteins. + alpha_nucl : float + The alpha value for nucleic acids. + alpha_ligand : float + The alpha value for ligands. + + Returns + ------- + float + The weight of the interface. + + """ + prot_id = const.chain_type_ids["PROTEIN"] + rna_id = const.chain_type_ids["RNA"] + dna_id = const.chain_type_ids["DNA"] + ligand_id = const.chain_type_ids["NONPOLYMER"] + + chain1 = record.chains[interface.chain_1] + chain2 = record.chains[interface.chain_2] + + n_prot = (chain1.mol_type) == prot_id + n_nuc = chain1.mol_type in [rna_id, dna_id] + n_ligand = chain1.mol_type == ligand_id + + n_prot += chain2.mol_type == prot_id + n_nuc += chain2.mol_type in [rna_id, dna_id] + n_ligand += chain2.mol_type == ligand_id + + weight = beta_interface / clusters[get_interface_cluster(interface, record)] + weight *= alpha_prot * n_prot + alpha_nucl * n_nuc + alpha_ligand * n_ligand + return weight + + +class ClusterSampler(Sampler): + """The weighted sampling approach, as described in AF3. + + Each chain / interface is given a weight according + to the following formula, and sampled accordingly: + + w = b / n_clust *(a_prot * n_prot + a_nuc * n_nuc + + a_ligand * n_ligand) + + """ + + def __init__( + self, + alpha_prot: float = 3.0, + alpha_nucl: float = 3.0, + alpha_ligand: float = 1.0, + beta_chain: float = 0.5, + beta_interface: float = 1.0, + ) -> None: + """Initialize the sampler. + + Parameters + ---------- + alpha_prot : float, optional + The alpha value for proteins. + alpha_nucl : float, optional + The alpha value for nucleic acids. + alpha_ligand : float, optional + The alpha value for ligands. + beta_chain : float, optional + The beta value for chains. + beta_interface : float, optional + The beta value for interfaces. + + """ + self.alpha_prot = alpha_prot + self.alpha_nucl = alpha_nucl + self.alpha_ligand = alpha_ligand + self.beta_chain = beta_chain + self.beta_interface = beta_interface + + def sample(self, records: list[Record]) -> list[Sample]: + """Sample a structure from the dataset infinitely. + + Parameters + ---------- + records : List[Record] + The records to sample from. + + Returns + ------- + List[Sample] + The samples. + + """ + # Compute chain cluster sizes + chain_clusters: dict[str, int] = {} + for record in records: + for chain in record.chains: + if not chain.valid: + continue + cluster_id = get_chain_cluster(chain, record) + if cluster_id not in chain_clusters: + chain_clusters[cluster_id] = 0 + chain_clusters[cluster_id] += 1 + + # Compute interface clusters sizes + interface_clusters: dict[str, int] = {} + for record in records: + for interface in record.interfaces: + if not interface.valid: + continue + cluster_id = get_interface_cluster(interface, record) + if cluster_id not in interface_clusters: + interface_clusters[cluster_id] = 0 + interface_clusters[cluster_id] += 1 + + # Compute weights + chain_samples, chain_weights = [], [] + int_samples, int_weights = [], [] + + for record in records: + for chain_id, chain in enumerate(record.chains): + if not chain.valid: + continue + weight = get_chain_weight( + chain, + record, + chain_clusters, + self.beta_chain, + self.alpha_prot, + self.alpha_nucl, + self.alpha_ligand, + ) + chain_samples.append((record.id, chain_id)) + chain_weights.append(weight) + + for int_id, interface in enumerate(record.interfaces): + if not interface.valid: + continue + weight = get_interface_weight( + interface, + record, + interface_clusters, + self.beta_interface, + self.alpha_prot, + self.alpha_nucl, + self.alpha_ligand, + ) + int_samples.append((record.id, int_id)) + int_weights.append(weight) + + # Normalize weights + weights_sum = np.sum(chain_weights) + np.sum(int_weights) + chain_weights = np.array(chain_weights) / weights_sum + int_weights = np.array(int_weights) / weights_sum + + # Create samples + chain_samples = [ + Sample(record_id=s[0], chain_id=s[1], weight=w) + for s, w in zip(chain_samples, chain_weights) + ] + int_samples = [ + Sample(record_id=s[0], interface_id=s[1], weight=w) + for s, w in zip(int_samples, int_weights) + ] + + samples = chain_samples + int_samples + return samples diff --git a/src/boltz/data/sample/v2/distillation.py b/src/boltz/data/sample/v2/distillation.py new file mode 100644 index 000000000..b4912d6d2 --- /dev/null +++ b/src/boltz/data/sample/v2/distillation.py @@ -0,0 +1,53 @@ +from boltz.data.sample.v2.sampler import Sample, Sampler +from boltz.data.types import Record + + +class DistillationSampler(Sampler): + """A sampler for monomer distillation data.""" + + def __init__(self, small_size: int = 200, small_prob: float = 0.01) -> None: + """Initialize the sampler. + + Parameters + ---------- + small_size : int, optional + The maximum size to be considered small. + small_prob : float, optional + The probability of sampling a small item. + + """ + self._size = small_size + self._prob = small_prob + + def sample(self, records: list[Record]) -> list[Sample]: + """Sample a structure from the dataset infinitely. + + Parameters + ---------- + records : List[Record] + The records to sample from. + + Returns + ------- + List[Sample] + The samples. + + """ + # Remove records with invalid chains + records = [r for r in records if r.chains[0].valid] + + # Split in small and large proteins. We assume that there is only + # one chain per record, as is the case for monomer distillation + small = [r for r in records if r.chains[0].num_residues <= self._size] + large = [r for r in records if r.chains[0].num_residues > self._size] + + # Assign uniform weights to the proteins, with prob amount of small + weights = [self._prob / len(small)] * len(small) + weights += [(1 - self._prob) / len(large)] * len(large) + + # Create samples + samples = [ + Sample(record_id=r.id, chain_id=0, weight=w) + for r, w in zip(records, weights) + ] + return samples diff --git a/src/boltz/data/sample/v2/random.py b/src/boltz/data/sample/v2/random.py new file mode 100644 index 000000000..25e7c105e --- /dev/null +++ b/src/boltz/data/sample/v2/random.py @@ -0,0 +1,38 @@ +from collections.abc import Iterator +from typing import Optional + +import numpy as np +from numpy.random import RandomState + +from boltz.data.sample.v2.sampler import Sampler +from boltz.data.types import Record + + +class RandomSampler(Sampler): + """A simple random sampler with replacement.""" + + def sample( + self, records: list[Record], random: Optional[RandomState] = None + ) -> Iterator[Record]: + """Sample a structure from the dataset infinitely. + + Parameters + ---------- + records : List[Record] + The records to sample from. + + random : Optional[RandomState] + Random state for reproducibility. If None, uses np.random. + + Returns + ------- + List[Sample] + The samples. + + """ + if random is None: + random = np.random.default_rng() + + while True: + item_idx = random.choice(len(records)) + yield records[item_idx] diff --git a/src/boltz/data/sample/v2/sampler.py b/src/boltz/data/sample/v2/sampler.py new file mode 100644 index 000000000..0b85320f3 --- /dev/null +++ b/src/boltz/data/sample/v2/sampler.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional + +from boltz.data.types import Record + + +@dataclass +class Sample: + """A sample with optional chain and interface IDs. + + Attributes + ---------- + record : Record + The record. + chain_id : Optional[int] + The chain ID. + interface_id : Optional[int] + The interface ID. + """ + + record_id: str + chain_id: Optional[int] = None + interface_id: Optional[int] = None + weight: Optional[float] = None + + +class Sampler(ABC): + """Abstract base class for samplers.""" + + @abstractmethod + def sample(self, records: list[Record]) -> list[Sample]: + """Sample a structure from the dataset infinitely. + + Parameters + ---------- + records : List[Record] + The records to sample from. + + Returns + ------- + List[Sample] + The samples. + + """ + raise NotImplementedError diff --git a/src/boltz/data/template/__init__.py b/src/boltz/data/template/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/boltz/data/template/feature.py b/src/boltz/data/template/feature.py new file mode 100644 index 000000000..b67148f65 --- /dev/null +++ b/src/boltz/data/template/feature.py @@ -0,0 +1,75 @@ +import numpy as np +import torch +from torch.nn.functional import one_hot + +from boltz.data import const +from boltz.data.types import Tokenized + + +def compute_template_features( + query_tokens: Tokenized, + tmpl_tokens: list[dict], + num_tokens: int, +) -> dict: + """Compute the template features.""" + # Allocate features + res_type = np.zeros((num_tokens,), dtype=np.int64) + frame_rot = np.zeros((num_tokens, 3, 3), dtype=np.float32) + frame_t = np.zeros((num_tokens, 3), dtype=np.float32) + cb_coords = np.zeros((num_tokens, 3), dtype=np.float32) + ca_coords = np.zeros((num_tokens, 3), dtype=np.float32) + frame_mask = np.zeros((num_tokens,), dtype=np.float32) + cb_mask = np.zeros((num_tokens,), dtype=np.float32) + template_mask = np.zeros((num_tokens,), dtype=np.float32) + query_to_template = np.zeros((num_tokens,), dtype=np.int64) + visibility_ids = np.zeros((num_tokens,), dtype=np.float32) + + # Now create features per token + asym_id_to_pdb_id = {} + + for token_dict in tmpl_tokens: + idx = token_dict["q_idx"] + monomeric = token_dict["is_monomeric"] + pdb_id = token_dict["pdb_id"] + token = token_dict["token"] + query_token = query_tokens.tokens[idx] + if not monomeric: + asym_id_to_pdb_id[query_token["asym_id"]] = pdb_id + + res_type[idx] = token["res_type"] + frame_rot[idx] = token["frame_rot"].reshape(3, 3) + frame_t[idx] = token["frame_t"] + cb_coords[idx] = token["disto_coords"] + ca_coords[idx] = token["center_coords"] + cb_mask[idx] = token["disto_mask"] + frame_mask[idx] = token["frame_mask"] + template_mask[idx] = 1.0 + + # Set visibility_id for templated chains + for asym_id, pdb_id in asym_id_to_pdb_id.items(): + indices = (query_tokens.tokens["asym_id"] == asym_id).nonzero() + visibility_ids[indices] = pdb_id + + # Set visibility for non templated chain + olygomerics + for asym_id in np.unique(query_tokens.structure.chains["asym_id"]): + if asym_id not in asym_id_to_pdb_id: + # We hack the chain id to be negative to not overlap with the above + indices = (query_tokens.tokens["asym_id"] == asym_id).nonzero() + visibility_ids[indices] = -1 - asym_id + + # Convert to one-hot + res_type = torch.from_numpy(res_type) + res_type = one_hot(res_type, num_classes=const.num_tokens) + + return { + "template_restype": res_type, + "template_frame_rot": torch.from_numpy(frame_rot), + "template_frame_t": torch.from_numpy(frame_t), + "template_cb": torch.from_numpy(cb_coords), + "template_ca": torch.from_numpy(ca_coords), + "template_mask_cb": torch.from_numpy(cb_mask), + "template_mask_frame": torch.from_numpy(frame_mask), + "template_mask": torch.from_numpy(template_mask), + "query_to_template": torch.from_numpy(query_to_template), + "visibility_ids": torch.from_numpy(visibility_ids), + } diff --git a/src/boltz/data/tokenize/boltz2.py b/src/boltz/data/tokenize/boltz2.py index 7371d6ae7..55eece57c 100644 --- a/src/boltz/data/tokenize/boltz2.py +++ b/src/boltz/data/tokenize/boltz2.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import astuple, dataclass from typing import Optional import numpy as np @@ -11,6 +11,7 @@ StructureV2, TokenBondV2, Tokenized, + TokenizedTraining, TokenV2, ) @@ -129,7 +130,7 @@ def get_unk_token(chain: np.ndarray) -> int: return res_id -def tokenize_structure( # noqa: C901, PLR0915 +def tokenize_structure( # noqa: PLR0915 struct: StructureV2, affinity: Optional[AffinityInfo] = None, ) -> tuple[np.ndarray, np.ndarray]: @@ -424,3 +425,227 @@ def tokenize(self, data: Input) -> Tokenized: extra_mols=data.extra_mols, ) return tokenized + + +class Boltz2TrainingTokenizer(Tokenizer): + """Tokenize an input structure for training.""" + + def __init__(self, atomize_modified_residues: bool = False) -> None: + """Initialize the AF3Tokenizer. + + Parameters + ---------- + atomize_modified_residues : bool + Whether to atomize modified residues. + map_to_closest_residue : bool + Whether to map modified residues to the closest residue. + + """ + self.atomize_modified_residues = atomize_modified_residues + + def tokenize( + self, + struct: StructureV2, + training: bool = False, + res_name_to_overwrite: Optional[dict] = None, + ) -> Tokenized: # noqa: C901, PLR0915 + """Tokenize the input data. + + Parameters + ---------- + struct : Structure + The input structure. + training: bool + Whether we are at training or inference time + + + Returns + ------- + Tokenized + The tokenized data. + + """ + # Create token data + token_data = [] + + # Keep track of atom_idx to token_idx + token_idx = 0 + atom_to_token = {} + + # Filter to valid chains only + chains = struct.chains[struct.mask] + + # Ensemble atom id start in coords table. + # For cropper and other operations, harcoded to 0th conformer. + offset = struct.ensemble[0]["atom_coord_idx"] + + for chain in chains: + # Get residue indices + res_start = chain["res_idx"] + res_end = chain["res_idx"] + chain["res_num"] + is_protein = chain["mol_type"] == const.chain_type_ids["PROTEIN"] + + for res in struct.residues[res_start:res_end]: + # Get atom indices + atom_start = res["atom_idx"] + atom_end = res["atom_idx"] + res["atom_num"] + + # Standard residues are tokens + if res["is_standard"]: + # Get center and disto atoms + center = struct.atoms[res["atom_center"]] + disto = struct.atoms[res["atom_disto"]] + + # Token is present if centers are + is_present = res["is_present"] & center["is_present"] + is_disto_present = res["is_present"] & disto["is_present"] + + # Apply chain transformation + c_coords = struct.coords[offset + res["atom_center"]]["coords"] + d_coords = struct.coords[offset + res["atom_disto"]]["coords"] + + # If protein, compute frame, only used for templates + frame_rot = np.eye(3).flatten() + frame_t = np.zeros(3) + frame_mask = False + + if is_protein: + # Get frame atoms + atom_st = res["atom_idx"] + atom_en = res["atom_idx"] + res["atom_num"] + atoms = struct.atoms[atom_st:atom_en] + + # Atoms are always in the order N, CA, C + atom_n = atoms[0] + atom_ca = atoms[1] + atom_c = atoms[2] + + # Compute frame and mask + frame_mask = atom_ca["is_present"] + frame_mask &= atom_c["is_present"] + frame_mask &= atom_n["is_present"] + frame_mask = bool(frame_mask) + if frame_mask: + frame_rot, frame_t = compute_frame( + atom_n["coords"], + atom_ca["coords"], + atom_c["coords"], + ) + frame_rot = frame_rot.flatten() + + # Create token + token = TokenData( + token_idx=token_idx, + atom_idx=res["atom_idx"], + atom_num=res["atom_num"], + res_idx=res["res_idx"], + res_type=res["res_type"], + res_name=( + res_name_to_overwrite.get(res["name"], res["name"]) + if res_name_to_overwrite + else res["name"] + ), + sym_id=chain["sym_id"], + asym_id=chain["asym_id"], + entity_id=chain["entity_id"], + mol_type=chain["mol_type"], + center_idx=res["atom_center"], + disto_idx=res["atom_disto"], + center_coords=c_coords, + disto_coords=d_coords, + resolved_mask=is_present, + disto_mask=is_disto_present, + modified=False, + frame_rot=frame_rot, + frame_t=frame_t, + frame_mask=frame_mask, + cyclic_period=0, + ) + token_data.append(astuple(token)) + + # Update atom_idx to token_idx + for atom_idx in range(atom_start, atom_end): + atom_to_token[atom_idx] = token_idx + + token_idx += 1 + + # Non-standard are tokenized per atom + elif ( + chain["mol_type"] == const.chain_type_ids["NONPOLYMER"] + or self.atomize_modified_residues + ): + # We use the unk protein token as res_type + unk_token = const.unk_token["PROTEIN"] + unk_id = const.token_ids[unk_token] + + # Get atom coordinates + atom_data = struct.atoms[atom_start:atom_end] + atom_coords = struct.coords[ + offset + atom_start : offset + atom_end + ]["coords"] + + # Tokenize each atom + for i, atom in enumerate(atom_data): + # Token is present if atom is + is_present = res["is_present"] & atom["is_present"] + index = atom_start + i + + # Create token + token = TokenData( + token_idx=token_idx, + atom_idx=index, + atom_num=1, + res_idx=res["res_idx"], + res_type=unk_id, + res_name=( + res_name_to_overwrite.get(res["name"], res["name"]) + if res_name_to_overwrite + else res["name"] + ), + sym_id=chain["sym_id"], + asym_id=chain["asym_id"], + entity_id=chain["entity_id"], + mol_type=chain["mol_type"], + center_idx=index, + disto_idx=index, + center_coords=atom_coords[i], + disto_coords=atom_coords[i], + resolved_mask=is_present, + disto_mask=is_present, + modified=chain["mol_type"] + != const.chain_type_ids["NONPOLYMER"], + frame_rot=np.eye(3).flatten(), + frame_t=np.zeros(3), + frame_mask=False, + cyclic_period=0, + ) + token_data.append(astuple(token)) + + # Update atom_idx to token_idx + atom_to_token[index] = token_idx + token_idx += 1 + + # Create token bonds + token_bonds = [] + + # Add bonds for ligands + for bond in struct.bonds: + if ( + bond["atom_1"] not in atom_to_token + or bond["atom_2"] not in atom_to_token + ): + continue + token_bond = ( + atom_to_token[bond["atom_1"]], + atom_to_token[bond["atom_2"]], + bond["type"] + 1, + ) + token_bonds.append(token_bond) + + # Consider adding missing bond for modified residues to standard? + # I'm not sure it's necessary because the bond is probably always + # the same and the model can use the residue indices to infer it + token_data = np.array(token_data, dtype=TokenV2) + token_bonds = np.array(token_bonds, dtype=TokenBondV2) + tokenized = TokenizedTraining(token_data, token_bonds, struct) + return tokenized diff --git a/src/boltz/data/types.py b/src/boltz/data/types.py index 1ce26b558..ca330a718 100644 --- a/src/boltz/data/types.py +++ b/src/boltz/data/types.py @@ -506,6 +506,7 @@ class ChainInfo: num_residues: int valid: bool = True entity_id: Optional[Union[str, int]] = None + template_ids: Optional[list[Union[str, int]]] = None @dataclass(frozen=True) @@ -690,6 +691,30 @@ def load(cls: "JSONSerializable", path: Path) -> "JSONSerializable": return manifest +#################################################################################################### +# TEMPLATE +#################################################################################################### + +TemplateCoordinates = [ + ("res_idx", np.dtype("i4")), + ("res_type", np.dtype("i1")), + ("frame_rot", np.dtype("9f4")), + ("frame_t", np.dtype("3f4")), + ("coords_cb", np.dtype("3f4")), + ("coords_ca", np.dtype("3f4")), + ("mask_frame", np.dtype("?")), + ("mask_cb", np.dtype("?")), + ("mask_ca", np.dtype("?")), +] + + +@dataclass(frozen=True, slots=True) +class Template(NumpySerializable): + """Template datatype.""" + + coordinates: np.ndarray + + #################################################################################################### # INPUT #################################################################################################### @@ -782,3 +807,24 @@ class Tokenized: template_tokens: Optional[dict[str, np.ndarray]] = None template_bonds: Optional[dict[str, np.ndarray]] = None extra_mols: Optional[dict[str, Mol]] = None + + +@dataclass(frozen=True) +class TokenizedTraining: + """Tokenized datatype.""" + + tokens: np.ndarray + bonds: np.ndarray + structure: Structure + + +@dataclass(frozen=True, slots=True) +class InputTraining: + """Input datatype.""" + + tokens: np.ndarray + bonds: np.ndarray + structure: Structure + msa: dict[str, MSA] + templates: dict[str, list[Template]] + record: Optional[Record] = None diff --git a/src/boltz/main.py b/src/boltz/main.py index 4a3750fec..ba28220fe 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -1,7 +1,6 @@ import multiprocessing import os import pickle -import platform import tarfile import urllib.request import warnings @@ -450,22 +449,19 @@ def compute_msa( click.echo(f"Calling MSA server for target {target_id} with {len(data)} sequences") click.echo(f"MSA server URL: {msa_server_url}") click.echo(f"MSA pairing strategy: {msa_pairing_strategy}") - + # Construct auth headers if API key header/value is provided auth_headers = None if api_key_value: key = api_key_header if api_key_header else "X-API-Key" value = api_key_value - auth_headers = { - "Content-Type": "application/json", - key: value - } + auth_headers = {"Content-Type": "application/json", key: value} click.echo(f"Using API key authentication for MSA server (header: {key})") elif msa_server_username and msa_server_password: click.echo("Using basic authentication for MSA server") else: click.echo("No authentication provided for MSA server") - + if len(data) > 1: paired_msas = run_mmseqs2( list(data.values()), @@ -714,7 +710,7 @@ def process_inputs( # Validate mutually exclusive authentication methods has_basic_auth = msa_server_username and msa_server_password has_api_key = api_key_value is not None - + if has_basic_auth and has_api_key: raise ValueError( "Cannot use both basic authentication (--msa_server_username/--msa_server_password) " @@ -1119,7 +1115,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 msa_server_password = os.environ.get("BOLTZ_MSA_PASSWORD") if api_key_value is None: api_key_value = os.environ.get("MSA_API_KEY_VALUE") - + click.echo(f"MSA server enabled: {msa_server_url}") if api_key_value: click.echo("MSA server authentication: using API key header") @@ -1212,8 +1208,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 if (isinstance(devices, int) and devices > 1) or ( isinstance(devices, list) and len(devices) > 1 ): - start_method = "fork" if platform.system() != "win32" and platform.system() != "Windows" else "spawn" - strategy = DDPStrategy(start_method=start_method) + strategy = DDPStrategy() if len(filtered_manifest.records) < devices: msg = ( "Number of requested devices is greater " @@ -1387,7 +1382,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 steering_args.fk_steering = False steering_args.physical_guidance_update = False steering_args.contact_guidance_update = False - + model_module = Boltz2.load_from_checkpoint( affinity_checkpoint, strict=True, diff --git a/src/boltz/model/loss/inference.py b/src/boltz/model/loss/inference.py new file mode 100644 index 000000000..252c7587a --- /dev/null +++ b/src/boltz/model/loss/inference.py @@ -0,0 +1,445 @@ +import torch + +from boltz.data import const + + +def compute_chain_clashes(pred_atom_coords, feats, clash_buffer=0.4): + chain_id = feats["asym_id"] + with torch.autocast("cuda", enabled=False): + atom_chain_id = ( + torch.bmm(feats["atom_to_token"].float(), chain_id.unsqueeze(-1).float()) + .squeeze(-1) + .long() + ) + + vdw_radii = torch.zeros( + const.num_elements, dtype=torch.float32, device=pred_atom_coords.device + ) + vdw_radii[1:119] = torch.tensor( + const.vdw_radii, dtype=torch.float32, device=pred_atom_coords.device + ) + atom_vdw_radii = (feats["ref_element"].float() @ vdw_radii.unsqueeze(-1)).squeeze( + -1 + ) + + dists = torch.cdist(pred_atom_coords, pred_atom_coords) + clashes = ( + dists + < (atom_vdw_radii.unsqueeze(-1) + atom_vdw_radii.unsqueeze(-2)) - clash_buffer + ) + + multiplicity = pred_atom_coords.shape[0] + num_clashes, num_pairs = {}, {} + for key in const.out_single_types: + num_clashes["sym_" + key] = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + num_pairs["sym_" + key] = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + for key in const.clash_types: + num_clashes["asym_" + key] = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + num_pairs["asym_" + key] = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + + for batch_idx in range(feats["atom_pad_mask"].shape[0]): # TODO: Batch size > 1 + pair_pad_mask = ( + feats["atom_pad_mask"][batch_idx, :, None] + * feats["atom_pad_mask"][batch_idx, None, :] + ).bool() + if feats["connections_edge_index"][batch_idx].shape[1] > 0: + pair_pad_mask[ + feats["connections_edge_index"][batch_idx][0], + feats["connections_edge_index"][batch_idx][1], + ] = False + pair_pad_mask[ + feats["connections_edge_index"][batch_idx][1], + feats["connections_edge_index"][batch_idx][0], + ] = False + chain_symmetries = feats["chain_symmetries"][batch_idx] + chain_id_to_symmetry = {} + chain_id_to_type = {} + for idx, symmetry in enumerate(chain_symmetries): + for chain in symmetry: + chain_id_to_symmetry[chain[0]] = idx + chain_id_to_type[chain[0]] = chain[4] + for i in chain_id_to_symmetry: + for j in chain_id_to_symmetry: + type1, type2 = ( + const.chain_types[chain_id_to_type[i]], + const.chain_types[chain_id_to_type[j]], + ) + if i >= j: + continue + chain_pair_mask = ( + pair_pad_mask + * (atom_chain_id[batch_idx] == i).unsqueeze(-1) + * (atom_chain_id[batch_idx] == j).unsqueeze(-2) + ) + chain_pair_clashes = clashes[:, chain_pair_mask].any(dim=-1) + if chain_id_to_symmetry[i] == chain_id_to_symmetry[j]: + num_clashes[ + "sym_" + const.chain_type_to_out_single_type[type1] + ] += chain_pair_clashes.float() + num_pairs["sym_" + const.chain_type_to_out_single_type[type1]] += 1 + else: + num_clashes[ + "asym_" + + const.chain_types_to_clash_type[frozenset((type1, type2))] + ] += chain_pair_clashes.float() + num_pairs[ + "asym_" + + const.chain_types_to_clash_type[frozenset((type1, type2))] + ] += 1 + + for key in num_clashes: + if num_pairs[key].sum() > 0: + num_clashes[key] /= num_pairs[key] + + return num_clashes, num_pairs + + +def compute_pb_geometry_metrics( + pred_atom_coords, feats, bond_buffer=0.25, angle_buffer=0.25, clash_buffer=0.3 +): + with torch.autocast("cuda", enabled=False): + chain_id = feats["asym_id"] + atom_chain_id = ( + torch.bmm(feats["atom_to_token"].float(), chain_id.unsqueeze(-1).float()) + .squeeze(-1) + .long() + ) + is_ligand_mask = ( + torch.bmm( + feats["atom_to_token"].float(), feats["mol_type"].unsqueeze(-1).float() + ) + .squeeze(-1) + .long() + == const.chain_type_ids["NONPOLYMER"] + ).float() + + multiplicity = pred_atom_coords.shape[0] + num_bond_length_failures = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + num_bond_angle_failures = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + num_internal_clash_failures = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + num_ligands = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + + for index_batch in range(len(feats["ligand_edge_index"])): + if feats["ligand_edge_index"][index_batch].shape[1] == 0: + continue + dists = torch.linalg.norm( + pred_atom_coords[:, feats["ligand_edge_index"][index_batch][0]] + - pred_atom_coords[:, feats["ligand_edge_index"][index_batch][1]], + dim=-1, + ) + + bond_length_violations = ( + ( + dists + < feats["ligand_edge_lower_bounds"][index_batch] * (1.0 - bond_buffer) + ) + + ( + dists + > feats["ligand_edge_upper_bounds"][index_batch] * (1.0 + bond_buffer) + ) + )[:, feats["ligand_edge_bond_mask"][index_batch]].float() + + bond_angle_violations = ( + ( + dists + < feats["ligand_edge_lower_bounds"][index_batch] * (1.0 - angle_buffer) + ) + + ( + dists + > feats["ligand_edge_upper_bounds"][index_batch] * (1.0 + angle_buffer) + ) + )[:, feats["ligand_edge_angle_mask"][index_batch]].float() + + internal_clash_violations = ( + dists + < feats["ligand_edge_lower_bounds"][index_batch] * (1.0 - clash_buffer) + )[ + :, + ~( + feats["ligand_edge_bond_mask"][index_batch] + + feats["ligand_edge_angle_mask"][index_batch] + ), + ].float() + + edge_chain_ids = atom_chain_id[index_batch][ + feats["ligand_edge_index"][index_batch][0] + ] + bond_chain_ids = edge_chain_ids[feats["ligand_edge_bond_mask"][index_batch]] + angle_chain_ids = edge_chain_ids[feats["ligand_edge_angle_mask"][index_batch]] + internal_clash_chain_ids = edge_chain_ids[ + ~( + feats["ligand_edge_bond_mask"][index_batch] + + feats["ligand_edge_angle_mask"][index_batch] + ) + ] + + num_bond_length_failures += ( + torch.zeros( + (multiplicity, chain_id.max().item() + 1), + dtype=torch.float32, + device=dists.device, + ) + .scatter_reduce( + 1, + bond_chain_ids.expand((multiplicity, -1)), + bond_length_violations, + reduce="amax", + ) + .sum(dim=-1) + ) + + num_bond_angle_failures += ( + torch.zeros( + (multiplicity, chain_id.max().item() + 1), + dtype=torch.float32, + device=dists.device, + ) + .scatter_reduce( + 1, + angle_chain_ids.expand((multiplicity, -1)), + bond_angle_violations, + reduce="amax", + ) + .sum(dim=-1) + ) + + num_internal_clash_failures += ( + torch.zeros( + (multiplicity, chain_id.max().item() + 1), + dtype=torch.float32, + device=dists.device, + ) + .scatter_reduce( + 1, + internal_clash_chain_ids.expand((multiplicity, -1)), + internal_clash_violations, + reduce="amax", + ) + .sum(dim=-1) + ) + + num_ligands += ( + torch.zeros( + (chain_id.max().item() + 1,), dtype=torch.float32, device=dists.device + ) + .scatter_reduce( + 0, + edge_chain_ids, + torch.ones( + edge_chain_ids.shape, dtype=torch.float32, device=dists.device + ), + reduce="amax", + ) + .sum() + ) + + num_bond_length_failures[num_ligands > 0] /= num_ligands[num_ligands > 0] + num_bond_angle_failures[num_ligands > 0] /= num_ligands[num_ligands > 0] + num_internal_clash_failures[num_ligands > 0] /= num_ligands[num_ligands > 0] + + return ( + num_bond_length_failures, + num_bond_angle_failures, + num_internal_clash_failures, + num_ligands, + ) + + +def compute_torsion_angles(coords, torsion_index): + r_ij = coords.index_select(-2, torsion_index[0]) - coords.index_select( + -2, torsion_index[1] + ) + r_kj = coords.index_select(-2, torsion_index[2]) - coords.index_select( + -2, torsion_index[1] + ) + r_kl = coords.index_select(-2, torsion_index[2]) - coords.index_select( + -2, torsion_index[3] + ) + + n_ijk = torch.cross(r_ij, r_kj, dim=-1) + n_jkl = torch.cross(r_kj, r_kl, dim=-1) + + r_kj_norm = torch.linalg.norm(r_kj, dim=-1) + n_ijk_norm = torch.linalg.norm(n_ijk, dim=-1) + n_jkl_norm = torch.linalg.norm(n_jkl, dim=-1) + + sign_phi = torch.sign( + r_kj.unsqueeze(-2) @ torch.cross(n_ijk, n_jkl, dim=-1).unsqueeze(-1) + ).squeeze(-1, -2) + phi = sign_phi * torch.arccos( + torch.clamp( + (n_ijk.unsqueeze(-2) @ n_jkl.unsqueeze(-1)).squeeze(-1, -2) + / (n_ijk_norm * n_jkl_norm), + -1 + 1e-8, + 1 - 1e-8, + ) + ) + return phi + + +def compute_stereo_metrics(pred_atom_coords, feats): + multiplicity = pred_atom_coords.shape[0] + num_chiral_atom_violations = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + num_chiral_atoms = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + num_stereo_bond_violations = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + num_stereo_bonds = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + + for index_batch in range(len(feats["ligand_edge_index"])): + if feats["ligand_chiral_atom_index"][index_batch].shape[1] > 0: + pred_chiral_torsion_angles = compute_torsion_angles( + pred_atom_coords, + feats["ligand_chiral_atom_index"][index_batch][ + :, feats["ligand_chiral_check_mask"][index_batch].bool() + ], + ) + pred_chiral_atom_orientations = pred_chiral_torsion_angles > 0 + true_chiral_atom_orientations = feats["ligand_chiral_atom_orientations"][ + index_batch + ][feats["ligand_chiral_check_mask"][index_batch].bool()] + num_chiral_atom_violations += ( + pred_chiral_atom_orientations != true_chiral_atom_orientations + ).sum(dim=-1) + num_chiral_atoms += true_chiral_atom_orientations.shape[0] + + if feats["ligand_stereo_bond_index"][index_batch].shape[1] > 0: + pred_stereo_torsion_angles = compute_torsion_angles( + pred_atom_coords, + feats["ligand_stereo_bond_index"][index_batch][ + :, feats["ligand_stereo_check_mask"][index_batch].bool() + ], + ) + pred_stereo_bond_orientations = ( + torch.abs(pred_stereo_torsion_angles) > torch.pi / 2 + ) + true_stereo_bond_orientations = feats["ligand_stereo_bond_orientations"][ + index_batch + ][feats["ligand_stereo_check_mask"][index_batch].bool()] + num_stereo_bond_violations += ( + pred_stereo_bond_orientations != true_stereo_bond_orientations + ).sum(dim=-1) + num_stereo_bonds += true_stereo_bond_orientations.shape[0] + + num_chiral_atom_violations[num_chiral_atoms > 0] /= num_chiral_atoms[ + num_chiral_atoms > 0 + ] + num_stereo_bond_violations[num_stereo_bonds > 0] /= num_stereo_bonds[ + num_stereo_bonds > 0 + ] + return ( + num_chiral_atom_violations, + num_chiral_atoms, + num_stereo_bond_violations, + num_stereo_bonds, + ) + + +def compute_pb_flatness_metrics(pred_atom_coords, feats, buffer=0.25): + multiplicity = pred_atom_coords.shape[0] + num_aromatic_5_violations = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + num_aromatic_5_rings = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + num_aromatic_6_violations = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + num_aromatic_6_rings = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + num_double_bond_violations = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + num_double_bonds = torch.zeros( + multiplicity, dtype=torch.float32, device=pred_atom_coords.device + ) + + for index_batch in range(len(feats["ligand_aromatic_5_ring_index"])): + ring_5_index = feats["ligand_aromatic_5_ring_index"][index_batch].T + ring_6_index = feats["ligand_aromatic_6_ring_index"][index_batch].T + double_bond_index = feats["ligand_planar_double_bond_index"][index_batch].T + + ring_5_coords = pred_atom_coords[..., ring_5_index, :] + ring_6_coords = pred_atom_coords[..., ring_6_index, :] + double_bond_coords = pred_atom_coords[..., double_bond_index, :] + + centered_ring_5_coords = ring_5_coords - ring_5_coords.mean( + dim=-2, keepdims=True + ) + ring_5_vecs = torch.linalg.svd(centered_ring_5_coords)[2][..., -1, :, None] + ring_5_dists = torch.abs( + (centered_ring_5_coords @ ring_5_vecs).squeeze(dim=(-1, -2)) + ) + num_aromatic_5_violations += torch.any(ring_5_dists > buffer, dim=-1).sum( + dim=-1 + ) + num_aromatic_5_rings += ring_5_index.shape[0] + + centered_ring_6_coords = ring_6_coords - ring_6_coords.mean( + dim=-2, keepdims=True + ) + ring_6_vecs = torch.linalg.svd(centered_ring_6_coords)[2][..., -1, :, None] + ring_6_dists = torch.abs( + (centered_ring_6_coords @ ring_6_vecs).squeeze(dim=(-1, -2)) + ) + num_aromatic_6_violations += torch.any(ring_6_dists > buffer, dim=-1).sum( + dim=-1 + ) + num_aromatic_6_rings += ring_6_index.shape[0] + + centered_double_bond_coords = double_bond_coords - double_bond_coords.mean( + dim=-2, keepdims=True + ) + double_bond_vecs = torch.linalg.svd(centered_double_bond_coords)[2][ + ..., -1, :, None + ] + double_bond_dists = torch.abs( + (centered_double_bond_coords @ double_bond_vecs).squeeze(dim=(-1, -2)) + ) + num_double_bond_violations += torch.any(double_bond_dists > buffer, dim=-1).sum( + dim=-1 + ) + num_double_bonds += double_bond_index.shape[0] + + num_aromatic_5_violations[num_aromatic_5_rings > 0] /= num_aromatic_5_rings[ + num_aromatic_5_rings > 0 + ] + num_aromatic_6_violations[num_aromatic_6_rings > 0] /= num_aromatic_6_rings[ + num_aromatic_6_rings > 0 + ] + num_double_bond_violations[num_double_bonds > 0] /= num_double_bonds[ + num_double_bonds > 0 + ] + + return ( + num_aromatic_5_violations, + num_aromatic_5_rings, + num_aromatic_6_violations, + num_aromatic_6_rings, + num_double_bond_violations, + num_double_bonds, + ) diff --git a/src/boltz/model/models/boltz2.py b/src/boltz/model/models/boltz2.py index d42f3400c..9d36bb3a3 100644 --- a/src/boltz/model/models/boltz2.py +++ b/src/boltz/model/models/boltz2.py @@ -31,7 +31,6 @@ InputEmbedder, MSAModule, TemplateModule, - TemplateV2Module, ) from boltz.model.optim.ema import EMA from boltz.model.optim.scheduler import AlphaFoldLRScheduler @@ -102,7 +101,6 @@ def __init__( predict_bfactor: bool = False, log_loss_every_steps: int = 50, checkpoint_diffusion_conditioning: bool = False, - use_templates_v2: bool = False, use_kernels: bool = False, ) -> None: super().__init__() @@ -133,6 +131,7 @@ def __init__( self.diffusion_loss_args = diffusion_loss_args self.predict_args = predict_args self.steering_args = steering_args + self.validate_structure = validate_structure # Training metrics if validate_structure: @@ -152,6 +151,7 @@ def __init__( self.num_bins = num_bins self.min_dist = min_dist self.max_dist = max_dist + self.num_distograms = 1 self.aggregate_distogram = aggregate_distogram # Trunk @@ -215,10 +215,7 @@ def __init__( # Pairwise stack self.use_templates = use_templates if use_templates: - if use_templates_v2: - self.template_module = TemplateV2Module(token_z, **template_args) - else: - self.template_module = TemplateModule(token_z, **template_args) + self.template_module = TemplateModule(token_z, **template_args) if compile_templates: self.is_template_compiled = True self.template_module = torch.compile( @@ -495,11 +492,7 @@ def forward( "z": z, } - if ( - self.run_trunk_and_structure - and ((not self.training) or self.confidence_prediction) - and (not self.skip_run_structure) - ): + if self.run_trunk_and_structure and (not self.skip_run_structure): if self.checkpoint_diffusion_conditioning and self.training: # TODO decide whether this should be with bf16 or not q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias = ( @@ -529,19 +522,20 @@ def forward( "token_trans_bias": token_trans_bias, } - with torch.autocast("cuda", enabled=False): - struct_out = self.structure_module.sample( - s_trunk=s.float(), - s_inputs=s_inputs.float(), - feats=feats, - num_sampling_steps=num_sampling_steps, - atom_mask=feats["atom_pad_mask"].float(), - multiplicity=diffusion_samples, - max_parallel_samples=max_parallel_samples, - steering_args=self.steering_args, - diffusion_conditioning=diffusion_conditioning, - ) - dict_out.update(struct_out) + if (not self.training) or self.confidence_prediction: + with torch.autocast("cuda", enabled=False): + struct_out = self.structure_module.sample( + s_trunk=s.float(), + s_inputs=s_inputs.float(), + feats=feats, + num_sampling_steps=num_sampling_steps, + atom_mask=feats["atom_pad_mask"].float(), + multiplicity=diffusion_samples, + max_parallel_samples=max_parallel_samples, + steering_args=self.steering_args, + diffusion_conditioning=diffusion_conditioning, + ) + dict_out.update(struct_out) if self.predict_bfactor: pbfactor = self.bfactor_module(s) diff --git a/src/boltz/model/modules/diffusionv2.py b/src/boltz/model/modules/diffusionv2.py index fd1af56d3..7876ee181 100644 --- a/src/boltz/model/modules/diffusionv2.py +++ b/src/boltz/model/modules/diffusionv2.py @@ -308,13 +308,13 @@ def sample( ): potentials = get_potentials(steering_args, boltz2=True) - if steering_args["fk_steering"]: + if steering_args is not None and steering_args["fk_steering"]: multiplicity = multiplicity * steering_args["num_particles"] energy_traj = torch.empty((multiplicity, 0), device=self.device) resample_weights = torch.ones(multiplicity, device=self.device).reshape( -1, steering_args["num_particles"] ) - if ( + if steering_args is not None and ( steering_args["physical_guidance_update"] or steering_args["contact_guidance_update"] ): @@ -362,8 +362,11 @@ def sample( + random_tr ) if ( - steering_args["physical_guidance_update"] - or steering_args["contact_guidance_update"] + steering_args is not None + and ( + steering_args["physical_guidance_update"] + or steering_args["contact_guidance_update"] + ) ) and scaled_guidance_update is not None: scaled_guidance_update = torch.einsum( "bmd,bds->bms", scaled_guidance_update, random_R @@ -395,12 +398,16 @@ def sample( ) atom_coords_denoised[sample_ids_chunk] = atom_coords_denoised_chunk - if steering_args["fk_steering"] and ( - ( - step_idx % steering_args["fk_resampling_interval"] == 0 - and noise_var > 0 + if ( + steering_args is not None + and steering_args["fk_steering"] + and ( + ( + step_idx % steering_args["fk_resampling_interval"] == 0 + and noise_var > 0 + ) + or step_idx == num_sampling_steps - 1 ) - or step_idx == num_sampling_steps - 1 ): # Compute energy of x_0 prediction energy = torch.zeros(multiplicity, device=self.device) @@ -423,8 +430,11 @@ def sample( # Compute ll difference between guided and unguided transition distribution if ( - steering_args["physical_guidance_update"] - or steering_args["contact_guidance_update"] + steering_args is not None + and ( + steering_args["physical_guidance_update"] + or steering_args["contact_guidance_update"] + ) ) and noise_var > 0: ll_difference = ( eps**2 - (eps + scaled_guidance_update) ** 2 @@ -442,8 +452,11 @@ def sample( # Compute guidance update to x_0 prediction if ( - steering_args["physical_guidance_update"] - or steering_args["contact_guidance_update"] + steering_args is not None + and ( + steering_args["physical_guidance_update"] + or steering_args["contact_guidance_update"] + ) ) and step_idx < num_sampling_steps - 1: guidance_update = torch.zeros_like(atom_coords_denoised) for guidance_step in range(steering_args["num_gd_steps"]): @@ -472,12 +485,16 @@ def sample( / t_hat ) - if steering_args["fk_steering"] and ( - ( - step_idx % steering_args["fk_resampling_interval"] == 0 - and noise_var > 0 + if ( + steering_args is not None + and steering_args["fk_steering"] + and ( + ( + step_idx % steering_args["fk_resampling_interval"] == 0 + and noise_var > 0 + ) + or step_idx == num_sampling_steps - 1 ) - or step_idx == num_sampling_steps - 1 ): resample_indices = ( torch.multinomial( @@ -614,7 +631,9 @@ def compute_loss( multiplicity, 0 ) - align_weights = denoised_atom_coords.new_ones(denoised_atom_coords.shape[:2]) + align_weights = denoised_atom_coords.new_ones( + denoised_atom_coords.shape[:2] + ) atom_type = ( torch.bmm( feats["atom_to_token"].float(), diff --git a/src/boltz/model/validation/rcsb.py b/src/boltz/model/validation/rcsb.py new file mode 100644 index 000000000..69a0a8319 --- /dev/null +++ b/src/boltz/model/validation/rcsb.py @@ -0,0 +1,61 @@ +from typing import Optional + +import torch +from pytorch_lightning import LightningModule + +from boltz.model.validation.validator import Validator + + +class RCSBValidator(Validator): + """Validation step implementation for RCSB.""" + + def __init__( + self, + val_names: list[str], + confidence_prediction: bool = False, + override_val_method: Optional[str] = None, + ) -> None: + super().__init__( + val_names=val_names, + confidence_prediction=confidence_prediction, + override_val_method=override_val_method, + ) + + def process( + self, + model: LightningModule, + batch: dict[str, torch.Tensor], + out: dict[str, torch.Tensor], + idx_dataset: int, + ) -> None: + """Compute features. + + Parameters + ---------- + model : LightningModule + The LightningModule model. + batch : Dict[str, torch.Tensor] + The batch input. + out : Dict[str, torch.Tensor] + The output of the model. + + """ + symmetry_correction = model.val_group_mapper[idx_dataset]["symmetry_correction"] + expand_to_diffusion_samples = ( + symmetry_correction # True # TODO Mateo why is this set to sym correction? + ) + + # For now all was dumped into the common operation in the parent Validator class + self.common_val_step( + model, + batch, + out, + idx_dataset, + expand_to_diffusion_samples=expand_to_diffusion_samples, + ) + + # TODO: Implement the RCSB specific validation step + + def on_epoch_end(self, model: LightningModule) -> None: + # For now all was dumped into the common operation in the parent Validator class + self.common_on_epoch_end(model) diff --git a/src/boltz/model/validation/validator.py b/src/boltz/model/validation/validator.py new file mode 100644 index 000000000..393e146ea --- /dev/null +++ b/src/boltz/model/validation/validator.py @@ -0,0 +1,1266 @@ +from collections import defaultdict +from typing import Optional + +import torch +import torch._dynamo +from pytorch_lightning import LightningModule +from torch import nn +from torchmetrics import MeanMetric + +from boltz.data import const +from boltz.model.loss.distogramv2 import distogram_loss +from boltz.model.loss.inference import ( + compute_chain_clashes, + compute_pb_flatness_metrics, + compute_pb_geometry_metrics, + compute_stereo_metrics, +) +from boltz.model.loss.validation import ( + compute_pae_mae, + compute_pde_mae, + compute_plddt_mae, + factored_lddt_loss, + factored_token_lddt_dist_loss, +) + + +class Validator(nn.Module): + """Compute validation step and aggregation.""" + + def __init__( + self, + val_names: list[str], + confidence_prediction: bool = False, + physicalism_metrics: bool = False, + override_val_method: Optional[str] = None, + ) -> None: + super().__init__() + self.val_names = val_names + + self.override_val_method = override_val_method + if override_val_method is not None: + override_val_method = override_val_method.lower() + assert override_val_method in const.method_types_ids, "Invalid method type." + self.override_val_method = const.method_types_ids[override_val_method] + + self.num_val_datasets = num_val_datasets = len(val_names) + + msg = "Only one dataset supported for now per validator. Define multiple validators for multiple datasets." + assert num_val_datasets == 1, msg + + # Folding metrics + folding_metric_labels = [ + "lddt", + "disto_lddt", + "complex_lddt", + # "rmsd", + "disto_loss", + ] + self.folding_metrics = nn.ModuleDict( + { + k: nn.ModuleList([nn.ModuleDict() for _ in range(num_val_datasets)]) + for k in folding_metric_labels + } + ) + + self.physicalism_metrics = physicalism_metrics + if physicalism_metrics: + # Physical realism metrics + physicalism_metric_labels = ["clash", "pb"] + pb_metric_labels = [ + "bond_length", + "bond_angle", + "internal_clash", + "atom_chirality", + "bond_stereochemistry", + "ring_5_flatness", + "ring_6_flatness", + "double_bond_flatness", + ] + self.physicalism_metrics = nn.ModuleDict( + { + k: nn.ModuleList([nn.ModuleDict() for _ in range(num_val_datasets)]) + for k in physicalism_metric_labels + } + ) + + # Confidence metrics + confidence_metric_prefixes = [ + "top1", + "iplddt_top1", + "ipde_top1", + "pde_top1", + "ptm_top1", + "iptm_top1", + "ligand_iptm_top1", + "protein_iptm_top1", + "avg", + ] + mae_metric_labels = ["plddt_mae", "pde_mae", "pae_mae"] + lddt_confidence_metric_labels = [ + prefix + "_lddt" for prefix in confidence_metric_prefixes + ] + if physicalism_metrics: + clash_confidence_metric_labels = [ + prefix + "_clash" for prefix in confidence_metric_prefixes + ] + pb_confidence_metric_labels = [ + prefix + "_pb" for prefix in confidence_metric_prefixes + ] + else: + clash_confidence_metric_labels, pb_confidence_metric_labels = [], [] + + if confidence_prediction: + self.confidence_metrics = nn.ModuleDict( + { + k: nn.ModuleList([nn.ModuleDict() for _ in range(num_val_datasets)]) + for k in lddt_confidence_metric_labels + + mae_metric_labels + + clash_confidence_metric_labels + + pb_confidence_metric_labels + } + ) + + # Initialize metrics for datasets + for val_idx in range(num_val_datasets): + for m_ in [ + *const.out_types, + "pocket_ligand_protein", + "contact_protein_protein", + ]: + self.folding_metrics["disto_lddt"][val_idx][m_] = MeanMetric() + + for m in const.out_single_types: + if confidence_prediction: + self.confidence_metrics["plddt_mae"][val_idx][m] = MeanMetric() + + for m in ["disto_loss"]: + self.folding_metrics["disto_loss"][val_idx][m] = MeanMetric() + + if self.physicalism_metrics: + for m_ in const.out_single_types: + m = "sym_" + m_ + self.physicalism_metrics["clash"][val_idx][m] = MeanMetric() + if confidence_prediction: + for k in clash_confidence_metric_labels: + self.confidence_metrics[k][val_idx][m] = MeanMetric() + + for m_ in const.clash_types: + m = "asym_" + m_ + self.physicalism_metrics["clash"][val_idx][m] = MeanMetric() + if confidence_prediction: + for k in clash_confidence_metric_labels: + self.confidence_metrics[k][val_idx][m] = MeanMetric() + + for m in pb_metric_labels: + self.physicalism_metrics["pb"][val_idx][m] = MeanMetric() + if confidence_prediction: + for k in pb_confidence_metric_labels: + self.confidence_metrics[k][val_idx][m] = MeanMetric() + + def run_model( + self, model: LightningModule, batch: dict[str, torch.Tensor], idx_dataset: int + ) -> dict[str, torch.Tensor]: + """Compute the forward pass.""" + if self.override_val_method is not None: + new_feature = batch["method_feature"] * 0 + self.override_val_method + batch["method_feature"] = new_feature + + out = model( + batch, + recycling_steps=model.validation_args.recycling_steps, + num_sampling_steps=model.validation_args.sampling_steps, + diffusion_samples=model.validation_args.diffusion_samples, + run_confidence_sequentially=model.validation_args.get( + "run_confidence_sequentially", False + ), + ) + + return out + + # @abstractmethod + def process( + self, + model: LightningModule, + batch: dict[str, torch.Tensor], + out: dict[str, torch.Tensor], + idx_dataset: int, + n_samples: int, + ) -> None: + """Run a validation step. + + Parameters + ---------- + model : LightningModule + The LightningModule model. + batch : Dict[str, torch.Tensor] + The batch input. + out : Dict[str, torch.Tensor] + The output of the model. + + """ + raise NotImplementedError + + def get_local_val_index(self, model: LightningModule, idx_dataset: int) -> int: + """Get the local validation index. + + Parameters + ---------- + idx_dataset : int + The dataset index. + + Returns + ------- + int + The local validation index. + """ + val_name = model.val_group_mapper[idx_dataset]["label"] + return self.val_names.index(val_name) + + def compute_disto_loss( + self, + model: LightningModule, + out: dict[str, torch.Tensor], + batch: dict[str, torch.Tensor], + idx_dataset: int, + ) -> None: + """Compute distogram loss.""" + # Compute validation disto loss + val_disto_loss, _ = distogram_loss( + out, batch, aggregate_distogram=model.aggregate_distogram + ) + + return val_disto_loss + + def compute_disto_lddt(self, model, batch, out, idx_dataset) -> tuple[dict, dict]: + """Compute distogram lddt.""" + boundaries = torch.linspace(model.min_dist, model.max_dist, model.num_bins - 1) + lower = torch.tensor([1.0]) + upper = torch.tensor([model.max_dist + 5.0]) + exp_boundaries = torch.cat((lower, boundaries, upper)) + mid_points = ((exp_boundaries[:-1] + exp_boundaries[1:]) / 2).to( + out["pdistogram"] + ) + + # Compute true distogram + K = batch["coords"].shape[1] + true_center = batch["disto_coords_ensemble"].reshape(K, -1, 3) # (K, L, 3) + + batch["token_disto_mask"] = batch["token_disto_mask"] + + # Compute distogram lddt by looping over predicted distograms + disto_lddt_dict = defaultdict( + lambda: torch.zeros(K, model.num_distograms).to(model.device) + ) + disto_total_dict = defaultdict( + lambda: torch.zeros(K, model.num_distograms).to(model.device) + ) + for i in range(model.num_distograms): + # Compute predicted dists + preds = out["pdistogram"][:, :, :, i] + pred_softmax = torch.softmax(preds, dim=-1) + pred_softmax = pred_softmax.argmax(dim=-1) + pred_softmax = torch.nn.functional.one_hot( + pred_softmax, num_classes=preds.shape[-1] + ) + pred_dist_i = (pred_softmax * mid_points).sum(dim=-1) # (B, L, L) + del pred_softmax + + # Compute true distances for each conformer + # Implemented in a loop to avoid memory issues with large number of + # conformers. Batched version over K factored_token_lddt_dist_loss_ensemble + # more efficient for small K. + for k in range(K): + true_dists_k = torch.cdist(true_center[k], true_center[k])[ + None + ] # (1, L * L) + + # Compute lddt + disto_lddt_dict_, disto_total_dict_ = factored_token_lddt_dist_loss( + feats=batch, + true_d=true_dists_k, + pred_d=pred_dist_i, + ) + + for key in disto_lddt_dict_: + disto_lddt_dict[key][k, i] = disto_lddt_dict_[key].item() + disto_total_dict[key][k, i] = disto_total_dict_[key].item() + + for key in disto_lddt_dict: + # Take min over distograms and average over conformers. Add batch dimension. + disto_lddt_dict[key] = ( + disto_lddt_dict[key].min(dim=1).values.mean(dim=0)[None] + ) + disto_total_dict[key] = ( + disto_total_dict[key].min(dim=1).values.mean(dim=0)[None] + ) + del true_center + del preds + + return disto_lddt_dict, disto_total_dict + + def get_true_coords( + self, + model, + batch, + out, + diffusion_samples, + symmetry_correction, + expand_to_diffusion_samples, + ) -> dict[str, torch.Tensor]: + # Get true coordinates + # TODO modiy for each validator, for now using default from model + return model.get_true_coordinates( + batch=batch, + out=out, + diffusion_samples=diffusion_samples, + symmetry_correction=symmetry_correction, + expand_to_diffusion_samples=expand_to_diffusion_samples, + ) + + def get_lddt_metrics( + self, + model, + batch, + out, + idx_dataset, + n_samples, + true_coords_resolved_mask, + true_coords, + expand_to_diffusion_samples, + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + K = batch["coords"].shape[1] + + if not expand_to_diffusion_samples: + true_coords_resolved_mask = true_coords_resolved_mask.unsqueeze(0).repeat( + (n_samples, 1) + ) + + ### Compute lddt ### + # Implemented in a loop to avoid memory issues with large number + # of conformers + all_lddt_dict = defaultdict(list) + all_total_dict = defaultdict(list) + for ensemble_idx in range(K): + # This OOM for large n_samples. Need to chunk or loop over samples. + + if expand_to_diffusion_samples: + true_coords_k = true_coords[:, ensemble_idx] + else: + true_coords_k = ( + true_coords[ensemble_idx].unsqueeze(0).repeat((n_samples, 1, 1)) + ) + + all_lddt_dict_s, all_total_dict_s = factored_lddt_loss( + feats=batch, + atom_mask=true_coords_resolved_mask, + true_atom_coords=true_coords_k, # (multiplicity, L, 3) + pred_atom_coords=out["sample_atom_coords"], + multiplicity=n_samples, + ) + for key in all_lddt_dict_s: + all_lddt_dict[key].append(all_lddt_dict_s[key]) + all_total_dict[key].append(all_total_dict_s[key]) + + for key in all_lddt_dict: + all_lddt_dict[key] = torch.stack( + all_lddt_dict[key], dim=1 + ) # (multiplicity, K) + all_total_dict[key] = torch.stack(all_total_dict[key], dim=1) + return all_lddt_dict, all_total_dict + + def get_clash_metrics( + self, + batch, + out, + ): + pair_clash_dict, pair_total_dict = compute_chain_clashes( + pred_atom_coords=out["sample_atom_coords"], + feats=batch, + ) + + return pair_clash_dict, pair_total_dict + + def get_pb_metrics( + self, + batch, + out, + ): + ( + num_bond_length_failures, + num_bond_angle_failures, + num_internal_clash_failures, + num_ligands, + ) = compute_pb_geometry_metrics( + pred_atom_coords=out["sample_atom_coords"], + feats=batch, + ) + ( + num_chiral_atom_violations, + num_chiral_atoms, + num_stereo_bond_violations, + num_stereo_bonds, + ) = compute_stereo_metrics( + pred_atom_coords=out["sample_atom_coords"], feats=batch + ) + + ( + num_aromatic_5_violations, + num_aromatic_5_rings, + num_aromatic_6_violations, + num_aromatic_6_rings, + num_double_bond_violations, + num_double_bonds, + ) = compute_pb_flatness_metrics( + pred_atom_coords=out["sample_atom_coords"], feats=batch + ) + + pb_failure_dict = { + "bond_length": num_bond_length_failures, + "bond_angle": num_bond_angle_failures, + "internal_clash": num_internal_clash_failures, + "atom_chirality": num_chiral_atom_violations, + "bond_stereochemistry": num_stereo_bond_violations, + "ring_5_flatness": num_aromatic_5_violations, + "ring_6_flatness": num_aromatic_6_violations, + "double_bond_flatness": num_double_bond_violations, + } + pb_total_dict = { + "bond_length": num_ligands, + "bond_angle": num_ligands, + "internal_clash": num_ligands, + "atom_chirality": num_chiral_atoms, + "bond_stereochemistry": num_stereo_bonds, + "ring_5_flatness": num_aromatic_5_rings, + "ring_6_flatness": num_aromatic_6_rings, + "double_bond_flatness": num_double_bonds, + } + return pb_failure_dict, pb_total_dict + + def get_confidence_metrics( + self, + model, + batch, + out, + idx_dataset, + n_samples, + true_coords, + true_coords_resolved_mask, + expand_to_diffusion_samples, + ): + K = batch["coords"].shape[1] + # note: for now we don't have pae predictions so have to use pLDDT instead of pTM + # also, while AF3 differentiates the best prediction per confidence type we are currently not doing it + # consider this in the future as well as weighing the different pLLDT types before aggregation + + msg = "Confidence_prediction is not supported for num_ensembles_val > 1" + assert batch["coords"].shape[1] == 1, msg + + mae_plddt_dicts = defaultdict(list) + total_mae_plddt_dicts = defaultdict(list) + mae_pde_dicts = defaultdict(list) + total_mae_pde_dicts = defaultdict(list) + mae_pae_dicts = defaultdict(list) + total_mae_pae_dicts = defaultdict(list) + + # All ensembles have same mask + if not expand_to_diffusion_samples: + true_coords_resolved_mask = true_coords_resolved_mask.unsqueeze(0).repeat( + (n_samples, 1) + ) + + for ensemble_idx in range(K): + if expand_to_diffusion_samples: + true_coords_k = true_coords[:, ensemble_idx] + else: + true_coords_k = ( + true_coords[ensemble_idx].unsqueeze(0).repeat((n_samples, 1, 1)) + ) + + mae_plddt_dict, total_mae_plddt_dict = compute_plddt_mae( + pred_atom_coords=out["sample_atom_coords"], + feats=batch, + true_atom_coords=true_coords_k, + pred_lddt=out["plddt"], + true_coords_resolved_mask=true_coords_resolved_mask, + token_level_confidence=model.token_level_confidence, + multiplicity=n_samples, + ) + for key in mae_plddt_dict: + mae_plddt_dicts[key].append(mae_plddt_dict[key]) + total_mae_plddt_dicts[key].append(total_mae_plddt_dict[key]) + + mae_pde_dict, total_mae_pde_dict = compute_pde_mae( + pred_atom_coords=out["sample_atom_coords"], + feats=batch, + true_atom_coords=true_coords_k, + pred_pde=out["pde"], + true_coords_resolved_mask=true_coords_resolved_mask, + multiplicity=n_samples, + ) + + for key in mae_pde_dict: + mae_pde_dicts[key].append(mae_pde_dict[key]) + total_mae_pde_dicts[key].append(total_mae_pde_dict[key]) + + mae_pae_dict, total_mae_pae_dict = compute_pae_mae( + pred_atom_coords=out["sample_atom_coords"], + feats=batch, + true_atom_coords=true_coords_k, + pred_pae=out["pae"], + true_coords_resolved_mask=true_coords_resolved_mask, + multiplicity=n_samples, + ) + + for key in mae_pae_dict: + mae_pae_dicts[key].append(mae_pae_dict[key]) + total_mae_pae_dicts[key].append(total_mae_pae_dict[key]) + + # Take mean over ensembles + for key in mae_plddt_dicts: + mae_plddt_dicts[key] = torch.stack(mae_plddt_dicts[key], dim=0).mean(dim=0) + total_mae_plddt_dicts[key] = torch.stack( + total_mae_plddt_dicts[key], dim=0 + ).mean(dim=0) + + for key in mae_pde_dicts: + mae_pde_dicts[key] = torch.stack(mae_pde_dicts[key], dim=0).mean(dim=0) + total_mae_pde_dicts[key] = torch.stack( + total_mae_pde_dicts[key], dim=0 + ).mean(dim=0) + + for key in mae_pae_dicts: + mae_pae_dicts[key] = torch.stack(mae_pae_dicts[key], dim=0).mean(dim=0) + total_mae_pae_dicts[key] = torch.stack( + total_mae_pae_dicts[key], dim=0 + ).mean(dim=0) + + return ( + mae_plddt_dicts, + total_mae_plddt_dicts, + mae_pde_dicts, + total_mae_pde_dicts, + mae_pae_dicts, + total_mae_pae_dicts, + ) + + def update_confidence_metrics( + self, + batch, + out, + idx_dataset, + n_samples, + all_lddt_dict, + all_total_dict, + mae_plddt_dicts, + total_mae_plddt_dicts, + mae_pde_dicts, + total_mae_pde_dicts, + mae_pae_dicts, + total_mae_pae_dicts, + pair_clash_dict, + pair_total_dict, + pb_failure_dict, + pb_total_dict, + physicalism_metrics, + ): + K = batch["coords"].shape[1] + + for confidence_metric_name in [ + "complex_plddt", + "complex_iplddt", + "complex_pde", + "complex_ipde", + "ptm", + "iptm", + "ligand_iptm", + "protein_iptm", + ]: + confidence_metric_val = out[confidence_metric_name].reshape(-1, n_samples) + top1_idx = confidence_metric_val.argmax(dim=1) + if confidence_metric_name == "complex_plddt": + confidence_metric_prefix = "top1" + elif "complex" in confidence_metric_name: + confidence_metric_prefix = ( + confidence_metric_name.split("_")[1] + "_top1" + ) + else: + confidence_metric_prefix = confidence_metric_name + "_top1" + for key in all_lddt_dict: + if key == "modified": + continue + top1_val = ( + all_lddt_dict[key] + .reshape(n_samples, K)[top1_idx, torch.arange(K)] + .mean(dim=0) + ) + top1_total = ( + all_total_dict[key] + .reshape(n_samples, K)[top1_idx, torch.arange(K)] + .mean(dim=0) + ) + self.confidence_metrics[confidence_metric_prefix + "_lddt"][ + idx_dataset + ][key].update(top1_val, top1_total) + + if physicalism_metrics: + for key in pair_clash_dict: + top1_val = pair_clash_dict[key][top1_idx] + top1_total = pair_total_dict[key][top1_idx] + self.confidence_metrics[confidence_metric_prefix + "_clash"][ + idx_dataset + ][key].update(top1_val, top1_total) + for key in pb_failure_dict: + top1_val = pb_failure_dict[key][top1_idx] + top1_total = pb_total_dict[key][top1_idx] + self.confidence_metrics[confidence_metric_prefix + "_pb"][ + idx_dataset + ][key].update(top1_val, top1_total) + + for key in all_lddt_dict: + if key == "modified": + continue + self.confidence_metrics["avg_lddt"][idx_dataset][key].update( + all_lddt_dict[key], all_total_dict[key] + ) + self.confidence_metrics["pde_mae"][idx_dataset][key].update( + mae_pde_dicts[key], total_mae_pde_dicts[key] + ) + self.confidence_metrics["pae_mae"][idx_dataset][key].update( + mae_pae_dicts[key], total_mae_pae_dicts[key] + ) + for key in mae_plddt_dicts: + self.confidence_metrics["plddt_mae"][idx_dataset][key].update( + mae_plddt_dicts[key], total_mae_plddt_dicts[key] + ) + + if physicalism_metrics: + for key in pair_clash_dict: + self.confidence_metrics["avg_clash"][idx_dataset][key].update( + pair_clash_dict[key], pair_total_dict[key] + ) + for key in pb_failure_dict: + self.confidence_metrics["avg_pb"][idx_dataset][key].update( + pb_failure_dict[key], pb_total_dict[key] + ) + + def update_lddt_rmsd_metrics( + self, + batch, + disto_lddt_dict, + disto_total_dict, + idx_dataset, + return_dict, + ): + # Folding metrics + for m_ in const.out_types: + if m_ == "ligand_protein": + if torch.any( + batch["contact_conditioning"][ + :, :, :, const.contact_conditioning_info["BINDER>POCKET"] + ].bool() + ): + self.folding_metrics["disto_lddt"][idx_dataset][ + "pocket_ligand_protein" + ].update(disto_lddt_dict[m_], disto_total_dict[m_]) + else: + self.folding_metrics["disto_lddt"][idx_dataset][ + "ligand_protein" + ].update(disto_lddt_dict[m_], disto_total_dict[m_]) + + elif m_ == "protein_protein": + if torch.any( + batch["contact_conditioning"][ + :, :, :, const.contact_conditioning_info["CONTACT"] + ].bool() + ): + self.folding_metrics["disto_lddt"][idx_dataset][ + "contact_protein_protein" + ].update(disto_lddt_dict[m_], disto_total_dict[m_]) + else: + self.folding_metrics["disto_lddt"][idx_dataset][ + "protein_protein" + ].update(disto_lddt_dict[m_], disto_total_dict[m_]) + + else: + self.folding_metrics["disto_lddt"][idx_dataset][m_].update( + disto_lddt_dict[m_], disto_total_dict[m_] + ) + + def update_physcialism_metrics( + self, + pair_clash_dict, + pair_total_dict, + pb_failure_dict, + pb_total_dict, + idx_dataset, + ): + for key in pair_clash_dict: + self.physicalism_metrics["clash"][idx_dataset][key].update( + pair_clash_dict[key], pair_total_dict[key] + ) + + for key in pb_failure_dict: + self.physicalism_metrics["pb"][idx_dataset][key].update( + pb_failure_dict[key], pb_total_dict[key] + ) + + def common_val_step( + self, + model: LightningModule, + batch: dict[str, torch.Tensor], + out: dict[str, torch.Tensor], + idx_dataset: int, + expand_to_diffusion_samples: bool = True, + ) -> None: + """Run a common validation step. + + Parameters + ---------- + model : LightningModule + The LightningModule model. + batch : dict[str, torch.Tensor] + The batch input. + out : dict[str, torch.Tensor] + The output of the model. + """ + symmetry_correction = model.val_group_mapper[idx_dataset][ + "symmetry_correction" + ] # global val index + + # Get the local validation index from the global index + idx_dataset = self.get_local_val_index(model, idx_dataset) + + n_samples = model.validation_args.diffusion_samples + + # Compute distogram loss and update metrics + val_disto_loss = self.compute_disto_loss(model, out, batch, idx_dataset) + + # Compute distogram lddt and update metrics + disto_lddt_dict, disto_total_dict = self.compute_disto_lddt( + model, batch, out, idx_dataset + ) + + # Get true coords + return_dict = self.get_true_coords( + model, + batch, + out, + n_samples, + symmetry_correction, + expand_to_diffusion_samples=expand_to_diffusion_samples, + ) + + # Move this and do better as to when to interleave + true_coords = return_dict[ + "true_coords" + ] # (multiplicity, K, L, 3) if expand_to_diffusion_samples else (K, L, 3) + true_coords_resolved_mask = return_dict[ + "true_coords_resolved_mask" + ] # (multiplicity, L) if expand_to_diffusion_samples else (L) + # rmsds = return_dict["rmsds"] + + # Get lddt metrics + all_lddt_dict, all_total_dict = self.get_lddt_metrics( + model, + batch, + out, + idx_dataset, + n_samples, + true_coords_resolved_mask, + true_coords, + expand_to_diffusion_samples, + ) + + # Get physical realism metrics + if self.physicalism_metrics: + pair_clash_dict, pair_total_dict = self.get_clash_metrics( + batch, + out, + ) + pb_failure_dict, pb_total_dict = self.get_pb_metrics( + batch, + out, + ) + else: + pair_clash_dict, pair_total_dict = None, None + pb_failure_dict, pb_total_dict = None, None + + # Filtering based on confidence + if model.confidence_prediction and n_samples > 1: + ( + mae_plddt_dicts, + total_mae_plddt_dicts, + mae_pde_dicts, + total_mae_pde_dicts, + mae_pae_dicts, + total_mae_pae_dicts, + ) = self.get_confidence_metrics( + model, + batch, + out, + idx_dataset, + n_samples, + true_coords, + true_coords_resolved_mask, + expand_to_diffusion_samples, + ) + + # Update distogram loss + self.folding_metrics["disto_loss"][idx_dataset]["disto_loss"].update( + val_disto_loss + ) + + # Update folding metrics + self.update_lddt_rmsd_metrics( + batch, + disto_lddt_dict, + disto_total_dict, + idx_dataset, + return_dict, + ) + + # Update physcial realism metrics + if self.physicalism_metrics: + self.update_physcialism_metrics( + pair_clash_dict, + pair_total_dict, + pb_failure_dict, + pb_total_dict, + idx_dataset, + ) + + # Update confidence metrics + if model.confidence_prediction and n_samples > 1: + self.update_confidence_metrics( + batch, + out, + idx_dataset, + n_samples, + all_lddt_dict, + all_total_dict, + mae_plddt_dicts, + total_mae_plddt_dicts, + mae_pde_dicts, + total_mae_pde_dicts, + mae_pae_dicts, + total_mae_pae_dicts, + pair_clash_dict, + pair_total_dict, + pb_failure_dict, + pb_total_dict, + physicalism_metrics=self.physicalism_metrics, + ) + + def on_epoch_end(self, model: LightningModule): + raise NotImplementedError + + def common_on_epoch_end(self, model: LightningModule): + avg_lddt = [{} for _ in range(self.num_val_datasets)] + avg_disto_lddt = [{} for _ in range(self.num_val_datasets)] + avg_complex_lddt = [{} for _ in range(self.num_val_datasets)] + avg_clash = [{} for _ in range(self.num_val_datasets)] + avg_pb = [{} for _ in range(self.num_val_datasets)] + + if model.confidence_prediction: + avg_mae_plddt = [{} for _ in range(self.num_val_datasets)] + avg_avg_clash = [{} for _ in range(self.num_val_datasets)] + avg_avg_pb = [{} for _ in range(self.num_val_datasets)] + + avg_top1_clash = [{} for _ in range(self.num_val_datasets)] + avg_iplddt_top1_clash = [{} for _ in range(self.num_val_datasets)] + avg_pde_top1_clash = [{} for _ in range(self.num_val_datasets)] + avg_ipde_top1_clash = [{} for _ in range(self.num_val_datasets)] + avg_ptm_top1_clash = [{} for _ in range(self.num_val_datasets)] + avg_iptm_top1_clash = [{} for _ in range(self.num_val_datasets)] + avg_ligand_iptm_top1_clash = [{} for _ in range(self.num_val_datasets)] + avg_protein_iptm_top1_clash = [{} for _ in range(self.num_val_datasets)] + + avg_top1_pb = [{} for _ in range(self.num_val_datasets)] + avg_iplddt_top1_pb = [{} for _ in range(self.num_val_datasets)] + avg_pde_top1_pb = [{} for _ in range(self.num_val_datasets)] + avg_ipde_top1_pb = [{} for _ in range(self.num_val_datasets)] + avg_ptm_top1_pb = [{} for _ in range(self.num_val_datasets)] + avg_iptm_top1_pb = [{} for _ in range(self.num_val_datasets)] + avg_ligand_iptm_top1_pb = [{} for _ in range(self.num_val_datasets)] + avg_protein_iptm_top1_pb = [{} for _ in range(self.num_val_datasets)] + + for idx_dataset in range(self.num_val_datasets): # local idx_dataset + dataset_name_ori = self.val_names[ + idx_dataset + ] # self.val_group_mapper[idx_dataset]["label"] + + # TODO this is harcodeded for now to compare with Boltz-1 metrics + dataset_name = "" if dataset_name_ori == "RCSB" else f"__{dataset_name_ori}" + + for m_ in [ + *const.out_types, + "pocket_ligand_protein", + "contact_protein_protein", + ]: + avg_disto_lddt[idx_dataset][m_] = self.folding_metrics["disto_lddt"][ + idx_dataset + ][m_].compute() + + avg_disto_lddt[idx_dataset][m_] = ( + 0.0 + if torch.isnan(avg_disto_lddt[idx_dataset][m_]) + else avg_disto_lddt[idx_dataset][m_].item() + ) + self.folding_metrics["disto_lddt"][idx_dataset][m_].reset() + model.log( + f"val/disto_lddt_{m_}{dataset_name}", + avg_disto_lddt[idx_dataset][m_], + ) + + for m in const.out_single_types: + if model.confidence_prediction: + avg_mae_plddt[idx_dataset][m] = ( + self.confidence_metrics["plddt_mae"][idx_dataset][m] + .compute() + .item() + ) + self.confidence_metrics["plddt_mae"][idx_dataset][m].reset() + model.log( + f"val/MAE_plddt_{m}{dataset_name}", + avg_mae_plddt[idx_dataset][m], + ) + + overall_disto_lddt = sum( + avg_disto_lddt[idx_dataset][m] * w + for (m, w) in const.out_types_weights.items() + ) / sum(const.out_types_weights.values()) + model.log( + f"val/disto_lddt{dataset_name}", + overall_disto_lddt, + ) + + # Distogram loss + r = self.folding_metrics["disto_loss"][idx_dataset]["disto_loss"].compute() + model.log(f"val/disto_loss{dataset_name}", r) + self.folding_metrics["disto_loss"][idx_dataset]["disto_loss"].reset() + + # Physical realism metrics + if self.physicalism_metrics: + for m in ["asym_" + m_ for m_ in const.clash_types] + [ + "sym_" + m_ for m_ in const.out_single_types + ]: + avg_clash[idx_dataset][m] = self.physicalism_metrics["clash"][ + idx_dataset + ][m].compute() + avg_clash[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_clash[idx_dataset][m]) + else avg_clash[idx_dataset][m].item() + ) + self.physicalism_metrics["clash"][idx_dataset][m].reset() + model.log( + f"val/clash_{m}{dataset_name}", + avg_clash[idx_dataset][m], + ) + + if model.confidence_prediction: + avg_top1_clash[idx_dataset][m] = self.confidence_metrics[ + "top1_clash" + ][idx_dataset][m].compute() + avg_top1_clash[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_top1_clash[idx_dataset][m]) + else avg_top1_clash[idx_dataset][m].item() + ) + self.confidence_metrics["top1_clash"][idx_dataset][m].reset() + model.log( + f"val/top1_clash_{m}{dataset_name}", + avg_top1_clash[idx_dataset][m], + ) + + avg_iplddt_top1_clash[idx_dataset][m] = self.confidence_metrics[ + "iplddt_top1_clash" + ][idx_dataset][m].compute() + avg_iplddt_top1_clash[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_iplddt_top1_clash[idx_dataset][m]) + else avg_iplddt_top1_clash[idx_dataset][m].item() + ) + self.confidence_metrics["iplddt_top1_clash"][idx_dataset][ + m + ].reset() + model.log( + f"val/iplddt_top1_clash_{m}{dataset_name}", + avg_iplddt_top1_clash[idx_dataset][m], + ) + + avg_pde_top1_clash[idx_dataset][m] = self.confidence_metrics[ + "pde_top1_clash" + ][idx_dataset][m].compute() + avg_pde_top1_clash[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_pde_top1_clash[idx_dataset][m]) + else avg_pde_top1_clash[idx_dataset][m].item() + ) + self.confidence_metrics["pde_top1_clash"][idx_dataset][ + m + ].reset() + model.log( + f"val/pde_top1_clash_{m}{dataset_name}", + avg_pde_top1_clash[idx_dataset][m], + ) + + avg_ipde_top1_clash[idx_dataset][m] = self.confidence_metrics[ + "ipde_top1_clash" + ][idx_dataset][m].compute() + avg_ipde_top1_clash[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_ipde_top1_clash[idx_dataset][m]) + else avg_ipde_top1_clash[idx_dataset][m].item() + ) + self.confidence_metrics["ipde_top1_clash"][idx_dataset][ + m + ].reset() + model.log( + f"val/ipde_top1_clash_{m}{dataset_name}", + avg_ipde_top1_clash[idx_dataset][m], + ) + + avg_ptm_top1_clash[idx_dataset][m] = self.confidence_metrics[ + "ptm_top1_clash" + ][idx_dataset][m].compute() + avg_ptm_top1_clash[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_ptm_top1_clash[idx_dataset][m]) + else avg_ptm_top1_clash[idx_dataset][m].item() + ) + self.confidence_metrics["ptm_top1_clash"][idx_dataset][ + m + ].reset() + model.log( + f"val/ptm_top1_clash_{m}{dataset_name}", + avg_ptm_top1_clash[idx_dataset][m], + ) + + avg_iptm_top1_clash[idx_dataset][m] = self.confidence_metrics[ + "iptm_top1_clash" + ][idx_dataset][m].compute() + avg_iptm_top1_clash[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_iptm_top1_clash[idx_dataset][m]) + else avg_iptm_top1_clash[idx_dataset][m].item() + ) + self.confidence_metrics["iptm_top1_clash"][idx_dataset][ + m + ].reset() + model.log( + f"val/iptm_top1_clash_{m}{dataset_name}", + avg_iptm_top1_clash[idx_dataset][m], + ) + + avg_ligand_iptm_top1_clash[idx_dataset][m] = ( + self.confidence_metrics["ligand_iptm_top1_clash"][ + idx_dataset + ][m].compute() + ) + avg_ligand_iptm_top1_clash[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_ligand_iptm_top1_clash[idx_dataset][m]) + else avg_ligand_iptm_top1_clash[idx_dataset][m].item() + ) + self.confidence_metrics["ligand_iptm_top1_clash"][idx_dataset][ + m + ].reset() + model.log( + f"val/ligand_iptm_top1_clash_{m}{dataset_name}", + avg_ligand_iptm_top1_clash[idx_dataset][m], + ) + + avg_protein_iptm_top1_clash[idx_dataset][m] = ( + self.confidence_metrics["protein_iptm_top1_clash"][ + idx_dataset + ][m].compute() + ) + avg_protein_iptm_top1_clash[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_protein_iptm_top1_clash[idx_dataset][m]) + else avg_protein_iptm_top1_clash[idx_dataset][m].item() + ) + self.confidence_metrics["protein_iptm_top1_clash"][idx_dataset][ + m + ].reset() + model.log( + f"val/protein_iptm_top1_clash_{m}{dataset_name}", + avg_protein_iptm_top1_clash[idx_dataset][m], + ) + + avg_avg_clash[idx_dataset][m] = self.confidence_metrics[ + "avg_clash" + ][idx_dataset][m].compute() + avg_avg_clash[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_avg_clash[idx_dataset][m]) + else avg_avg_clash[idx_dataset][m].item() + ) + self.confidence_metrics["avg_clash"][idx_dataset][m].reset() + model.log( + f"val/avg_clash_{m}{dataset_name}", + avg_avg_clash[idx_dataset][m], + ) + + for m in [ + "bond_length", + "bond_angle", + "internal_clash", + "atom_chirality", + "bond_stereochemistry", + "ring_5_flatness", + "ring_6_flatness", + "double_bond_flatness", + ]: + avg_pb[idx_dataset][m] = self.physicalism_metrics["pb"][ + idx_dataset + ][m].compute() + avg_pb[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_pb[idx_dataset][m]) + else avg_pb[idx_dataset][m].item() + ) + self.physicalism_metrics["pb"][idx_dataset][m].reset() + model.log( + f"val/pb_{m}{dataset_name}", + avg_pb[idx_dataset][m], + ) + + if model.confidence_prediction: + avg_top1_pb[idx_dataset][m] = self.confidence_metrics[ + "top1_pb" + ][idx_dataset][m].compute() + avg_top1_pb[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_top1_pb[idx_dataset][m]) + else avg_top1_pb[idx_dataset][m].item() + ) + self.confidence_metrics["top1_pb"][idx_dataset][m].reset() + model.log( + f"val/top1_pb_{m}{dataset_name}", + avg_top1_pb[idx_dataset][m], + ) + + avg_iplddt_top1_pb[idx_dataset][m] = self.confidence_metrics[ + "iplddt_top1_pb" + ][idx_dataset][m].compute() + avg_iplddt_top1_pb[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_iplddt_top1_pb[idx_dataset][m]) + else avg_iplddt_top1_pb[idx_dataset][m].item() + ) + self.confidence_metrics["iplddt_top1_pb"][idx_dataset][ + m + ].reset() + model.log( + f"val/iplddt_top1_pb_{m}{dataset_name}", + avg_iplddt_top1_pb[idx_dataset][m], + ) + + avg_pde_top1_pb[idx_dataset][m] = self.confidence_metrics[ + "pde_top1_pb" + ][idx_dataset][m].compute() + avg_pde_top1_pb[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_pde_top1_pb[idx_dataset][m]) + else avg_pde_top1_pb[idx_dataset][m].item() + ) + self.confidence_metrics["pde_top1_pb"][idx_dataset][m].reset() + model.log( + f"val/pde_top1_pb_{m}{dataset_name}", + avg_pde_top1_pb[idx_dataset][m], + ) + + avg_ipde_top1_pb[idx_dataset][m] = self.confidence_metrics[ + "ipde_top1_pb" + ][idx_dataset][m].compute() + avg_ipde_top1_pb[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_ipde_top1_pb[idx_dataset][m]) + else avg_ipde_top1_pb[idx_dataset][m].item() + ) + self.confidence_metrics["ipde_top1_pb"][idx_dataset][m].reset() + model.log( + f"val/ipde_top1_pb_{m}{dataset_name}", + avg_ipde_top1_pb[idx_dataset][m], + ) + + avg_ptm_top1_pb[idx_dataset][m] = self.confidence_metrics[ + "ptm_top1_pb" + ][idx_dataset][m].compute() + avg_ptm_top1_pb[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_ptm_top1_pb[idx_dataset][m]) + else avg_ptm_top1_pb[idx_dataset][m].item() + ) + self.confidence_metrics["ptm_top1_pb"][idx_dataset][m].reset() + model.log( + f"val/ptm_top1_pb_{m}{dataset_name}", + avg_ptm_top1_pb[idx_dataset][m], + ) + + avg_iptm_top1_pb[idx_dataset][m] = self.confidence_metrics[ + "iptm_top1_pb" + ][idx_dataset][m].compute() + avg_iptm_top1_pb[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_iptm_top1_pb[idx_dataset][m]) + else avg_iptm_top1_pb[idx_dataset][m].item() + ) + self.confidence_metrics["iptm_top1_pb"][idx_dataset][m].reset() + model.log( + f"val/iptm_top1_pb_{m}{dataset_name}", + avg_iptm_top1_pb[idx_dataset][m], + ) + + avg_ligand_iptm_top1_pb[idx_dataset][m] = ( + self.confidence_metrics["ligand_iptm_top1_pb"][idx_dataset][ + m + ].compute() + ) + avg_ligand_iptm_top1_pb[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_ligand_iptm_top1_pb[idx_dataset][m]) + else avg_ligand_iptm_top1_pb[idx_dataset][m].item() + ) + self.confidence_metrics["ligand_iptm_top1_pb"][idx_dataset][ + m + ].reset() + model.log( + f"val/ligand_iptm_top1_pb_{m}{dataset_name}", + avg_ligand_iptm_top1_pb[idx_dataset][m], + ) + + avg_protein_iptm_top1_pb[idx_dataset][m] = ( + self.confidence_metrics["protein_iptm_top1_pb"][ + idx_dataset + ][m].compute() + ) + avg_protein_iptm_top1_pb[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_protein_iptm_top1_pb[idx_dataset][m]) + else avg_protein_iptm_top1_pb[idx_dataset][m].item() + ) + self.confidence_metrics["protein_iptm_top1_pb"][idx_dataset][ + m + ].reset() + model.log( + f"val/protein_iptm_top1_pb_{m}{dataset_name}", + avg_protein_iptm_top1_pb[idx_dataset][m], + ) + + avg_avg_pb[idx_dataset][m] = self.confidence_metrics["avg_pb"][ + idx_dataset + ][m].compute() + avg_avg_pb[idx_dataset][m] = ( + 0.0 + if torch.isnan(avg_avg_pb[idx_dataset][m]) + else avg_avg_pb[idx_dataset][m].item() + ) + self.confidence_metrics["avg_pb"][idx_dataset][m].reset() + model.log( + f"val/avg_pb_{m}{dataset_name}", + avg_avg_pb[idx_dataset][m], + ) From 3d643e390405a1f4b43f42febc1ad8589e5cc379 Mon Sep 17 00:00:00 2001 From: Dejun Lin Date: Sat, 28 Mar 2026 23:37:46 -0700 Subject: [PATCH 2/2] Fold-CP: A Context Parallelism Framework for Biomolecular Modeling Context parallelism (CP) for distributed inference and training for biomolecular folding models across multiple GPUs using a 2D CP mesh combined with data parallelism, demonstrated with the Boltz model. --- .gitignore | 5 +- .pre-commit-config.yaml | 31 + CHANGELOG.md | 13 + CONTRIBUTING.md | 1 + Dockerfile | 60 + LICENSE | 22 +- README.md | 12 + SECURITY.md | 24 + docs/boltz2_cp_prediction.md | 285 ++ docs/boltz2_cp_training.md | 394 ++ licenses/LICENSE | 20 + licenses/third-party-attr.txt | 31 + pre-commit/.secrets.baseline | 132 + pre-commit/license_check.py | 492 +++ pre-commit/license_header | 20 + pyproject.toml | 43 +- scripts/eval/run_evals.py | 121 +- scripts/process/cluster.py | 34 +- scripts/train/configs/structurev2.yaml | 14 +- scripts/train/configs/structurev2_cp.yaml | 69 + scripts/train/configs/structurev2_small.yaml | 33 + .../train/configs/structurev2_small_cp.yaml | 30 + scripts/train/train.py | 53 +- src/boltz/data/crop/boltz.py | 32 +- src/boltz/data/feature/featurizerv2.py | 83 +- src/boltz/data/feature/featurizerv2_train.py | 112 +- src/boltz/data/module/training.py | 95 +- src/boltz/data/module/trainingv2.py | 129 +- src/boltz/data/parse/schema.py | 258 +- src/boltz/data/tokenize/boltz.py | 20 +- src/boltz/data/types.py | 74 +- src/boltz/data/write/writer.py | 61 +- src/boltz/distributed/README.md | 92 + src/boltz/distributed/__init__.py | 20 + src/boltz/distributed/comm.py | 782 ++++ .../distributed/data/feature/featurizer.py | 627 +++ .../data/feature/featurizer_utils.py | 207 + .../distributed/data/feature/symmetry.py | 369 ++ .../distributed/data/module/inferencev2.py | 441 ++ .../distributed/data/module/placements.py | 90 + .../distributed/data/module/trainingv2.py | 519 +++ src/boltz/distributed/data/types.py | 31 + src/boltz/distributed/data/utils.py | 646 +++ src/boltz/distributed/lightning_strategy.py | 248 ++ src/boltz/distributed/main.py | 343 ++ src/boltz/distributed/manager.py | 900 ++++ src/boltz/distributed/model/__init__.py | 20 + .../distributed/model/layers/__init__.py | 20 + .../distributed/model/layers/atom_to_token.py | 653 +++ .../distributed/model/layers/attention.py | 658 +++ .../model/layers/attention_impl.py | 1510 +++++++ .../distributed/model/layers/cat_and_chunk.py | 417 ++ src/boltz/distributed/model/layers/clip.py | 228 + .../model/layers/distribute_module_tools.py | 103 + src/boltz/distributed/model/layers/dropout.py | 289 ++ .../model/layers/dtensor_metadata_tools.py | 174 + .../model/layers/elementwise_op.py | 1024 +++++ .../distributed/model/layers/embedding.py | 273 ++ .../model/layers/flatten_and_unflatten.py | 864 ++++ src/boltz/distributed/model/layers/gather.py | 490 +++ .../distributed/model/layers/layernorm.py | 577 +++ src/boltz/distributed/model/layers/linear.py | 500 +++ .../distributed/model/layers/outer_gather.py | 1309 ++++++ .../distributed/model/layers/outer_op.py | 704 ++++ .../model/layers/outer_product_mean.py | 597 +++ .../model/layers/pair_averaging.py | 726 ++++ .../distributed/model/layers/pairformer.py | 350 ++ .../model/layers/redistribute_transpose.py | 324 ++ .../redistribute_transpose_without_dtensor.py | 133 + .../model/layers/repeat_interleave.py | 193 + .../distributed/model/layers/replicate_op.py | 234 ++ src/boltz/distributed/model/layers/scatter.py | 706 ++++ .../distributed/model/layers/sharded_op.py | 228 + .../distributed/model/layers/shardwise_op.py | 1368 ++++++ .../distributed/model/layers/sigmoid_gate.py | 245 ++ src/boltz/distributed/model/layers/squeeze.py | 414 ++ src/boltz/distributed/model/layers/swiglu.py | 93 + .../distributed/model/layers/transition.py | 87 + .../model/layers/triangular_attention.py | 1658 ++++++++ .../model/layers/triangular_mult.py | 760 ++++ src/boltz/distributed/model/layers/utils.py | 2276 ++++++++++ src/boltz/distributed/model/layers/where.py | 274 ++ src/boltz/distributed/model/loss/__init__.py | 20 + src/boltz/distributed/model/loss/bfactor.py | 305 ++ .../distributed/model/loss/confidencev2.py | 3383 +++++++++++++++ src/boltz/distributed/model/loss/diffusion.py | 1063 +++++ src/boltz/distributed/model/loss/distogram.py | 457 ++ .../distributed/model/loss/triton/__init__.py | 20 + .../model/loss/triton/cdist_lddt.py | 784 ++++ .../model/loss/triton/cdist_pde.py | 1032 +++++ .../model/loss/triton/smooth_lddt_loss.py | 403 ++ .../distributed/model/loss/validation.py | 980 +++++ .../distributed/model/models/__init__.py | 20 + src/boltz/distributed/model/models/boltz2.py | 1394 ++++++ .../distributed/model/modules/__init__.py | 22 + .../model/modules/confidence_utils.py | 942 +++++ .../distributed/model/modules/confidencev2.py | 886 ++++ .../distributed/model/modules/diffusion.py | 1314 ++++++ .../model/modules/diffusion_conditioning.py | 197 + .../distributed/model/modules/encoders.py | 1491 +++++++ .../distributed/model/modules/transformers.py | 864 ++++ .../distributed/model/modules/trunkv2.py | 773 ++++ src/boltz/distributed/model/modules/utils.py | 943 +++++ src/boltz/distributed/model/optim/__init__.py | 20 + src/boltz/distributed/model/optim/ema.py | 266 ++ .../distributed/model/validation/__init__.py | 20 + .../distributed/model/validation/rcsb.py | 135 + .../distributed/model/validation/utils.py | 53 + .../distributed/model/validation/validator.py | 754 ++++ src/boltz/distributed/predict.py | 500 +++ src/boltz/distributed/testing/utils.py | 109 + src/boltz/distributed/train.py | 608 +++ src/boltz/distributed/utils.py | 1161 +++++ src/boltz/main.py | 157 +- src/boltz/model/layers/attentionv2.py | 32 +- src/boltz/model/layers/confidence_utils.py | 58 +- src/boltz/model/layers/outer_product_mean.py | 39 +- src/boltz/model/layers/pairformer.py | 67 +- .../layers/triangular_attention/primitives.py | 71 +- src/boltz/model/layers/triangular_mult.py | 29 +- src/boltz/model/loss/bfactor.py | 31 +- src/boltz/model/loss/confidence.py | 84 +- src/boltz/model/loss/confidencev2.py | 63 +- src/boltz/model/loss/diffusion.py | 57 +- src/boltz/model/loss/diffusionv2.py | 43 +- src/boltz/model/loss/distogramv2.py | 26 +- src/boltz/model/loss/inference.py | 32 +- src/boltz/model/loss/validation.py | 47 +- src/boltz/model/models/boltz2.py | 70 +- src/boltz/model/modules/confidencev2.py | 54 +- src/boltz/model/modules/diffusion.py | 87 +- src/boltz/model/modules/diffusionv2.py | 134 +- src/boltz/model/modules/encoders.py | 46 +- src/boltz/model/modules/encodersv2.py | 164 +- src/boltz/model/modules/transformers.py | 33 +- src/boltz/model/modules/transformersv2.py | 45 +- src/boltz/model/modules/trunkv2.py | 45 +- src/boltz/model/validation/rcsb.py | 26 + src/boltz/model/validation/validator.py | 564 +-- src/boltz/testing/__init__.py | 20 + src/boltz/testing/utils.py | 3730 +++++++++++++++++ src/boltz/workflow/__init__.py | 20 + src/boltz/workflow/utils.py | 208 + tests/conftest.py | 591 +++ tests/data/feature/test_featurizerv2.py | 480 +++ tests/data/feature/test_featurizerv2_train.py | 525 +++ tests/data/write/test_writer.py | 330 ++ tests/distributed/__init__.py | 20 + ...st_dtensor_minimum_lddt_symmetry_coords.py | 355 ++ ...test_dtensor_pack_and_pad_atom_features.py | 323 ++ .../data/test_dtensor_scatter_features.py | 175 + tests/distributed/dtensor_train_harness.py | 365 ++ tests/distributed/model/__init__.py | 20 + ...on_with_dtensor_for_pairformer_use_case.py | 548 +++ .../layers/test_dtensor_atom_to_token.py | 816 ++++ .../model/layers/test_dtensor_attention.py | 908 ++++ .../model/layers/test_dtensor_cat.py | 370 ++ .../model/layers/test_dtensor_chunk.py | 405 ++ .../model/layers/test_dtensor_clip.py | 219 + .../model/layers/test_dtensor_dropout.py | 267 ++ .../layers/test_dtensor_elementwise_op.py | 709 ++++ .../model/layers/test_dtensor_embedding.py | 224 + .../model/layers/test_dtensor_flatten.py | 643 +++ .../model/layers/test_dtensor_gather.py | 187 + .../test_dtensor_layernorm_nocastbf16.py | 369 ++ .../model/layers/test_dtensor_outer_gather.py | 676 +++ .../model/layers/test_dtensor_outer_op.py | 508 +++ .../layers/test_dtensor_outer_product_mean.py | 377 ++ .../test_dtensor_pair_weighted_averaging.py | 429 ++ .../layers/test_dtensor_pairformer_layer.py | 504 +++ .../layers/test_dtensor_pairformer_module.py | 755 ++++ .../test_dtensor_pairformer_no_seq_layer.py | 360 ++ .../test_dtensor_pairformer_no_seq_module.py | 555 +++ .../test_dtensor_redistribute_transpose.py | 494 +++ .../layers/test_dtensor_repeat_interleave.py | 320 ++ .../model/layers/test_dtensor_scatter.py | 265 ++ .../model/layers/test_dtensor_sharded_op.py | 205 + .../model/layers/test_dtensor_shardwise_op.py | 1687 ++++++++ .../model/layers/test_dtensor_sigmoid_gate.py | 198 + .../model/layers/test_dtensor_squeeze.py | 611 +++ .../model/layers/test_dtensor_swiglu.py | 404 ++ .../model/layers/test_dtensor_transition.py | 343 ++ .../layers/test_dtensor_triangle_attention.py | 913 ++++ .../layers/test_dtensor_triangular_mult.py | 429 ++ .../model/layers/test_dtensor_unflatten.py | 742 ++++ .../model/layers/test_dtensor_where.py | 247 ++ .../layers/test_dtensor_window_batch_utils.py | 1021 +++++ .../layers/test_redistribute_transpose.py | 199 + .../model/layers/test_window_batch_utils.py | 551 +++ .../model/layers/test_window_ownership.py | 361 ++ tests/distributed/model/loss/__init__.py | 20 + .../loss/benchmark_smooth_lddt_loss_triton.py | 308 ++ .../model/loss/test_cdist_lddt_triton.py | 430 ++ .../model/loss/test_cdist_pde_triton.py | 378 ++ .../loss/test_compute_plddt_mae_triton.py | 134 + .../model/loss/test_dtensor_bfactor.py | 281 ++ .../loss/test_dtensor_confidence_loss.py | 1039 +++++ .../loss/test_dtensor_confidence_pde_loss.py | 348 ++ .../test_dtensor_confidence_plddt_loss.py | 928 ++++ .../test_dtensor_confidence_resolved_loss.py | 545 +++ .../model/loss/test_dtensor_distogram.py | 371 ++ .../loss/test_dtensor_get_true_coordinates.py | 397 ++ .../model/loss/test_dtensor_pae_loss.py | 412 ++ .../loss/test_dtensor_smooth_lddt_loss.py | 222 + ...st_dtensor_weighted_minimum_rmsd_single.py | 235 ++ .../loss/test_dtensor_weighted_rigid_align.py | 318 ++ .../model/loss/test_get_lddt_metrics.py | 453 ++ .../loss/test_smooth_lddt_loss_triton.py | 603 +++ tests/distributed/model/models/__init__.py | 20 + .../model/models/test_dtensor_boltz2.py | 3389 +++++++++++++++ tests/distributed/model/modules/__init__.py | 22 + .../model/modules/test_dtensor_adaln.py | 215 + .../test_dtensor_atom_attn_decoder_wb.py | 736 ++++ .../test_dtensor_atom_attn_encoder_wb.py | 1040 +++++ .../modules/test_dtensor_atom_encoder_wb.py | 745 ++++ .../modules/test_dtensor_atom_transformer.py | 670 +++ ...st_dtensor_conditioned_transition_block.py | 214 + .../modules/test_dtensor_confidence_utils.py | 731 ++++ .../modules/test_dtensor_confidencev2.py | 1132 +++++ .../model/modules/test_dtensor_diffusion.py | 2368 +++++++++++ .../test_dtensor_diffusion_conditioning.py | 551 +++ .../modules/test_dtensor_diffusion_module.py | 611 +++ ...est_dtensor_diffusion_transformer_layer.py | 576 +++ .../model/modules/test_dtensor_encoders.py | 425 ++ .../modules/test_dtensor_fourier_embedding.py | 188 + .../modules/test_dtensor_input_embedder_wb.py | 749 ++++ .../model/modules/test_dtensor_msa_layer.py | 602 +++ .../model/modules/test_dtensor_msa_module.py | 729 ++++ .../test_dtensor_pairwise_conditioning.py | 254 ++ .../test_dtensor_single_conditioning.py | 347 ++ .../model/modules/test_dtensor_trunkv2.py | 913 ++++ .../model/modules/test_dtensor_utils.py | 747 ++++ ...tensor_utils_center_random_augmentation.py | 206 + tests/distributed/model/optim/__init__.py | 20 + .../model/optim/test_dtensor_ema.py | 333 ++ .../test_dtensor_get_clash_metrics.py | 297 ++ .../validation/test_dtensor_get_pb_metrics.py | 267 ++ .../validation/test_dtensor_rcsb_validator.py | 900 ++++ .../distributed/test_dtensor_boltz2_train.py | 2798 +++++++++++++ .../test_dtensor_cp_dataloader_v2.py | 1188 ++++++ tests/distributed/test_dtensor_layernorm.py | 351 ++ tests/distributed/test_dtensor_linear.py | 429 ++ .../test_dtensor_metadata_tools.py | 304 ++ ...nsor_parallel_assert_factored_lddt_loss.py | 241 ++ tests/distributed/test_dtensor_predict.py | 1290 ++++++ tests/distributed/test_dtensor_stop_and_go.py | 1061 +++++ tests/distributed/test_dtensor_train_utils.py | 224 + tests/distributed/test_layoutmap.py | 310 ++ tests/distributed/test_lightning_strategy.py | 182 + tests/distributed/test_manager.py | 258 ++ .../test_tiled_softmax_attn_update.py | 311 ++ tests/distributed/test_utils.py | 735 ++++ tests/model/layers/test_outer_product_mean.py | 47 +- tests/model/layers/test_triangle_attention.py | 43 +- tests/model/layers/test_triattn_kernel.py | 759 ++++ tests/model/loss/__init__.py | 20 + .../model/loss/test_cdist_lddt_validation.py | 132 + tests/model/loss/test_distogramv2.py | 280 ++ .../test_factored_token_lddt_dist_loss.py | 203 + tests/model/validation/test_validator.py | 467 +++ tests/scripts/__init__.py | 20 + tests/scripts/test_cluster.py | 172 + tests/scripts/test_run_evals.py | 175 + tests/workflow/test_workflow_utils.py | 708 ++++ 264 files changed, 116580 insertions(+), 1148 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 CHANGELOG.md create mode 100644 CONTRIBUTING.md create mode 100644 Dockerfile mode change 100644 => 120000 LICENSE create mode 100644 SECURITY.md create mode 100644 docs/boltz2_cp_prediction.md create mode 100644 docs/boltz2_cp_training.md create mode 100644 licenses/LICENSE create mode 100644 licenses/third-party-attr.txt create mode 100644 pre-commit/.secrets.baseline create mode 100644 pre-commit/license_check.py create mode 100644 pre-commit/license_header create mode 100644 scripts/train/configs/structurev2_cp.yaml create mode 100644 scripts/train/configs/structurev2_small.yaml create mode 100644 scripts/train/configs/structurev2_small_cp.yaml create mode 100644 src/boltz/distributed/README.md create mode 100644 src/boltz/distributed/__init__.py create mode 100644 src/boltz/distributed/comm.py create mode 100644 src/boltz/distributed/data/feature/featurizer.py create mode 100644 src/boltz/distributed/data/feature/featurizer_utils.py create mode 100644 src/boltz/distributed/data/feature/symmetry.py create mode 100644 src/boltz/distributed/data/module/inferencev2.py create mode 100644 src/boltz/distributed/data/module/placements.py create mode 100644 src/boltz/distributed/data/module/trainingv2.py create mode 100644 src/boltz/distributed/data/types.py create mode 100644 src/boltz/distributed/data/utils.py create mode 100644 src/boltz/distributed/lightning_strategy.py create mode 100644 src/boltz/distributed/main.py create mode 100644 src/boltz/distributed/manager.py create mode 100644 src/boltz/distributed/model/__init__.py create mode 100644 src/boltz/distributed/model/layers/__init__.py create mode 100644 src/boltz/distributed/model/layers/atom_to_token.py create mode 100644 src/boltz/distributed/model/layers/attention.py create mode 100644 src/boltz/distributed/model/layers/attention_impl.py create mode 100644 src/boltz/distributed/model/layers/cat_and_chunk.py create mode 100644 src/boltz/distributed/model/layers/clip.py create mode 100644 src/boltz/distributed/model/layers/distribute_module_tools.py create mode 100644 src/boltz/distributed/model/layers/dropout.py create mode 100644 src/boltz/distributed/model/layers/dtensor_metadata_tools.py create mode 100644 src/boltz/distributed/model/layers/elementwise_op.py create mode 100644 src/boltz/distributed/model/layers/embedding.py create mode 100644 src/boltz/distributed/model/layers/flatten_and_unflatten.py create mode 100644 src/boltz/distributed/model/layers/gather.py create mode 100644 src/boltz/distributed/model/layers/layernorm.py create mode 100644 src/boltz/distributed/model/layers/linear.py create mode 100644 src/boltz/distributed/model/layers/outer_gather.py create mode 100644 src/boltz/distributed/model/layers/outer_op.py create mode 100644 src/boltz/distributed/model/layers/outer_product_mean.py create mode 100644 src/boltz/distributed/model/layers/pair_averaging.py create mode 100644 src/boltz/distributed/model/layers/pairformer.py create mode 100755 src/boltz/distributed/model/layers/redistribute_transpose.py create mode 100644 src/boltz/distributed/model/layers/redistribute_transpose_without_dtensor.py create mode 100644 src/boltz/distributed/model/layers/repeat_interleave.py create mode 100644 src/boltz/distributed/model/layers/replicate_op.py create mode 100644 src/boltz/distributed/model/layers/scatter.py create mode 100644 src/boltz/distributed/model/layers/sharded_op.py create mode 100644 src/boltz/distributed/model/layers/shardwise_op.py create mode 100644 src/boltz/distributed/model/layers/sigmoid_gate.py create mode 100644 src/boltz/distributed/model/layers/squeeze.py create mode 100755 src/boltz/distributed/model/layers/swiglu.py create mode 100644 src/boltz/distributed/model/layers/transition.py create mode 100644 src/boltz/distributed/model/layers/triangular_attention.py create mode 100644 src/boltz/distributed/model/layers/triangular_mult.py create mode 100644 src/boltz/distributed/model/layers/utils.py create mode 100644 src/boltz/distributed/model/layers/where.py create mode 100644 src/boltz/distributed/model/loss/__init__.py create mode 100644 src/boltz/distributed/model/loss/bfactor.py create mode 100644 src/boltz/distributed/model/loss/confidencev2.py create mode 100644 src/boltz/distributed/model/loss/diffusion.py create mode 100644 src/boltz/distributed/model/loss/distogram.py create mode 100644 src/boltz/distributed/model/loss/triton/__init__.py create mode 100644 src/boltz/distributed/model/loss/triton/cdist_lddt.py create mode 100644 src/boltz/distributed/model/loss/triton/cdist_pde.py create mode 100644 src/boltz/distributed/model/loss/triton/smooth_lddt_loss.py create mode 100644 src/boltz/distributed/model/loss/validation.py create mode 100644 src/boltz/distributed/model/models/__init__.py create mode 100644 src/boltz/distributed/model/models/boltz2.py create mode 100644 src/boltz/distributed/model/modules/__init__.py create mode 100644 src/boltz/distributed/model/modules/confidence_utils.py create mode 100644 src/boltz/distributed/model/modules/confidencev2.py create mode 100644 src/boltz/distributed/model/modules/diffusion.py create mode 100644 src/boltz/distributed/model/modules/diffusion_conditioning.py create mode 100644 src/boltz/distributed/model/modules/encoders.py create mode 100644 src/boltz/distributed/model/modules/transformers.py create mode 100644 src/boltz/distributed/model/modules/trunkv2.py create mode 100644 src/boltz/distributed/model/modules/utils.py create mode 100644 src/boltz/distributed/model/optim/__init__.py create mode 100644 src/boltz/distributed/model/optim/ema.py create mode 100644 src/boltz/distributed/model/validation/__init__.py create mode 100644 src/boltz/distributed/model/validation/rcsb.py create mode 100644 src/boltz/distributed/model/validation/utils.py create mode 100644 src/boltz/distributed/model/validation/validator.py create mode 100644 src/boltz/distributed/predict.py create mode 100644 src/boltz/distributed/testing/utils.py create mode 100644 src/boltz/distributed/train.py create mode 100644 src/boltz/distributed/utils.py create mode 100644 src/boltz/testing/__init__.py create mode 100644 src/boltz/testing/utils.py create mode 100644 src/boltz/workflow/__init__.py create mode 100644 src/boltz/workflow/utils.py create mode 100644 tests/conftest.py create mode 100644 tests/data/feature/test_featurizerv2.py create mode 100644 tests/data/feature/test_featurizerv2_train.py create mode 100644 tests/data/write/test_writer.py create mode 100644 tests/distributed/__init__.py create mode 100644 tests/distributed/data/test_dtensor_minimum_lddt_symmetry_coords.py create mode 100644 tests/distributed/data/test_dtensor_pack_and_pad_atom_features.py create mode 100644 tests/distributed/data/test_dtensor_scatter_features.py create mode 100644 tests/distributed/dtensor_train_harness.py create mode 100644 tests/distributed/model/__init__.py create mode 100755 tests/distributed/model/layers/test_attention_with_dtensor_for_pairformer_use_case.py create mode 100644 tests/distributed/model/layers/test_dtensor_atom_to_token.py create mode 100644 tests/distributed/model/layers/test_dtensor_attention.py create mode 100644 tests/distributed/model/layers/test_dtensor_cat.py create mode 100755 tests/distributed/model/layers/test_dtensor_chunk.py create mode 100644 tests/distributed/model/layers/test_dtensor_clip.py create mode 100644 tests/distributed/model/layers/test_dtensor_dropout.py create mode 100644 tests/distributed/model/layers/test_dtensor_elementwise_op.py create mode 100644 tests/distributed/model/layers/test_dtensor_embedding.py create mode 100644 tests/distributed/model/layers/test_dtensor_flatten.py create mode 100644 tests/distributed/model/layers/test_dtensor_gather.py create mode 100644 tests/distributed/model/layers/test_dtensor_layernorm_nocastbf16.py create mode 100644 tests/distributed/model/layers/test_dtensor_outer_gather.py create mode 100644 tests/distributed/model/layers/test_dtensor_outer_op.py create mode 100644 tests/distributed/model/layers/test_dtensor_outer_product_mean.py create mode 100644 tests/distributed/model/layers/test_dtensor_pair_weighted_averaging.py create mode 100644 tests/distributed/model/layers/test_dtensor_pairformer_layer.py create mode 100644 tests/distributed/model/layers/test_dtensor_pairformer_module.py create mode 100644 tests/distributed/model/layers/test_dtensor_pairformer_no_seq_layer.py create mode 100644 tests/distributed/model/layers/test_dtensor_pairformer_no_seq_module.py create mode 100755 tests/distributed/model/layers/test_dtensor_redistribute_transpose.py create mode 100644 tests/distributed/model/layers/test_dtensor_repeat_interleave.py create mode 100644 tests/distributed/model/layers/test_dtensor_scatter.py create mode 100644 tests/distributed/model/layers/test_dtensor_sharded_op.py create mode 100644 tests/distributed/model/layers/test_dtensor_shardwise_op.py create mode 100644 tests/distributed/model/layers/test_dtensor_sigmoid_gate.py create mode 100644 tests/distributed/model/layers/test_dtensor_squeeze.py create mode 100755 tests/distributed/model/layers/test_dtensor_swiglu.py create mode 100644 tests/distributed/model/layers/test_dtensor_transition.py create mode 100644 tests/distributed/model/layers/test_dtensor_triangle_attention.py create mode 100644 tests/distributed/model/layers/test_dtensor_triangular_mult.py create mode 100644 tests/distributed/model/layers/test_dtensor_unflatten.py create mode 100644 tests/distributed/model/layers/test_dtensor_where.py create mode 100644 tests/distributed/model/layers/test_dtensor_window_batch_utils.py create mode 100755 tests/distributed/model/layers/test_redistribute_transpose.py create mode 100644 tests/distributed/model/layers/test_window_batch_utils.py create mode 100644 tests/distributed/model/layers/test_window_ownership.py create mode 100644 tests/distributed/model/loss/__init__.py create mode 100644 tests/distributed/model/loss/benchmark_smooth_lddt_loss_triton.py create mode 100644 tests/distributed/model/loss/test_cdist_lddt_triton.py create mode 100644 tests/distributed/model/loss/test_cdist_pde_triton.py create mode 100644 tests/distributed/model/loss/test_compute_plddt_mae_triton.py create mode 100644 tests/distributed/model/loss/test_dtensor_bfactor.py create mode 100644 tests/distributed/model/loss/test_dtensor_confidence_loss.py create mode 100644 tests/distributed/model/loss/test_dtensor_confidence_pde_loss.py create mode 100644 tests/distributed/model/loss/test_dtensor_confidence_plddt_loss.py create mode 100644 tests/distributed/model/loss/test_dtensor_confidence_resolved_loss.py create mode 100644 tests/distributed/model/loss/test_dtensor_distogram.py create mode 100644 tests/distributed/model/loss/test_dtensor_get_true_coordinates.py create mode 100644 tests/distributed/model/loss/test_dtensor_pae_loss.py create mode 100644 tests/distributed/model/loss/test_dtensor_smooth_lddt_loss.py create mode 100644 tests/distributed/model/loss/test_dtensor_weighted_minimum_rmsd_single.py create mode 100644 tests/distributed/model/loss/test_dtensor_weighted_rigid_align.py create mode 100644 tests/distributed/model/loss/test_get_lddt_metrics.py create mode 100644 tests/distributed/model/loss/test_smooth_lddt_loss_triton.py create mode 100644 tests/distributed/model/models/__init__.py create mode 100644 tests/distributed/model/models/test_dtensor_boltz2.py create mode 100644 tests/distributed/model/modules/__init__.py create mode 100644 tests/distributed/model/modules/test_dtensor_adaln.py create mode 100644 tests/distributed/model/modules/test_dtensor_atom_attn_decoder_wb.py create mode 100644 tests/distributed/model/modules/test_dtensor_atom_attn_encoder_wb.py create mode 100644 tests/distributed/model/modules/test_dtensor_atom_encoder_wb.py create mode 100644 tests/distributed/model/modules/test_dtensor_atom_transformer.py create mode 100644 tests/distributed/model/modules/test_dtensor_conditioned_transition_block.py create mode 100644 tests/distributed/model/modules/test_dtensor_confidence_utils.py create mode 100644 tests/distributed/model/modules/test_dtensor_confidencev2.py create mode 100644 tests/distributed/model/modules/test_dtensor_diffusion.py create mode 100644 tests/distributed/model/modules/test_dtensor_diffusion_conditioning.py create mode 100644 tests/distributed/model/modules/test_dtensor_diffusion_module.py create mode 100644 tests/distributed/model/modules/test_dtensor_diffusion_transformer_layer.py create mode 100644 tests/distributed/model/modules/test_dtensor_encoders.py create mode 100644 tests/distributed/model/modules/test_dtensor_fourier_embedding.py create mode 100644 tests/distributed/model/modules/test_dtensor_input_embedder_wb.py create mode 100644 tests/distributed/model/modules/test_dtensor_msa_layer.py create mode 100644 tests/distributed/model/modules/test_dtensor_msa_module.py create mode 100644 tests/distributed/model/modules/test_dtensor_pairwise_conditioning.py create mode 100644 tests/distributed/model/modules/test_dtensor_single_conditioning.py create mode 100644 tests/distributed/model/modules/test_dtensor_trunkv2.py create mode 100644 tests/distributed/model/modules/test_dtensor_utils.py create mode 100644 tests/distributed/model/modules/test_dtensor_utils_center_random_augmentation.py create mode 100644 tests/distributed/model/optim/__init__.py create mode 100644 tests/distributed/model/optim/test_dtensor_ema.py create mode 100644 tests/distributed/model/validation/test_dtensor_get_clash_metrics.py create mode 100644 tests/distributed/model/validation/test_dtensor_get_pb_metrics.py create mode 100644 tests/distributed/model/validation/test_dtensor_rcsb_validator.py create mode 100644 tests/distributed/test_dtensor_boltz2_train.py create mode 100644 tests/distributed/test_dtensor_cp_dataloader_v2.py create mode 100755 tests/distributed/test_dtensor_layernorm.py create mode 100644 tests/distributed/test_dtensor_linear.py create mode 100755 tests/distributed/test_dtensor_metadata_tools.py create mode 100644 tests/distributed/test_dtensor_parallel_assert_factored_lddt_loss.py create mode 100644 tests/distributed/test_dtensor_predict.py create mode 100644 tests/distributed/test_dtensor_stop_and_go.py create mode 100644 tests/distributed/test_dtensor_train_utils.py create mode 100644 tests/distributed/test_layoutmap.py create mode 100644 tests/distributed/test_lightning_strategy.py create mode 100644 tests/distributed/test_manager.py create mode 100644 tests/distributed/test_tiled_softmax_attn_update.py create mode 100644 tests/distributed/test_utils.py create mode 100644 tests/model/layers/test_triattn_kernel.py create mode 100644 tests/model/loss/__init__.py create mode 100644 tests/model/loss/test_cdist_lddt_validation.py create mode 100644 tests/model/loss/test_distogramv2.py create mode 100644 tests/model/loss/test_factored_token_lddt_dist_loss.py create mode 100644 tests/model/validation/test_validator.py create mode 100644 tests/scripts/__init__.py create mode 100644 tests/scripts/test_cluster.py create mode 100644 tests/scripts/test_run_evals.py create mode 100644 tests/workflow/test_workflow_utils.py diff --git a/.gitignore b/.gitignore index 3d20fc11a..47ea960cb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Project specific +tests/test_data + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -163,4 +166,4 @@ cython_debug/ # Boltz prediction outputs # All result files generated from a boltz prediction call -boltz_results_*/ \ No newline at end of file +boltz_results_*/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..f74b6dac0 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-yaml + exclude: "mkdocs.yml" + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.1 + hooks: + - id: ruff + # Don't fail on docstring checks in pre-commit, and don't remove these unused noqa flags. + args: ["--fix", "--ignore", "D", "--ignore", "RUF100"] + - id: ruff-format + - repo: https://github.com/Yelp/detect-secrets + rev: v1.5.0 + hooks: + - id: detect-secrets + name: detect-secrets (everything but notebooks) + args: ['--baseline', './pre-commit/.secrets.baseline', '--exclude-files', '(.*\.ipynb|.*\.baseline|.*\.a3m)$', ] + exclude: package.lock.json + - repo: local + hooks: + - id: license-header-check + name: Run license-check script + entry: python pre-commit/license_check.py --license-header ./pre-commit/license_header --modify + language: python + additional_dependencies: ["click==8.1.7"] + pass_filenames: false + always_run: true diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..3473bb5f7 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,13 @@ +# Boltz Context Parallelism Changelog + +## 0.1.0 (initial release) + +### New Features + +- Distributed inference with DTensor context parallelism (`src/boltz/distributed/predict.py`) +- Distributed training with DTensor context parallelism (`src/boltz/distributed/train.py`) +- 2D CP mesh support (Shard x Shard context parallelism) +- Data parallelism combined with context parallelism (DP x CP) +- Multiple attention kernel backends: cuEquivariance, trifast, FlexAttention +- CUDA memory profiling support +- Preprocessed input format for distributed inference diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..b0ac9d4e8 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1 @@ +This project is currently not accepting contributions. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..754ce8f46 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,60 @@ +# Boltz-1 Dockerfile +ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:25.10-py3 + +FROM ${BASE_IMAGE} AS boltz-base + +# Install core apt packages. +RUN --mount=type=cache,id=apt-cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,id=apt-lib,target=/var/lib/apt,sharing=locked \ + bash -c '\ + apt-get update -qy && \ + apt-get install -qyy \ + libsndfile1 \ + ffmpeg \ + git \ + curl \ + pre-commit \ + lsof \ + git-lfs \ + sudo && \ + apt-get upgrade -qyy \ + rsync && \ + rm -rf /tmp/* /var/tmp/*' + +RUN apt-get install -y gnupg + +RUN mkdir -p /workspace/boltz/ + +# Fix for duplicate triton installation due to pytorch_triton renaming +RUN bash -c ' \ + cd /usr/local/lib/python3*/dist-packages/ && \ + PTRITON_DIR=$(ls -d pytorch_triton-*.dist-info 2>/dev/null | head -n 1) && \ + if [ -n "$PTRITON_DIR" ]; then \ + VERSION=${PTRITON_DIR#pytorch_triton-} && \ + VERSION=${VERSION%.dist-info} && \ + NEW_DIR="triton-${VERSION}.dist-info" && \ + cp -r "$PTRITON_DIR" "$NEW_DIR" && \ + sed -i "s/Name: pytorch-triton/Name: triton/" "$NEW_DIR/METADATA" && \ + echo "Successfully aliased pytorch-triton to triton version $VERSION"; \ + fi' + +ENV NVIDIA_TF32_OVERRIDE=0 + +RUN bash -c 'echo "ubuntu ALL=(root) NOPASSWD:ALL" > /etc/sudoers.d/ubuntu && \ + chmod 0440 /etc/sudoers.d/ubuntu' + +FROM boltz-base AS dev + +# Install boltz-1 +COPY ./README.md /workspace/boltz/README.md +COPY ./pyproject.toml /workspace/boltz/pyproject.toml +COPY ./src /workspace/boltz/src +COPY ./tests /workspace/boltz/tests +COPY ./scripts /workspace/boltz/scripts +COPY ./examples /workspace/boltz/examples + +WORKDIR /workspace/boltz +RUN bash -c 'find . -name __pycache__ -type d -print | xargs rm -rf' +RUN pip install --no-build-isolation --editable .[lint,test,cuda,dev] + +ENV NVIDIA_TF32_OVERRIDE=0 diff --git a/LICENSE b/LICENSE deleted file mode 100644 index a9d657538..000000000 --- a/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/LICENSE b/LICENSE new file mode 120000 index 000000000..4da2101f2 --- /dev/null +++ b/LICENSE @@ -0,0 +1 @@ +licenses/LICENSE \ No newline at end of file diff --git a/README.md b/README.md index 450e0bdad..746971b48 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,18 @@ Boltz is a family of models for biomolecular interaction prediction. Boltz-1 was All the code and weights are provided under MIT license, making them freely available for both academic and commercial uses. For more information about the model, see the [Boltz-1](https://doi.org/10.1101/2024.11.19.624167) and [Boltz-2](https://doi.org/10.1101/2025.06.14.659707) technical reports. To discuss updates, tools and applications join our [Slack channel](https://boltz.bio/join-slack). +## Fold-CP: A Context Parallelism Framework for Biomolecular Modeling + +This repo also contains context parallelism (CP) for distributed inference and training for biomolecular folding models across multiple GPUs using a 2D CP mesh combined with data parallelism, demonstrated with the Boltz model. See [this README](src/boltz/distributed/README.md) for detail. + +### Copyright and License Compliance + +- The context parallel code is licensed under the terms and conditions as written in [the license file](licenses/LICENSE) + +- The original Boltz code is licensed under their respective MIT license (See the [third-party-attr.txt](licenses/third-party-attr.txt)) + +- This project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use + ## Installation > Note: we recommend installing boltz in a fresh python environment diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..9d1a71169 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,24 @@ +## Security + +NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization. + +If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub.** If a potential security issue is inadvertently reported via a public issue or pull request, NVIDIA maintainers may limit public discussion and redirect the reporter to the appropriate private disclosure channels. + +## Reporting Potential Security Vulnerability in an NVIDIA Product + +To report a potential security vulnerability in any NVIDIA product: +- Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html) +- E-Mail: psirt@nvidia.com + - We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key) + - Please include the following information: + - Product/Driver name and version/branch that contains the vulnerability + - Type of vulnerability (code execution, denial of service, buffer overflow, etc.) + - Instructions to reproduce the vulnerability + - Proof-of-concept or exploit code + - Potential impact of the vulnerability, including how an attacker could exploit the vulnerability + +While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information. + +## NVIDIA Product Security + +For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security diff --git a/docs/boltz2_cp_prediction.md b/docs/boltz2_cp_prediction.md new file mode 100644 index 000000000..9fe0f6efd --- /dev/null +++ b/docs/boltz2_cp_prediction.md @@ -0,0 +1,285 @@ +# Boltz CP Distributed Prediction with DTensor Context Parallelism + +> **Note:** The current implementation supports Boltz-2 only. + +This document describes how to run distributed structure prediction +using `src/boltz/distributed/main.py`, which provides a Click CLI that +delegates to `src/boltz/distributed/predict.py::run_predict` for DTensor-based +context parallelism (CP) combined with data parallelism (DP). + +## Entrypoint + +The distributed prediction CLI is: + +``` +src/boltz/distributed/main.py predict [options] +``` + +`DATA` is the path to the input data directory. The CLI resolves the model +checkpoint and CCD molecule directory automatically (downloading them to +`~/.boltz` if not provided), then forwards all options to `run_predict`. + +### Launching with `torchrun` or `srun` + +The script is designed to be launched with **either** `torchrun` (for +single-node or multi-node runs outside SLURM) **or** `srun` (for SLURM +clusters). The `DistributedManager` (`src/boltz/distributed/manager.py`) +auto-detects the launch method: + +- **`torchrun`** — detected via `RANK`, `WORLD_SIZE`, `LOCAL_RANK`, + `MASTER_ADDR`, `MASTER_PORT` environment variables. +- **`srun` (SLURM)** — detected via `SLURM_PROCID`, `SLURM_NPROCS`, + `SLURM_LOCALID`, `SLURM_LAUNCH_NODE_IPADDR` environment variables. + +The detection order can be forced by setting +`BOLTZ_DISTRIBUTED_INIT_METHOD=ENV` or `BOLTZ_DISTRIBUTED_INIT_METHOD=SLURM`. + +Example with `torchrun` (single node, 4 GPUs, dp=1, cp=4): + +```bash +torchrun \ + --nnodes 1 \ + --nproc_per_node 4 \ + src/boltz/distributed/main.py predict \ + /path/to/preprocessed_data \ + --out_dir ./predictions \ + --size_dp 1 \ + --size_cp 4 \ + --recycling_steps 3 \ + --sampling_steps 200 \ + --diffusion_samples 5 +``` + +Example with `srun` (multi-node SLURM, 8 GPUs total, dp=2, cp=4): + +```bash +srun --ntasks-per-node=4 --nodes=2 \ + python src/boltz/distributed/main.py predict \ + /path/to/preprocessed_data \ + --out_dir ./predictions \ + --size_dp 2 \ + --size_cp 4 \ + --checkpoint /path/to/boltz2_conf.ckpt \ + --mol_dir /path/to/mols +``` + +The constraint `size_dp * size_cp == world_size` must hold, and `size_cp` +must be a perfect square (the CP mesh is 2D). + +--- + +## Input Data + +The distributed inference pipeline currently supports only **preprocessed** +input data (`--input_format preprocessed`). The data directory must contain: + +- `manifest.json` — describes the samples to predict. +- `structures/` — preprocessed structure files. +- `msa/` — MSA files for each target. +- `templates/` (optional) — template structure files. +- `extra_mols/` (optional) — additional molecule definitions. + +Rank 0 loads the manifest and broadcasts it to all other ranks via +`torch.distributed.broadcast_object_list` over a CPU (Gloo) process group. + +--- + +## CLI Options + +All options are provided as Click flags on the `predict` subcommand. + +### Required Arguments + +| Argument | Description | +| -------- | ---------------------------------------- | +| `DATA` | Path to the preprocessed data directory. | + +### Common Options + +| Option | Type | Default | Description | +| -------------- | ---- | ------------- | ------------------------------------------------------------------------- | +| `--out_dir` | path | `./` | Output directory for predictions. | +| `--cache` | path | `~/.boltz` | Download cache for checkpoint and CCD molecules. Respects `$BOLTZ_CACHE`. | +| `--checkpoint` | path | auto-download | Path to a Boltz model checkpoint. | +| `--mol_dir` | path | auto-download | Directory containing per-residue CCD molecule pickle files. | + +### Parallelism + +| Option | Type | Default | Description | +| ----------- | ---- | ------- | -------------------------------------------------------------------------- | +| `--size_dp` | int | `1` | Number of data-parallel ranks. | +| `--size_cp` | int | `1` | Total context-parallel ranks (must be a perfect square: 1, 4, 9, 16, ...). | + +The product `size_dp * size_cp` must equal the total world size +(`nproc_per_node * nnodes`). + +### Diffusion Sampling + +| Option | Type | Default | Description | +| ------------------------ | ----- | ------- | ---------------------------------------------------------------------------------------- | +| `--recycling_steps` | int | `3` | Number of trunk recycling iterations. | +| `--sampling_steps` | int | `200` | Number of diffusion denoising steps. | +| `--diffusion_samples` | int | `1` | Number of independent diffusion samples per input. | +| `--max_parallel_samples` | int | `None` | Max samples to run in parallel (`None` = all at once). | +| `--step_scale` | float | `1.5` | Diffusion schedule step scale. Lower values increase sample diversity (recommended 1–2). | + +### Model and Precision + +| Option | Type | Default | Description | +| --------------- | ------ | ------------ | ------------------------------------------------------ | +| `--precision` | enum | `BF16_MIXED` | Model precision: `BF16`, `BF16_MIXED`, `TF32`, `FP32`. | +| `--accelerator` | choice | `gpu` | Device accelerator: `gpu` or `cpu`. | +| `--seed` | int | `None` | Random seed for reproducibility. | + +### Attention Kernel Backends + +| Option | Values | Default | Description | +| ------------------------------------ | ---------------------------------------------------------------- | ----------------- | ----------------------------------------------------------------- | +| `--triattn_backend` | `cueq`, `trifast`, `reference` | `cueq` | Triangle attention kernel. `cueq` requires CUDA + cuequivariance. | +| `--sdpa_with_bias_backend` | `reference`, `torch_flex_attn` | `torch_flex_attn` | SDPA backend for ring-attention layers. | +| `--sdpa_with_bias_shardwise_backend` | `reference`, `torch_sdpa_efficient_attention`, `torch_flex_attn` | `torch_flex_attn` | SDPA backend for window-batched attention layers. | + +### Data Processing + +| Option | Type | Default | Description | +| ----------------------- | ------ | -------------- | --------------------------------------------------------------------------------------------------------------- | +| `--input_format` | choice | `preprocessed` | Input data format. Only `preprocessed` is currently supported. | +| `--max_msa_seqs` | int | `4096` | Maximum number of MSA sequences. | +| `--msa_pad_to_max_seqs` | flag | `False` | Pad MSA to `max_msa_seqs`. | +| `--use_templates` | bool | `True` | Reserved for future use. Template weights are loaded but the distributed TemplateModule is not yet implemented. | + +### Window Batching + +| Option | Type | Default | Description | +| --------------------------------- | ------- | -------- | --------------------------------------------------------------------------------------------- | +| `--atoms_per_window_queries_keys` | int int | `32 128` | (queries, keys) window sizes for atom attention batching. | +| `--pair_mask_mode` | choice | `None` | Pair mask mode: `None` (window batching), `GlobalAtomAttention`, or `SequenceLocalAttention`. | + +### Output + +| Option | Type | Default | Description | +| -------------------- | ------ | ------- | ----------------------------------------------------- | +| `--output_format` | choice | `mmcif` | Output structure format: `pdb` or `mmcif`. | +| `--write_full_pae` | flag | `False` | Write full PAE matrices (requires confidence module). | +| `--local_batch_size` | int | `1` | Per-rank batch size. | +| `--num_ensembles` | int | `1` | Number of ensemble members for structure prediction. | + +### Timeouts and Profiling + +| Option | Type | Default | Description | +| ----------------------- | ----- | ------- | --------------------------------------------------------- | +| `--timeout_nccl` | float | `30` | NCCL timeout in minutes (for CUDA). | +| `--timeout_gloo` | float | `30` | Gloo timeout in minutes (for CPU). | +| `--cuda_memory_profile` | flag | `False` | Dump a CUDA memory snapshot pickle per rank to `out_dir`. | + +--- + +## Inference Pipeline Stages + +The `run_predict` function in `src/boltz/distributed/predict.py` executes +the following stages: + +### 1. Distributed Setup + +- Initializes `DistributedManager` with the appropriate device type and timeout. +- Validates `size_dp * size_cp == world_size` and that `size_cp` is a perfect + square. +- Creates a 2D CP grid: `OrderedDict([("dp", size_dp), ("cp", (sqrt_cp, sqrt_cp))])`. +- Creates CPU-backed Gloo process groups (`world_cpu`, `cp_cpu`) for data + broadcast and metadata exchange. + +### 2. Data Broadcast and Loading + +- Rank 0 loads the manifest and constructs `BoltzProcessedInput`. +- The processed input is broadcast to all ranks via + `broadcast_object_list` over the `world_cpu` group. +- A `Boltz2InferenceDataModuleDTensor` data module is created with the device + mesh, which uses `PredictionDatasetCPWithDTensorV2` internally. Each sample + is featurized, tokenized, and distributed as DTensors across CP ranks. + +### 3. Model Loading + +- The serial Boltz model is loaded from the checkpoint using + `Boltz2Serial.load_from_checkpoint` with `strict=True`. +- Checkpoint hyperparameters (e.g., `pairformer_args.v2`, `msa_args.use_paired_feature`) + are read from the checkpoint and merged to ensure the correct model + architecture is instantiated. +- The serial model is wrapped in the `Boltz2Distributed` wrapper, which + replaces submodules with DTensor-aware distributed counterparts. +- Attention kernel backends are configured via `model.apply(SetTriAttnBackend(...))`, + `model.apply(SetAttnPairBiasBackend(...))`, and + `model.apply(SetAttnPairBiasShardwiseBackend(...))`. + +### 4. Prediction + +- A Lightning `Trainer` with `SingleDeviceStrategy` runs `trainer.predict()`. +- Only CP rank 0 within each DP group writes output files, via `BoltzWriter`. +- Each DP rank writes to its own subdirectory: + `/boltz_results_/predictions_dp{dp_rank}_cp0/`. +- Prediction runs inside a `setup_tf32_env` context when TF32 precision + is selected. + +--- + +## CP-Specific Settings (not in serial `predict`) + +The following settings exist in the distributed `predict` CLI but have no +counterpart in the serial `predict` command: + +### Parallelism topology + +| Key | Type | Description | +| ----------- | ---- | ------------------------------------------------------- | +| `--size_dp` | int | Data-parallel group size. | +| `--size_cp` | int | Context-parallel group size (must be a perfect square). | + +### Precision + +A top-level `--precision` enum (`BF16`, `BF16_MIXED`, `TF32`, `FP32`) that +replaces Lightning's `trainer.precision`. For `BF16` mode, a custom +`HalfPrecisionAllowFrozen` plugin is used to handle Boltz's frozen +dataclasses during input conversion. + +### Attention kernel backends + +| Option | Values | Description | +| ------------------------------------ | ---------------------------------------------------------------- | ----------------------------------------------------------------- | +| `--triattn_backend` | `reference`, `cueq`, `trifast` | Triangular attention kernel. `cueq` does not support FP32 or CPU. | +| `--sdpa_with_bias_backend` | `reference`, `torch_flex_attn` | SDPA backend for ring-attention layers. | +| `--sdpa_with_bias_shardwise_backend` | `reference`, `torch_sdpa_efficient_attention`, `torch_flex_attn` | SDPA backend for window-batched attention layers. | + +### CUDA memory profiling + +When `--cuda_memory_profile` is set, each rank writes a pickle file to +`/cuda_memory_profile_rank.pickle`. + +### Timeouts + +| Key | Type | Default | Description | +| ---------------- | ----- | ------- | --------------------------------------- | +| `--timeout_nccl` | float | `30` | NCCL timeout in minutes (CUDA). | +| `--timeout_gloo` | float | `30` | Gloo timeout in minutes (CPU/metadata). | + +--- + +## Differences from Serial Prediction + +| Aspect | Serial (`src/boltz/main.py predict`) | CP (`src/boltz/distributed/main.py predict`) | +| --------------------- | ------------------------------------------- | ------------------------------------------------------------------------------------- | +| Multi-GPU strategy | Lightning DDP (`DDPStrategy`) | `SingleDeviceStrategy` + DTensor CP mesh | +| Device management | Lightning (`--devices`, `--num_nodes`) | `DistributedManager` via `--size_dp`, `--size_cp` | +| Launch method | `python src/boltz/main.py predict` | `torchrun` or `srun` | +| Input formats | `config_files` (YAML/FASTA), `preprocessed` | `preprocessed` only | +| `num_workers` | Configurable | Fixed at `0` (DTensor CP requires main-process collation) | +| Precision | Lightning `--precision` string | Top-level `--precision` enum | +| Attention backends | Not configurable | `--triattn_backend`, `--sdpa_with_bias_backend`, `--sdpa_with_bias_shardwise_backend` | +| CUDA memory profiling | Not available | `--cuda_memory_profile` flag | +| Confidence prediction | Supported | Not yet supported (`write_confidence_summary=False`) | +| Steering potentials | Supported | Not yet supported | +| Affinity prediction | Supported | Not yet supported | +| Template features | Supported | Weights loaded but distributed TemplateModule not yet implemented | +| Constraint features | Supported | Not yet supported | +| Checkpoint loading | Lightning default | Reads checkpoint hparams, merges v2 flags, loads with `strict=True` | +| Output writing | All ranks write | Only CP rank 0 per DP group writes output | + +--- diff --git a/docs/boltz2_cp_training.md b/docs/boltz2_cp_training.md new file mode 100644 index 000000000..de3af269f --- /dev/null +++ b/docs/boltz2_cp_training.md @@ -0,0 +1,394 @@ +# Boltz CP Distributed Training with DTensor Context Parallelism + +> **Note:** The current implementation supports Boltz-2 only. + +This document describes how to run distributed training using +`src/boltz/distributed/train.py`, which implements DTensor-based context +parallelism (CP) combined with data parallelism (DP). + +## Entrypoint + +The distributed training entrypoint is: + +``` +src/boltz/distributed/train.py [override1=value1] [override2=value2] ... +``` + +It accepts a YAML config file as the first positional argument, followed by +zero or more OmegaConf-style dotlist overrides. The config is loaded via +Hydra (respecting `defaults:` chains), struct mode is disabled to allow +adding new keys, and CLI overrides are merged on top. + +### Launching with `torchrun` or `srun` + +The script is designed to be launched with **either** `torchrun` (for +single-node or multi-node runs outside SLURM) **or** `srun` (for SLURM +clusters). The `DistributedManager` (`src/boltz/distributed/manager.py`) +auto-detects the launch method: + +- **`torchrun`** — detected via `RANK`, `WORLD_SIZE`, `LOCAL_RANK`, + `MASTER_ADDR`, `MASTER_PORT` environment variables. +- **`srun` (SLURM)** — detected via `SLURM_PROCID`, `SLURM_NPROCS`, + `SLURM_LOCALID`, `SLURM_LAUNCH_NODE_IPADDR` environment variables. + +The detection order can be forced by setting +`BOLTZ_DISTRIBUTED_INIT_METHOD=ENV` or `BOLTZ_DISTRIBUTED_INIT_METHOD=SLURM`. + +Example with `torchrun` (single node, 8 GPUs): + +```bash +torchrun \ + --nnodes 1 \ + --nproc_per_node 8 \ + src/boltz/distributed/train.py \ + scripts/train/configs/structurev2_small_cp.yaml \ + parallel_size.size_dp=2 \ + parallel_size.size_cp=4 \ + output= \ + ... +``` + +Example with `srun` (multi-node SLURM): + +```bash +srun --ntasks-per-node=8 --nodes=2 \ + python src/boltz/distributed/train.py \ + scripts/train/configs/structurev2_cp.yaml \ + parallel_size.size_dp=4 \ + parallel_size.size_cp=4 \ + output= \ + ... +``` + +The constraint `size_dp * size_cp == world_size` must hold, and `size_cp` +must be a perfect square (the CP mesh is 2D). + +--- + +## Configuration Hierarchy + +### Base config: `structurev2.yaml` + +`scripts/train/configs/structurev2.yaml` is the base configuration for +Boltz serial training (used by `scripts/train/train.py`). It defines the +full model architecture, data pipeline, training hyperparameters, and +defaults to single-device training: + +```yaml +trainer: + accelerator: cuda + devices: 1 + num_nodes: 1 + precision: bf16-mixed +``` + +### CP overlay: `structurev2_cp.yaml` + +`scripts/train/configs/structurev2_cp.yaml` extends `structurev2.yaml` with +settings required for DTensor context-parallel training: + +```yaml +defaults: + - structurev2 + - _self_ + +trainer: + accelerator: gpu # must be gpu (not cuda) — CP code manages devices + devices: 1 # must be 1 — each Lightning Trainer manages one device + num_nodes: 1 # must be 1 — multi-node is handled by torchrun/SLURM + precision: null # superseded by top-level precision setting + +parallel_size: + size_cp: 1 # context-parallel group size (must be a perfect square) + size_dp: 1 # data-parallel group size + timeout_nccl: 30 # NCCL timeout in minutes + timeout_gloo: 30 # Gloo timeout in minutes + +precision: BF16_MIXED # top-level precision enum (FP32, TF32, BF16, BF16_MIXED, FP16, FP64) + +triattn_backend: cueq +sdpa_with_bias_backend: torch_flex_attn +sdpa_with_bias_shardwise_backend: torch_flex_attn + +data: + num_workers: 0 # must be 0 — DTensor CP requires main-process collation + pin_memory: false +``` + +Key differences from the serial config: + +| Setting | Serial (`structurev2.yaml`) | CP (`structurev2_cp.yaml`) | +|---|---|---| +| `trainer.accelerator` | `cuda` | `gpu` | +| `trainer.devices` | configurable (e.g. 8) | must be `1` | +| `trainer.num_nodes` | configurable (e.g. 4) | must be `1` | +| `trainer.precision` | `bf16-mixed` | `null` (use top-level `precision`) | +| `parallel_size` | not present | `size_dp`, `size_cp`, timeouts | +| `precision` (top-level) | not present | `BF16_MIXED` enum | +| `triattn_backend` | not present | triangular attention kernel backend | +| `sdpa_with_bias_backend` | not present | ring-attention SDPA backend | +| `sdpa_with_bias_shardwise_backend` | not present | window-batched SDPA backend | +| `data.num_workers` | `2` | `0` (required) | +| `data.pin_memory` | `false` | `false` | +| `CUDAMemoryProfile` | not present | optional memory profiling | + +### Small-model variants + +- **`structurev2_small.yaml`** extends `structurev2.yaml` with reduced model + depth (12 pairformer blocks, 3 MSA blocks), smaller sequence limits + (`max_tokens: 256`, `max_atoms: 2048`), and no activation checkpointing. + It also sets `trainer.devices: 8` and `trainer.num_nodes: 4` for multi-GPU + serial DDP training. + +- **`structurev2_small_cp.yaml`** extends `structurev2_cp.yaml` (not the + serial small variant) with the same reduced model depth and sequence + limits. It does **not** set `trainer.devices` or `trainer.num_nodes` + because the CP config already fixes those to `1`. The parallel topology is + controlled entirely via `parallel_size.size_dp` and + `parallel_size.size_cp` CLI overrides. + +--- + +## CP-Specific Settings (not in serial `train.py`) + +The following settings exist in `DistributedTrainConfig` but have no +counterpart in the serial `TrainConfig`: + +### `parallel_size` + +Controls the distributed topology. + +| Key | Type | Description | +|---|---|---| +| `size_dp` | int | Data-parallel group size | +| `size_cp` | int | Context-parallel group size (must be a perfect square: 1, 4, 9, 16, ...) | +| `timeout_nccl` | int | NCCL timeout in minutes (for CUDA) | +| `timeout_gloo` | int | Gloo timeout in minutes (for CPU) | + +The product `size_dp * size_cp` must equal the total world size +(`nproc_per_node * nnodes`). + +### `precision` (top-level) + +An enum (`FP32`, `TF32`, `BF16`, `BF16_MIXED`, `FP16`, `FP64`) that +replaces `trainer.precision`. Setting `trainer.precision` directly raises an +error in the CP entrypoint. + +### Attention kernel backends + +| Key | Values | Description | +|---|---|---| +| `triattn_backend` | `reference`, `cueq`, `trifast`, `cueq_fwd_trifast_bwd` | Triangular attention kernel. `cueq` does not support FP32. | +| `sdpa_with_bias_backend` | `reference`, `torch_sdpa_efficient_attention`, `torch_flex_attn` | SDPA backend for ring-attention layers | +| `sdpa_with_bias_shardwise_backend` | `reference`, `torch_sdpa_efficient_attention`, `torch_flex_attn` | SDPA backend for window-batched (shardwise) attention layers | + +### `OffloadActvCkptToCPU` + +Enables CPU offloading of activation-checkpoint boundary tensors for +selected distributed module types. When activation checkpointing is active, +intermediate activations saved for backward are moved to CPU during the +forward pass and restored on backward, trading extra CPU-GPU transfers for +reduced GPU memory. + +```yaml +# Disabled by default. +OffloadActvCkptToCPU: null + +# Enable for all three supported module types: +OffloadActvCkptToCPU: + - DiffusionTransformer + - MSAModule + - PairformerModule +``` + +CLI override: + +```bash +OffloadActvCkptToCPU='[DiffusionTransformer,MSAModule,PairformerModule]' +``` + +Valid module type names: `DiffusionTransformer`, `MSAModule`, +`PairformerModule`. Any subset may be specified. + +**Prerequisite: activation checkpointing must be enabled on every targeted +module.** The setter raises `ValueError` if a targeted module has +`activation_checkpointing=False`. The standard activation-checkpointing +overrides cover four config sections: + +```bash +model.msa_args.activation_checkpointing=true +model.template_args.activation_checkpointing=true +model.pairformer_args.activation_checkpointing=true +model.score_model_args.activation_checkpointing=true +``` + +However, there is a fifth `DiffusionTransformer` instance nested inside the +`InputEmbedder` (via `AtomAttentionEncoder` -> `AtomTransformer` -> +`DiffusionTransformer`). This instance is controlled by `embedder_args`, not +`score_model_args`. If `DiffusionTransformer` is included in +`OffloadActvCkptToCPU`, you must also enable: + +```bash +model.embedder_args.activation_checkpointing=true +``` + +Otherwise the setter will raise because the `InputEmbedder`'s +`DiffusionTransformer` still has `activation_checkpointing=False`. + +### `CUDAMemoryProfile` + +Optional CUDA memory profiling. Each rank writes a pickle file. + +```yaml +CUDAMemoryProfile: + output_path_prefix: null # set a path prefix to enable, e.g. "profiling/mem" + max_entries: 300000 +``` + +### `checkpoint` + +Overrides for Lightning's `ModelCheckpoint`. The CP entrypoint applies +Boltz defaults (`monitor="val/lddt"`, `save_last=True`, +`save_on_train_epoch_end=True`, `mode="max"`, `every_n_epochs=1`) but any +key can be overridden via CLI: + +```bash +checkpoint.monitor=val/disto_lddt_protein_protein \ +checkpoint.save_top_k=3 +``` + +### `seed` + +When set, seeds are offset by `rank + epoch + global_step` on resume to +avoid replaying identical data samples across ranks and restarts. + +### `validation_only` + +When `true`, runs `trainer.validate()` instead of `trainer.fit()`. Useful +for evaluating a checkpoint without training. + +--- + +## CLI Overrides + +All config keys can be overridden from the command line using OmegaConf +dotlist syntax. The config's struct mode is disabled, so new keys can also +be introduced. + +### Dataset overrides via CLI + +A recently added utility (`convert_datasets_dict_to_list_config` from +`src/boltz/workflow/utils.py`) enables **partial overrides of individual +dataset entries** within the `data.datasets` list directly from the CLI. +This was not possible before because OmegaConf does not natively support +partial updates to ListConfig entries. + +The utility converts dict-style index keys (e.g. `data.datasets.0.key`) +into proper list merges against the base config. The overridable dataset +keys are: + +`_target_`, `target_dir`, `msa_dir`, `prob`, `sampler`, `cropper`, +`template_dir`, `filters`, `split`, `symmetry_correction`, `val_group`, +`use_train_subset`, `moldir`, `override_bfactor`, `override_method` + +Examples: + +```bash +# Override dataset 0's data directory and disable filters +data.datasets.0.target_dir=/path/to/data \ +data.datasets.0.msa_dir=/path/to/msa \ +'data.datasets.0.filters=[]' \ + +# Remove dataset 1 entirely (null removes the entry) +data.datasets.1=null \ +``` + +When `filters` needs to be set to an empty list, it must be quoted to +prevent shell interpretation: `'data.datasets.0.filters=[]'`. + +### Common CLI overrides + +```bash +# Parallelism topology +parallel_size.size_dp= +parallel_size.size_cp= + +# Output and checkpointing +output= +resume= +pretrained= +checkpoint.monitor= + +# Training schedule +trainer.max_epochs= +trainer.accumulate_grad_batches= + +# Logging frequency +trainer.log_every_n_steps= # Lightning: how often to flush to logger (default: 50) +model.log_loss_every_steps= # Model: how often to compute and log losses/norms (default: 50) + +# Data +data.samples_per_epoch= +data.overfit= # Activate overfit mode: use first N samples, validate on train data +data.datasets.0.target_dir= +data.datasets.0.split= + +# Precision +precision=BF16_MIXED + +# WandB (all keys required when wandb section is present) +wandb.name= +wandb.project= +wandb.id= +wandb.entity= + +# Validation only +validation_only=true +``` + +### Logging frequency + +There are two independent gates that control when training metrics appear +in the logger: + +1. **`model.log_loss_every_steps`** (default: 50) — The Boltz model only + calls `self.log()` (and computes parameter/gradient norms) every N + global steps. Steps where this condition is not met skip all logging + computation entirely. + +2. **`trainer.log_every_n_steps`** (default: 50) — PyTorch Lightning only + flushes `self.log()` calls to the configured logger (e.g. WandB) every + N steps. Calls on other steps are buffered but not written. + +Both conditions must be satisfied for metrics to appear. For small +datasets with few steps per epoch, set both to `1`. + +### Overfit mode + +Setting `data.overfit=N` activates overfit mode: + +- Training samples are truncated to the first N per dataset. +- Validation uses the **training** datasets instead of the validation split. +- Sample weights are normalized for uniform sampling. + +When using overfit mode, set `data.datasets.0.split=null` to prevent the +train/val split from reducing the training set. Set +`data.samples_per_epoch` to control the epoch length (divided by `size_dp` +for per-rank step count). + +--- + +## Differences from Serial Training (`scripts/train/train.py`) + +| Aspect | Serial (`scripts/train/train.py`) | CP (`src/boltz/distributed/train.py`) | +|---|---|---| +| Multi-GPU strategy | Lightning DDP (`DDPStrategy`) | `BoltzContextParallelStrategy` (DTensor) | +| Device management | Lightning (`trainer.devices`, `trainer.num_nodes`) | `DistributedManager` via `parallel_size` | +| Launch method | `python scripts/train/train.py` | `torchrun` or `srun` | +| `num_workers` | Configurable (default: 2) | Must be `0` | +| Precision | `trainer.precision` (Lightning string) | Top-level `precision` enum | +| Attention backends | Not configurable | `triattn_backend`, `sdpa_with_bias_backend`, `sdpa_with_bias_shardwise_backend` | +| CUDA memory profiling | Not available | `CUDAMemoryProfile` section | +| Confidence prediction | Supported | Not yet supported (auto-disabled with warning) | +| Steering potentials | Supported | Not yet supported (auto-disabled with warning) | +| Checkpoint strategy | Lightning default | `BoltzContextParallelStrategy` (DTensor-aware save/load) | diff --git a/licenses/LICENSE b/licenses/LICENSE new file mode 100644 index 000000000..d922346aa --- /dev/null +++ b/licenses/LICENSE @@ -0,0 +1,20 @@ +SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: MIT + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/licenses/third-party-attr.txt b/licenses/third-party-attr.txt new file mode 100644 index 000000000..fdd9825d3 --- /dev/null +++ b/licenses/third-party-attr.txt @@ -0,0 +1,31 @@ +This project will download and install additional third-party open source software projects. +Review the license terms of these open source projects before use. + +We directly use code from the following open source project: + +Name: Boltz + Version: https://github.com/jwohlwend/boltz/commit/cb04aeccdd480fd4db707f0bbafde538397fa2ac + License: MIT License + URL: https://github.com/jwohlwend/boltz + License Text: + MIT License + + Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/pre-commit/.secrets.baseline b/pre-commit/.secrets.baseline new file mode 100644 index 000000000..1508efba2 --- /dev/null +++ b/pre-commit/.secrets.baseline @@ -0,0 +1,132 @@ +{ + "version": "1.5.0", + "plugins_used": [ + { + "name": "ArtifactoryDetector" + }, + { + "name": "AWSKeyDetector" + }, + { + "name": "AzureStorageKeyDetector" + }, + { + "name": "Base64HighEntropyString", + "limit": 4.5 + }, + { + "name": "BasicAuthDetector" + }, + { + "name": "CloudantDetector" + }, + { + "name": "DiscordBotTokenDetector" + }, + { + "name": "GitHubTokenDetector" + }, + { + "name": "GitLabTokenDetector" + }, + { + "name": "HexHighEntropyString", + "limit": 3.0 + }, + { + "name": "IbmCloudIamDetector" + }, + { + "name": "IbmCosHmacDetector" + }, + { + "name": "IPPublicDetector" + }, + { + "name": "JwtTokenDetector" + }, + { + "name": "KeywordDetector", + "keyword_exclude": "" + }, + { + "name": "MailchimpDetector" + }, + { + "name": "NpmDetector" + }, + { + "name": "OpenAIDetector" + }, + { + "name": "PrivateKeyDetector" + }, + { + "name": "PypiTokenDetector" + }, + { + "name": "SendGridDetector" + }, + { + "name": "SlackDetector" + }, + { + "name": "SoftlayerDetector" + }, + { + "name": "SquareOAuthDetector" + }, + { + "name": "StripeDetector" + }, + { + "name": "TelegramBotTokenDetector" + }, + { + "name": "TwilioKeyDetector" + } + ], + "filters_used": [ + { + "path": "detect_secrets.filters.allowlist.is_line_allowlisted" + }, + { + "path": "detect_secrets.filters.common.is_baseline_file", + "filename": ".secrets.baseline" + }, + { + "path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies", + "min_level": 2 + }, + { + "path": "detect_secrets.filters.heuristic.is_indirect_reference" + }, + { + "path": "detect_secrets.filters.heuristic.is_likely_id_string" + }, + { + "path": "detect_secrets.filters.heuristic.is_lock_file" + }, + { + "path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string" + }, + { + "path": "detect_secrets.filters.heuristic.is_potential_uuid" + }, + { + "path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign" + }, + { + "path": "detect_secrets.filters.heuristic.is_sequential_string" + }, + { + "path": "detect_secrets.filters.heuristic.is_swagger_file" + }, + { + "path": "detect_secrets.filters.heuristic.is_templated_secret" + } + ], + "results": { + }, + "generated_at": "2024-10-10T04:23:53Z" +} diff --git a/pre-commit/license_check.py b/pre-commit/license_check.py new file mode 100644 index 000000000..db20b7c8b --- /dev/null +++ b/pre-commit/license_check.py @@ -0,0 +1,492 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from __future__ import annotations + +import ast +import logging +import subprocess +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from typing import Callable + +from dataclasses import dataclass +from functools import partial +from pathlib import Path + +import click + +logger = logging.getLogger(__name__) + +__all__: Sequence[str] = ( + # main license check functionality: per file & per directory + # (recursively, *.py filter) + "license_check", + "check_license_project_files", + "Checked", + # error types + "LicenseCheckError", + "HeaderNotFoundError", + # functions that implement license checking behavior + "append_license_header", + "is_valid_python", + "has_header", + "ensure_license_starts_with_pound", + "remove_existing_license_header", + # default license header + "LICENSE_HEADER", + # to run main CLI program logic, w/o Click runner + "main", +) + +LICENSE_HEADER: str = (Path(__file__).parent / "license_header").read_text().strip() + + +@dataclass(frozen=True) +class HeaderNotFoundError(ValueError): + """Error that indicates the pointed-to file does not have a valid license header.""" + + pyfile: Path + + def __str__(self) -> str: # noqa: D105 + return f"{self.pyfile.name} does not have the license header!" + + +LicenseCheckError = IOError | SyntaxError | HeaderNotFoundError +"""Errors that can be encountered during the license check process. + +Specific errors and their underlying causes: + - IOError: problem reading file + - SyntaxError: the input file for license checking is not valid Python + - HeaderNotFound: the input file was valid Python, but did not have the right + license @ the header +""" + + +def license_check( # noqa: PLR0911 + pyfile: Path, + *, + license_header: str = LICENSE_HEADER, + modify: bool, + replace: bool = False, +) -> LicenseCheckError | None: + """Check file for license header, returning nothing on success or an error.""" + if not pyfile.is_file(): + return OSError(f"{pyfile.name} file does not exist!") + + with open(str(pyfile)) as rt: # noqa: PTH123 + pyfile_contents: str = rt.read() + + maybe_err = is_valid_python(pyfile_contents) + if maybe_err is not None: + return maybe_err + + if has_header(pyfile_contents, license_header=license_header): + return None + + if modify: + # `pyfile` doesn't start with `license_header` text + + if replace: + # does it start with some other license header? + # if so, then we delete that before appending our new `license_header` text + pyfile_contents = remove_existing_license_header(pyfile_contents) + maybe_err = is_valid_python(pyfile_contents) + if maybe_err is not None: + return maybe_err + + pyfile_contents = append_license_header(pyfile_contents, license_header=license_header) + maybe_err = is_valid_python(pyfile_contents) + if maybe_err is not None: + return maybe_err + + with open(str(pyfile), "w") as wt: # noqa: PTH123 + wt.write(pyfile_contents) + return None + return HeaderNotFoundError(pyfile) + + +def is_valid_python(pyfile_contents: str) -> SyntaxError | None: + """Validates python code. Returns None if it is valid, otherwise a SyntaxError.""" + try: + _ = ast.parse(pyfile_contents) + except SyntaxError as error: + return error + else: + return None + + +def has_header(pyfile_contents: str, *, license_header: str = LICENSE_HEADER) -> bool: + """True if the :param:`pyfile_contents` starts with the :param:`license_header`.""" + return pyfile_contents.startswith(license_header) + + +def append_license_header(pyfile_contents: str, *, license_header: str = LICENSE_HEADER, n_sep_lines: int = 2) -> str: + """Appends the :param:`license_header` to the beginning of the input Python code. + + Inserts :param:`n_sep_lines` newlines between the license header & Python file + content. + """ + spacer = "\n" * n_sep_lines + return f"{license_header}{spacer}{pyfile_contents}" + + +def remove_existing_license_header(pyfile_contents: str) -> str: + """Heuristically removes the license header from a Python file's contents. + + Assumes that a license header is identified by a span of commented-out lines + from the beginning of the file. I.e. a big initial block of lines starting + with "#" ==> a license header. + + Will always return the input without this "license header" block. + """ + if not pyfile_contents.startswith("#") or len(pyfile_contents) == 0: + return pyfile_contents + lines: list[str] = pyfile_contents.split("\n") + non_header_lines = lines[_last_index_of_header_comment_line(lines) + 1 :] + return "\n".join(non_header_lines) + + +def _last_index_of_header_comment_line(lines: list[str]) -> int: + """Return index into `lines` with the first line that doesn't start as a comment.""" + if len(lines) == 0: + raise ValueError + last_index_of_line_that_started_with_hash_from_beginning: int = -1 + for i, line in enumerate(lines): + if line.startswith("#"): + last_index_of_line_that_started_with_hash_from_beginning = i + else: + break + if last_index_of_line_that_started_with_hash_from_beginning < 0: + raise ValueError("Must supply non-empty lines of Python!") # noqa: TRY003, EM101 + return last_index_of_line_that_started_with_hash_from_beginning + + +@dataclass(frozen=True) +class Checked: + """Result of running license check across a collection of Python files.""" + + noncompliant_files: Mapping[Path, LicenseCheckError] + """Files that either don't have a license header for some reason or another. + """ + + n_files: int + """Total number of Python files checked. + """ + + +def check_license_project_files( + python_package_directory: Path, *, license_header: str, modify: bool, replace: bool +) -> Checked: + """Check all Python files in a given directory tree, returning non-compliant files. + + Each returned file will be associated with the specific :class:`LicenseCheckError`. + For more details, + see :func:`license_check`. + """ + assert python_package_directory.is_dir(), ( + "Input must be a directory of Python files, not a " f"directory: {python_package_directory}" + ) + noncompliant_files = {} + n_files = 0 + for pyfile in python_package_directory.rglob("*.py"): + n_files += 1 + maybe_error = license_check(pyfile, license_header=license_header, modify=modify, replace=replace) + if maybe_error is not None: + noncompliant_files[pyfile] = maybe_error + return Checked(noncompliant_files, n_files) + + +def ensure_license_starts_with_pound(license_header_contents: str) -> str: + """Ensures lines of the license headers start with "# "; add if necessary.""" + if len(license_header_contents) == 0: + raise ValueError("License header must be non-empty!") # noqa: TRY003, EM101 + safe_license_header_lines: list[str] = [] + for line in license_header_contents.split("\n"): + this_line = f"# {line}" if not line.startswith("#") else line + safe_license_header_lines.append(this_line) + return "\n".join(safe_license_header_lines) + + +def get_staged_files() -> list[str]: + """Returns list of git staged files.""" + try: + result = subprocess.run( # noqa: S603 + ["git", "diff", "--cached", "--name-only", "--diff-filter=ACMR"], # noqa: S607 + capture_output=True, + text=True, + check=True, # Raise exception on non-zero exit code + ) + except subprocess.CalledProcessError as e: + error_msg = f"""Git command failed with code {e.returncode} + Command: {" ".join(e.cmd)} + Error output: {e.stderr.strip() or ""} + """ + raise RuntimeError(error_msg) from None + except FileNotFoundError: + raise RuntimeError("Git executable not found - is Git installed?") from None # noqa: TRY003, EM101 + + if not result.stdout.strip(): + return [] + return result.stdout.splitlines() + + +@click.command(help="Check that Python files start with a license header.") +@click.option( + "--check", + "-c", + multiple=True, + type=str, + help="Either a file or directory. If a directory, then all files that are " + "accessible will be included (directories)" + " are searched recursively). Only files that end with *.py will be included. " + "Acceptable to use multiple " + "times. All --check files will be included. Must specify at least one *.py file.", +) +@click.option( + "--modify", + "-m", + is_flag=True, + help="If present, modifies files that don't have the license header. " + "Otherwise, will error-out if it finds any non-compliant files.", +) +@click.option( + "--license-header", + "-l", + required=False, + help="If present, loads the license header from this file. Defaults to use " "standard license header.", +) +@click.option( + "--add-leading", + "-a", + is_flag=True, + help="If present, will ensure that each line of the license header starts with " + "'#'. " + "If any line doesn't, then this option will make the program append '# ' to the " + "start of each line.", +) +@click.option( + "--replace", + "-r", + is_flag=True, + help="If present, will replace an existing license header. By default, this " + "program simply appends the" + "license header text to each .py file. This option will employ a heuristic to " + "detect if a .py file " + "starts with a license header. If detected, then this text is removed before the" + " normal license header appending logic runs.", +) +@click.option( + "--verbose", + is_flag=True, + help="If present, will perform extra (verbose) logging. ", +) +def entrypoint( # noqa: PLR0915, PLR0912, C901 + check: tuple[str, ...], + modify: bool, + license_header: str | None, + add_leading: bool, + replace: bool, + verbose: bool, +) -> None: + logger.info(f"Files/directories for finding .py files: {check}") # noqa: G004 + logger.info(f"Modify .py files with license header?: {modify}") # noqa: G004 + logger.info(f"Overriding standard license header?: {license_header}") # noqa: G004 + logger.info(f"Force each line to start with '# '?: {add_leading}") # noqa: G004 + logger.info(f"Check for and replace existing header?: {replace}") # noqa: G004 + logger.info(f"Verbose (extra) logging?: {verbose}") # noqa: G004 + logger.info("-" * 100) + + if len(check) == 0: + files_staged = get_staged_files() + py_staged = [] + for f in files_staged: + p = Path(f).absolute() + if p.name.endswith(".py"): + py_staged.append(f) + check = tuple(f for f in py_staged) + logger.info(f"Run license_check.py on the staged files: {check}") # noqa: G004 + + if replace and not modify: + raise ValueError("Must use --modify if also using --replace !") # noqa: TRY003, EM101 + + if modify and not replace: + logger.info( + "WARNING: existing license headers are ignored. " + "To replace any existing header text, re-run with --replace. " + ) + + # get all files / directories from --check + files: list[Path] = [] + directories: list[Path] = [] + unknown: list[Path] = [] + for f in check: + p = Path(f).absolute() + if p.is_file(): + files.append(p) + elif p.is_dir(): + directories.append(p) + else: + unknown.append(p) + + # check that they all exist + if len(unknown) > 0: + raise ValueError(f"Found {len(unknown)} --check things that do not exist!\n\n".join(map(str, unknown))) + + # check that files passed in explicitly from --check end with .py + non_py_files = [f for f in files if not f.name.endswith(".py")] + if len(non_py_files) > 0: + raise ValueError( + f"Found {len(non_py_files)} files from --check that are not Python files! " + "(They don't end with .py):\n" + "\n".join(map(str, non_py_files)) + ) + + # resolve the license header: either the default or user override .txt file + if license_header is not None: + lic_file = Path(license_header).absolute() + if not lic_file.is_file(): + raise ValueError( # noqa: TRY003 + f"Supplied a --license-header, but {license_header} is not a file!" # noqa: EM102 + ) + with open(str(lic_file)) as rt: # noqa: PTH123 + license_header_contents: str = rt.read() + msg_license: str = "Using custom license header" + else: + license_header_contents = LICENSE_HEADER + msg_license = "Using default license header" + + if add_leading: + msg_license += ". Ensuring each line of the license header starts with '#'" + license_header_contents = ensure_license_starts_with_pound(license_header_contents) + + logger.info(f"{msg_license}:\n{license_header_contents}" if verbose else f"{msg_license}.") + + # run license check + try: + checked_n_files: int = main( + modify, + license_header_contents, + files=files, + directories=directories, + replace=replace, + ) + except ValueError as error: + logger.info(str(error)) + sys.exit(1) + else: + logger.info( + f"Success! All {checked_n_files} checked have the required license header!" # noqa: G004 + ) + + +def main( + modify: bool, + license_header_contents: str, + *, + files: list[Path], + directories: list[Path], + replace: bool, +) -> int: + """Runs license check on all files & files accessible from the directories. + + On failure, raises an error with all noncompliant files. Returns nothing on success. + See :func:`check_license_project_files` for details. + + Returns the number of files checked on success. On failure, :raises:`ValueError` + with message containing the # of non-compliant files & their specific + :class:`LicenseCheckError` errors. + + The :param:`replace` option will heuristically check for an existing license header. + It will remove and replace this with the :param:`license_header_contents`. + """ + if len(files) == 0 and len(directories) == 0: + return 0 + if len(license_header_contents) == 0: + raise ValueError("Must supply non-empty license header!") # noqa: TRY003, EM101 + checked = _main(modify, license_header_contents, files, directories, replace) + if len(checked.noncompliant_files) > 0: + raise _error(checked.noncompliant_files, checked.n_files, modify) + return checked.n_files + + +def _main( + modify: bool, + license_header_contents: str, + files: list[Path], + directories: list[Path], + replace: bool, +) -> Checked: + check_file: Callable[[Path], LicenseCheckError | None] = partial( + license_check, + license_header=license_header_contents, + modify=modify, + replace=replace, + ) + check_dir: Callable[[Path], Checked] = partial( + check_license_project_files, + modify=modify, + license_header=license_header_contents, + replace=replace, + ) + + n_files_checked: int = 0 + noncompliant_files: dict[Path, LicenseCheckError] = {} + + # license check all individual files + for f in files: + maybe_err = check_file(f) + if maybe_err is not None: + noncompliant_files[f] = maybe_err + n_files_checked += len(files) + + # license check all directories and their contents, recursively + for d in directories: + checked = check_dir(d) + noncompliant_files.update(checked.noncompliant_files) + n_files_checked += checked.n_files + + return Checked(noncompliant_files=noncompliant_files, n_files=n_files_checked) + + +def _error( + noncompliant_files: Mapping[Path, LicenseCheckError], + n_files_checked: int, + modify: bool, +) -> ValueError: + maybe_modify_msg: str = ( + " You can re-run with '--modify' to automatically add the required" " license header." if not modify else "" + ) + error_message: str = ( + f"ERROR: There are {len(noncompliant_files)} / {n_files_checked} " + f"files that do not have the license header!{maybe_modify_msg}\n" + ) + for pyfile, error in noncompliant_files.items(): + error_message += f" {pyfile!s}: {error}\n" + return ValueError(error_message) + + +if __name__ == "__main__": + entrypoint() diff --git a/pre-commit/license_header b/pre-commit/license_header new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/pre-commit/license_header @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/pyproject.toml b/pyproject.toml index 666a4aab0..6ff88b781 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,37 +36,36 @@ dependencies = [ [project.scripts] boltz = "boltz.main:cli" +download_data = "boltz.data.load.load:entrypoint" [project.optional-dependencies] lint = ["ruff"] -test = ["pytest", "requests"] +test = [ + "pytest", + "biotite", + "gdown==5.2.0", + "nest_asyncio", + "ngcsdk", + "pooch", + "pydantic[email]>=2.7.0", + "tqdm", +] cuda = [ - "cuequivariance_ops_cu12>=0.5.0", - "cuequivariance_ops_torch_cu12>=0.5.0", - "cuequivariance_torch>=0.5.0", + "trifast>=0.1.13", + "cuequivariance_ops_cu13==0.8.0", + "cuequivariance_ops_torch_cu13==0.8.0", + "cuequivariance_torch==0.8.0", ] +dev = ["pre-commit==4.1.0"] [tool.ruff] src = ["src"] extend-exclude = ["conf.py"] -target-version = "py39" -lint.select = ["ALL"] -lint.ignore = [ - "COM812", # Conflicts with the formatter - "ISC001", # Conflicts with the formatter - "ANN101", # "missing-type-self" - "RET504", # Unnecessary assignment to `x` before `return` statementRuff - "S101", # Use of `assert` detected - "D100", # Missing docstring in public module - "D104", # Missing docstring in public package - "PT001", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715 - "PT004", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715 - "PT005", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715 - "PT023", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715 - "FBT001", - "FBT002", - "PLR0913", # Too many arguments to init (> 5) -] +target-version = "py310" +line-length = 120 +# Match boltz-1x-cp: only check basic categories (complexity, errors, pyflakes, isort, warnings) +lint.select = ["C", "E", "F", "I", "W"] +lint.ignore = ["C901", "E741", "E501"] [tool.ruff.lint.per-file-ignores] "**/__init__.py" = [ diff --git a/scripts/eval/run_evals.py b/scripts/eval/run_evals.py index 48f3a0687..82530ea5b 100644 --- a/scripts/eval/run_evals.py +++ b/scripts/eval/run_evals.py @@ -1,46 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + import argparse import concurrent.futures +import os import subprocess from pathlib import Path from tqdm import tqdm -OST_COMPARE_STRUCTURE = r""" -#!/bin/bash -# https://openstructure.org/docs/2.7/actions/#ost-compare-structures - -IMAGE_NAME=openstructure-0.2.8 - -command="compare-structures \ --m {model_file} \ --r {reference_file} \ ---fault-tolerant \ ---min-pep-length 4 \ ---min-nuc-length 4 \ --o {output_path} \ ---lddt --bb-lddt --qs-score --dockq \ ---ics --ips --rigid-scores --patch-scores --tm-score" - -sudo docker run -u $(id -u):$(id -g) --rm --volume {mount}:{mount} $IMAGE_NAME $command -""" - +IMAGE_NAME = "openstructure-0.2.8" -OST_COMPARE_LIGAND = r""" -#!/bin/bash -# https://openstructure.org/docs/2.7/actions/#ost-compare-structures -IMAGE_NAME=openstructure-0.2.8 - -command="compare-ligand-structures \ --m {model_file} \ --r {reference_file} \ ---fault-tolerant \ ---lddt-pli --rmsd \ ---substructure-match \ --o {output_path}" - -sudo docker run -u $(id -u):$(id -g) --rm --volume {mount}:{mount} $IMAGE_NAME $command -""" +def _docker_compare_structures_cmd( + model_file: str, + reference_file: str, + output_path: str, + mount: str, +) -> list[str]: + """Build the ``docker run … compare-structures`` command as an arg list.""" + uid = os.getuid() + gid = os.getgid() + return [ + "sudo", "docker", "run", + "-u", f"{uid}:{gid}", + "--rm", + "--volume", f"{mount}:{mount}", + IMAGE_NAME, + "compare-structures", + "-m", model_file, + "-r", reference_file, + "--fault-tolerant", + "--min-pep-length", "4", + "--min-nuc-length", "4", + "-o", output_path, + "--lddt", "--bb-lddt", "--qs-score", "--dockq", + "--ics", "--ips", "--rigid-scores", "--patch-scores", "--tm-score", + ] # fmt: skip + + +def _docker_compare_ligands_cmd( + model_file: str, + reference_file: str, + output_path: str, + mount: str, +) -> list[str]: + """Build the ``docker run … compare-ligand-structures`` command as an arg list.""" + uid = os.getuid() + gid = os.getgid() + return [ + "sudo", "docker", "run", + "-u", f"{uid}:{gid}", + "--rm", + "--volume", f"{mount}:{mount}", + IMAGE_NAME, + "compare-ligand-structures", + "-m", model_file, + "-r", reference_file, + "--fault-tolerant", + "--lddt-pli", "--rmsd", + "--substructure-match", + "-o", output_path, + ] # fmt: skip def evaluate_structure( @@ -49,7 +90,6 @@ def evaluate_structure( reference: Path, outdir: str, mount: str, - executable: str = "/bin/bash", ) -> None: """Evaluate the structure.""" # Evaluate polymer metrics @@ -61,15 +101,13 @@ def evaluate_structure( ) else: subprocess.run( - OST_COMPARE_STRUCTURE.format( + _docker_compare_structures_cmd( model_file=str(pred), reference_file=str(reference), output_path=str(out_path), mount=mount, ), - shell=True, # noqa: S602 check=False, - executable=executable, capture_output=True, ) @@ -79,15 +117,13 @@ def evaluate_structure( print(f"Skipping recomputation of {name} as ligand json file already exists") # noqa: T201 else: subprocess.run( - OST_COMPARE_LIGAND.format( + _docker_compare_ligands_cmd( model_file=str(pred), reference_file=str(reference), output_path=str(out_path), mount=mount, ), - shell=True, # noqa: S602 check=False, - executable=executable, capture_output=True, ) @@ -132,7 +168,6 @@ def main(args): reference=str(ref_path), outdir=str(args.outdir), mount=args.mount, - executable=args.executable, ) first_item = False else: @@ -143,7 +178,6 @@ def main(args): reference=str(ref_path), outdir=str(args.outdir), mount=args.mount, - executable=args.executable, ) futures.append(future) @@ -161,7 +195,6 @@ def main(args): parser.add_argument("--format", type=str, default="af3") parser.add_argument("--testset", type=str, default="casp") parser.add_argument("--mount", type=str) - parser.add_argument("--executable", type=str, default="/bin/bash") parser.add_argument("--max-workers", type=int, default=32) args = parser.parse_args() main(args) diff --git a/scripts/process/cluster.py b/scripts/process/cluster.py index c45dcd999..ce13d3c01 100644 --- a/scripts/process/cluster.py +++ b/scripts/process/cluster.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + """Create a mapping from structure and chain ID to MSA indices.""" import argparse @@ -46,8 +69,15 @@ def main(args: argparse.Namespace) -> None: f.write("\n".join(proteins)) subprocess.run( - f"{args.mmseqs} easy-cluster {outdir / 'proteins.fasta'} {outdir / 'clust_prot'} {outdir / 'tmp'} --min-seq-id 0.4", # noqa: E501 - shell=True, # noqa: S602 + [ + args.mmseqs, + "easy-cluster", + str(outdir / "proteins.fasta"), + str(outdir / "clust_prot"), + str(outdir / "tmp"), + "--min-seq-id", + "0.4", + ], check=True, ) diff --git a/scripts/train/configs/structurev2.yaml b/scripts/train/configs/structurev2.yaml index ea09d18b6..d36b19a9a 100644 --- a/scripts/train/configs/structurev2.yaml +++ b/scripts/train/configs/structurev2.yaml @@ -75,7 +75,7 @@ data: _target_: boltz.data.feature.featurizerv2_train.Boltz2Featurizer moldir: #PATH_HERE - max_tokens: 384 # 640 + max_tokens: 384 # 640 # NOTE: cuEq TriAttn backend on sm100f GPUs requires multiples of 8 token counts per CP shard max_atoms: 3456 # 5760 max_seqs: 8192 pad_to_max_tokens: true @@ -224,11 +224,11 @@ model: diffusion_process_args: sigma_min: 0.0004 - sigma_max: 160.0 - sigma_data: 16.0 - rho: 7 - P_mean: -1.2 - P_std: 1.5 + sigma_max: 160.0 + sigma_data: 16.0 + rho: 7 + P_mean: -1.2 + P_std: 1.5 gamma_0: 0.8 gamma_min: 1.0 noise_scale: 1.0 @@ -241,4 +241,4 @@ model: add_smooth_lddt_loss: true nucleotide_loss_weight: 5.0 ligand_loss_weight: 10.0 - filter_by_plddt: 0.0 \ No newline at end of file + filter_by_plddt: 0.0 diff --git a/scripts/train/configs/structurev2_cp.yaml b/scripts/train/configs/structurev2_cp.yaml new file mode 100644 index 000000000..b36b4287c --- /dev/null +++ b/scripts/train/configs/structurev2_cp.yaml @@ -0,0 +1,69 @@ +defaults: + - structurev2 + - _self_ + +# General DTensor context-parallel settings for normal-size model. +# For small-model CP training, use structurev2_small_cp.yaml instead. + +# DTensor CP uses SingleDeviceStrategy; multi-device/node is managed by +# DistributedManager via parallel_size, not by Lightning. +trainer: + accelerator: gpu # must be gpu instead of cuda because CP code manages devices + devices: 1 # must be 1 for DTensor CP (one device per Lightning Trainer) + num_nodes: 1 # must be 1; multi-node is handled by torchrun/SLURM + precision: null # superseded by top-level precision below + +# Context-parallelism and data-parallelism topology. +# Override via CLI: parallel_size.size_cp=4 parallel_size.size_dp=2 +# Constraint: size_cp must be a perfect square (2D CP mesh); size_dp * size_cp == world_size. +parallel_size: + size_cp: 1 # context-parallel group size (must be a perfect square, e.g. 1, 4, 9, 16) + size_dp: 1 # data-parallel group size + timeout_nccl: 30 # NCCL timeout in minutes (for CUDA) + timeout_gloo: 30 # Gloo timeout in minutes (for CPU) + +# Training precision. Values: FP32, TF32, BF16, BF16_MIXED, FP16, FP64 +precision: BF16_MIXED + +# Triangular attention kernel backend. +# Values: reference, cueq, trifast, cueq_fwd_trifast_bwd +# Note: cueq does not support FP32 precision. +triattn_backend: cueq + +# Scaled dot-product attention with bias backend (ring-attention layers). +# Values: reference, torch_sdpa_efficient_attention, torch_flex_attn +sdpa_with_bias_backend: torch_flex_attn + +# SDPA with bias backend for shardwise (window-batched) attention layers. +# Values: reference, torch_sdpa_efficient_attention, torch_flex_attn +sdpa_with_bias_shardwise_backend: torch_flex_attn + +# CUDA memory profiling. Activated only when output_path_prefix is set. +# Each rank writes to: {output_path_prefix}_rank{global_rank}.pickle +# Additional kwargs are forwarded to torch.cuda.memory._record_memory_history(). +CUDAMemoryProfile: + output_path_prefix: null # set a path prefix to enable, e.g. "profiling/mem" + max_entries: 300000 # max allocation/deallocation events to record + +# CPU offloading of activation-checkpoint boundary tensors. +# Lists the distributed module types whose checkpoint-boundary activations +# should be offloaded to CPU during forward and restored on backward. +# Requires the corresponding activation_checkpointing to be enabled. +# Valid types: DiffusionTransformer, MSAModule, PairformerModule +# Set to null or omit to disable. +OffloadActvCkptToCPU: null + +# DTensor CP requires num_workers=0 (main-process collation for distributed +# tensor construction) and pin_memory=false. +data: + num_workers: 0 + pin_memory: false + +model: + validators: + - _target_: boltz.distributed.model.validation.rcsb.DistributedRCSBValidator + val_names: ["RCSB"] + confidence_prediction: ${model.confidence_prediction} + physicalism_metrics: False + rmsd_metrics: True + clash_score_metrics: True diff --git a/scripts/train/configs/structurev2_small.yaml b/scripts/train/configs/structurev2_small.yaml new file mode 100644 index 000000000..9b39e2ee6 --- /dev/null +++ b/scripts/train/configs/structurev2_small.yaml @@ -0,0 +1,33 @@ +defaults: + - structurev2 + - _self_ + +trainer: + devices: 8 + num_nodes: 4 + accumulate_grad_batches: 4 + +wandb: + name: boltz2_small_bf16mixed + project: boltz + id: boltz2_small_bf16mixed + +data: + checkpoint_monitor_val_group: "val/disto_lddt_protein_protein" + max_tokens: 256 + max_atoms: 2048 + max_seqs: 1024 + num_workers: 0 + +model: + checkpoint_diffusion_conditioning: false + msa_args: + msa_blocks: 3 + activation_checkpointing: false + template_args: + activation_checkpointing: false + pairformer_args: + num_blocks: 12 + activation_checkpointing: false + score_model_args: + activation_checkpointing: false diff --git a/scripts/train/configs/structurev2_small_cp.yaml b/scripts/train/configs/structurev2_small_cp.yaml new file mode 100644 index 000000000..f4ae9b79b --- /dev/null +++ b/scripts/train/configs/structurev2_small_cp.yaml @@ -0,0 +1,30 @@ +defaults: + - structurev2_cp + - _self_ + +trainer: + accumulate_grad_batches: 4 + +wandb: + name: boltz2_small_bf16mixed_cp + project: boltz + id: boltz2_small_bf16mixed_cp + +data: + checkpoint_monitor_val_group: "val/disto_lddt_protein_protein" + max_tokens: 256 + max_atoms: 2048 + max_seqs: 1024 + +model: + checkpoint_diffusion_conditioning: false + msa_args: + msa_blocks: 3 + activation_checkpointing: false + template_args: + activation_checkpointing: false + pairformer_args: + num_blocks: 12 + activation_checkpointing: false + score_model_args: + activation_checkpointing: false diff --git a/scripts/train/train.py b/scripts/train/train.py index 98acafa6e..d7296d87e 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + import os import random import string @@ -11,6 +34,8 @@ import pytorch_lightning as pl import torch import torch.multiprocessing +from hydra import compose, initialize_config_dir +from hydra.core.global_hydra import GlobalHydra from omegaconf import OmegaConf, listconfig from pytorch_lightning import LightningModule from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint @@ -20,6 +45,7 @@ from boltz.data.module.training import BoltzTrainingDataModule, DataConfig from boltz.data.module.trainingv2 import Boltz2TrainingDataModule, DataConfigV2 +from boltz.workflow.utils import _DATASET_KEYS_TO_OVERRIDE, convert_datasets_dict_to_list_config @dataclass @@ -92,11 +118,25 @@ def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912, PLR Any command line overrides. """ - # Load the configuration - raw_config = omegaconf.OmegaConf.load(raw_config) + # Load the configuration (with optional Hydra defaults support) + raw_config_path = raw_config + raw_config = omegaconf.OmegaConf.load(raw_config_path) + if "defaults" in raw_config: + config_path = Path(raw_config_path) + GlobalHydra.instance().clear() + with initialize_config_dir(config_dir=str(config_path.parent.absolute()), version_base=None): + raw_config = compose(config_name=config_path.stem) + omegaconf.OmegaConf.set_struct(raw_config, False) # Apply input arguments args = omegaconf.OmegaConf.from_dotlist(args) + if "data" in args and "datasets" in args.data and "data" in raw_config and "datasets" in raw_config.data: + args["data"]["datasets"] = convert_datasets_dict_to_list_config( + raw_config.data.datasets, + args.data.datasets, + keys_to_override=_DATASET_KEYS_TO_OVERRIDE, + remove_null_datasets=True, + ) raw_config = omegaconf.OmegaConf.merge(raw_config, args) # Instantiate the task @@ -166,8 +206,11 @@ def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912, PLR file_path = cfg.pretrained print(f"Loading model from {file_path}") + hparams = dict(model_module.hparams) + if getattr(model_module, "validate_structure", False) and hasattr(model_module, "validators"): + hparams["validators"] = model_module.validators model_module = type(model_module).load_from_checkpoint( - file_path, map_location="cpu", strict=False, **(model_module.hparams) + file_path, map_location="cpu", strict=False, **hparams ) if cfg.load_confidence_from_trunk: @@ -189,12 +232,16 @@ def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912, PLR # Create wandb logger loggers = [] if wandb: + wandb_id = wandb.get("id") + wandb_resume = "allow" if wandb_id else None wdb_logger = WandbLogger( name=wandb["name"], group=wandb["name"], save_dir=cfg.output, project=wandb["project"], entity=wandb["entity"], + id=wandb_id, + resume=wandb_resume, log_model=False, ) loggers.append(wdb_logger) diff --git a/src/boltz/data/crop/boltz.py b/src/boltz/data/crop/boltz.py index 1d2e20e25..91533fb49 100644 --- a/src/boltz/data/crop/boltz.py +++ b/src/boltz/data/crop/boltz.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + from dataclasses import replace from typing import Optional @@ -28,7 +51,7 @@ def pick_random_token( The selected token. """ - return tokens[random.integers(len(tokens))] + return tokens[random.randint(len(tokens))] def pick_chain_token( @@ -255,11 +278,11 @@ def crop( # noqa: PLR0915 interface = interfaces[interface_id] query = pick_interface_token(valid_tokens, interface, random) elif valid_interfaces.size: - idx = random.integers(len(valid_interfaces)) + idx = random.randint(len(valid_interfaces)) interface = valid_interfaces[idx] query = pick_interface_token(valid_tokens, interface, random) else: - idx = random.integers(len(valid_chains)) + idx = random.randint(len(valid_chains)) chain_id = valid_chains[idx]["asym_id"] query = pick_chain_token(valid_tokens, chain_id, random) @@ -354,7 +377,8 @@ def crop( # noqa: PLR0915 # We switch to the res_idx instead of the token_idx to always # include all tokens from modified residues or from ligands. min_idx = max_idx = center_token["res_idx"] - while new_tokens.size < neighborhood_size_to_use: + target_size = min(neighborhood_size_to_use, max_token_set.size) + while new_tokens.size < target_size: min_idx = min_idx - 1 max_idx = max_idx + 1 new_tokens = max_token_set diff --git a/src/boltz/data/feature/featurizerv2.py b/src/boltz/data/feature/featurizerv2.py index f8a2231ec..52ddff887 100644 --- a/src/boltz/data/feature/featurizerv2.py +++ b/src/boltz/data/feature/featurizerv2.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +# fmt: off import math from collections import deque from typing import Optional @@ -276,9 +299,23 @@ def construct_paired_msa( # noqa: C901, PLR0915, PLR0912 first_residues["res_type"][idx] == const.token_ids["UNK"] ) if (np.all(is_met) and np.all(is_msa_unk)) or np.all(is_unk): - msa_residues[first_start:first_end]["res_type"] = residues[ + # BUG FIX: The original code mutated data.msa[chain_id].residues + # in-place via the msa_residues view. MSA is a frozen dataclass + # but its numpy arrays are still mutable. If construct_paired_msa + # is called twice on the same Tokenized (e.g. when retrying a + # failed sample), the second call sees residues already patched by + # the first — the MET/UNK mismatch check passes silently with + # corrupted data. Copy-on-write: create a new MSA with copied + # residues so the caller's data is never modified. + patched_residues = msa_residues.copy() + patched_residues[first_start:first_end][ "res_type" - ] + ] = residues["res_type"] + msa[chain_id] = MSA( + sequences=data.msa[chain_id].sequences, + deletions=data.msa[chain_id].deletions, + residues=patched_residues, + ) else: print( warning, @@ -320,8 +357,20 @@ def construct_paired_msa( # noqa: C901, PLR0915, PLR0912 ) # Keep track of the sequences available per chain, keeping the original - # order of the sequences in the MSA to favor the best matching sequences - visited = {(c, s) for c, items in taxonomy_map for s in items} + # order of the sequences in the MSA to favor the best matching sequences. + # + # BUG FIX: The original comprehension was {(c, s) for c, items in taxonomy_map + # for s in items}. After sorted(), taxonomy_map is a list of (taxon, pairs) + # tuples, so `c` was the taxon key and `s` was a (chain_id, seq_idx) tuple, + # producing {(taxon, (chain_id, seq_idx)), ...}. The downstream check + # `(c, i) not in visited` uses (chain_id, seq_idx) — an (int, int) pair that + # never matched the (int, tuple) entries, so `visited` never filtered anything. + # Example: taxonomy_map = [(9606, [(0, 1), (1, 1)])] produced + # visited = {(9606, (0, 1)), (9606, (1, 1))} but the check looked for (0, 1) + # which was never found. All taxonomy-assigned sequences leaked into the + # `available` pool, causing duplicate MSA rows that waste capacity and dilute + # the paired co-evolutionary signal. + visited = {s for _, items in taxonomy_map for s in items} available = {} for c in chain_ids: available[c] = deque( @@ -426,12 +475,23 @@ def construct_paired_msa( # noqa: C901, PLR0915, PLR0912 value_type=numba.types.int64, ) for chain_id, chain_msa in msa.items(): - chain_deletions = chain_msa.deletions for sequence in chain_msa.sequences: seq_idx = sequence["seq_idx"] del_start = sequence["del_start"] del_end = sequence["del_end"] - chain_deletions = chain_deletions[del_start:del_end] + # BUG FIX: The original code assigned + # chain_deletions = chain_msa.deletions (full array, outer loop) + # chain_deletions = chain_deletions[del_start:del_end] (inner loop) + # On each inner iteration, chain_deletions was re-sliced from the + # *previous* iteration's already-shrunken slice, not from the original + # full array. Example: if seq 0 (query) has del_start=0, del_end=0, + # then after seq 0 chain_deletions = full_array[0:0] = []. For seq 1, + # [][del_start_1:del_end_1] = [] regardless of del_start_1/del_end_1, + # silently dropping ALL deletion data for every subsequent sequence. + # This affects all inference predictions — has_deletion, deletion_value, + # and deletion_mean features are all zero, removing structural loop/ + # insertion information from the MSA module's input. + chain_deletions = chain_msa.deletions[del_start:del_end] for deletion_data in chain_deletions: res_idx = deletion_data["res_idx"] deletion_values = deletion_data["deletion"] @@ -1156,6 +1216,7 @@ def process_atom_features( frame_data = [] resolved_frame_data = [] atom_to_token = [] + atom_counts_per_token = [] # consumed by distributed featurizer for sharding token_to_rep_atom = [] # index on cropped atom table r_set_to_rep_atom = [] disto_coords_ensemble = [] @@ -1203,6 +1264,7 @@ def process_atom_features( # Map atoms to token indices ref_space_uid.extend([new_idx] * token["atom_num"]) atom_to_token.extend([token_id] * token["atom_num"]) + atom_counts_per_token.append(token["atom_num"]) # Add atom data start = token["atom_idx"] @@ -1389,6 +1451,12 @@ def process_atom_features( atom_idx += len(token_atoms) disto_coords_ensemble = np.array(disto_coords_ensemble) # (N_TOK, N_ENS, 3) + if disto_coords_ensemble.ndim != 3: + msg = ( + f"disto_coords_ensemble has shape {disto_coords_ensemble.shape} " + f"(expected 3D: N_TOK x N_ENS x 3) for record {data.record.id}" + ) + raise ValueError(msg) # Compute ensemble distogram L = len(data.tokens) @@ -1439,6 +1507,7 @@ def process_atom_features( resolved_mask = from_numpy(atom_data["is_present"]) pad_mask = torch.ones(len(atom_data), dtype=torch.float) atom_to_token = torch.tensor(atom_to_token, dtype=torch.long) + atom_counts_per_token = torch.tensor(atom_counts_per_token, dtype=torch.long) token_to_rep_atom = torch.tensor(token_to_rep_atom, dtype=torch.long) r_set_to_rep_atom = torch.tensor(r_set_to_rep_atom, dtype=torch.long) token_to_center_atom = torch.tensor(token_to_center_atom, dtype=torch.long) @@ -1531,6 +1600,7 @@ def process_atom_features( pad_len = max_tokens - token_to_rep_atom.shape[0] if pad_len > 0: atom_to_token = pad_dim(atom_to_token, 1, pad_len) + atom_counts_per_token = pad_dim(atom_counts_per_token, 0, pad_len) token_to_rep_atom = pad_dim(token_to_rep_atom, 0, pad_len) r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 0, pad_len) token_to_center_atom = pad_dim(token_to_center_atom, 0, pad_len) @@ -1553,6 +1623,7 @@ def process_atom_features( "coords": coords, "atom_pad_mask": pad_mask, "atom_to_token": atom_to_token, + "atom_counts_per_token": atom_counts_per_token, "token_to_rep_atom": token_to_rep_atom, "r_set_to_rep_atom": r_set_to_rep_atom, "token_to_center_atom": token_to_center_atom, diff --git a/src/boltz/data/feature/featurizerv2_train.py b/src/boltz/data/feature/featurizerv2_train.py index 5e08441cf..92a4c55de 100644 --- a/src/boltz/data/feature/featurizerv2_train.py +++ b/src/boltz/data/feature/featurizerv2_train.py @@ -1,5 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + import math -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import networkx as nx import numba @@ -29,6 +52,14 @@ ) from boltz.model.modules.utils import center_random_augmentation +# The training callers (trainingv2.py) pass either np.random (the module, which +# exposes the legacy RandomState API as module-level functions) or an explicit +# np.random.RandomState instance. Both support .randint(), .choice(), .random(). +# np.random.Generator (used by the inference featurizer) is a different API that +# uses .integers() instead of .randint(). The type annotation below reflects +# what the training pipeline actually provides. +NumpyRNG = Union[np.random.RandomState, "np.random"] + #################################################################################################### # HELPERS #################################################################################################### @@ -58,7 +89,7 @@ def sample_d( min_d: float, max_d: float, n_samples: int, - random: np.random.Generator, + random: NumpyRNG, ) -> np.ndarray: """Generate samples from a 1/d distribution between min_d and max_d. @@ -216,7 +247,7 @@ def dummy_msa(residues: np.ndarray) -> MSA: def construct_paired_msa( # noqa: C901, PLR0915, PLR0912 data: Input, - random: np.random.Generator, + random: NumpyRNG, max_seqs: int, max_pairs: int = 8192, max_total: int = 16384, @@ -279,9 +310,23 @@ def construct_paired_msa( # noqa: C901, PLR0915, PLR0912 first_residues["res_type"][idx] == const.token_ids["UNK"] ) if (np.all(is_met) and np.all(is_msa_unk)) or np.all(is_unk): - msa_residues[first_start:first_end]["res_type"] = residues[ + # BUG FIX: The original code mutated data.msa[chain_id].residues + # in-place via the msa_residues view. MSA is a frozen dataclass + # but its numpy arrays are still mutable. If construct_paired_msa + # is called twice on the same Input (e.g. when retrying a failed + # sample), the second call sees residues already patched by the + # first — the MET/UNK mismatch check passes silently with + # corrupted data. Copy-on-write: create a new MSA with copied + # residues so the caller's data is never modified. + patched_residues = msa_residues.copy() + patched_residues[first_start:first_end][ "res_type" - ] + ] = residues["res_type"] + msa[chain_id] = MSA( + sequences=data.msa[chain_id].sequences, + deletions=data.msa[chain_id].deletions, + residues=patched_residues, + ) else: print( warning, @@ -323,8 +368,20 @@ def construct_paired_msa( # noqa: C901, PLR0915, PLR0912 ) # Keep track of the sequences available per chain, keeping the original - # order of the sequences in the MSA to favor the best matching sequences - visited = {(c, s) for c, items in taxonomy_map for s in items} + # order of the sequences in the MSA to favor the best matching sequences. + # + # BUG FIX: The original comprehension was {(c, s) for c, items in taxonomy_map + # for s in items}. After sorted(), taxonomy_map is a list of (taxon, pairs) + # tuples, so `c` was the taxon key and `s` was a (chain_id, seq_idx) tuple, + # producing {(taxon, (chain_id, seq_idx)), ...}. The downstream check + # `(c, i) not in visited` uses (chain_id, seq_idx) — an (int, int) pair that + # never matched the (int, tuple) entries, so `visited` never filtered anything. + # Example: taxonomy_map = [(9606, [(0, 1), (1, 1)])] produced + # visited = {(9606, (0, 1)), (9606, (1, 1))} but the check looked for (0, 1) + # which was never found. All taxonomy-assigned sequences leaked into the + # `available` pool, causing duplicate MSA rows that waste capacity and dilute + # the paired co-evolutionary signal. + visited = {s for _, items in taxonomy_map for s in items} available = {} for c in chain_ids: available[c] = [ @@ -426,7 +483,6 @@ def construct_paired_msa( # noqa: C901, PLR0915, PLR0912 # Map (chain_id, seq_idx, res_idx) to deletion deletions = {} for chain_id, chain_msa in msa.items(): - chain_deletions = chain_msa.deletions for sequence in chain_msa.sequences: del_start = sequence["del_start"] del_end = sequence["del_end"] @@ -586,7 +642,7 @@ def _prepare_msa_arrays_inner( #################################################################################################### -def select_subset_from_mask(mask, p, random: np.random.Generator) -> np.ndarray: +def select_subset_from_mask(mask, p, random: NumpyRNG) -> np.ndarray: num_true = np.sum(mask) v = random.geometric(p) + 1 k = min(v, num_true) @@ -616,7 +672,7 @@ def get_range_bin(value: float, range_dict: Dict[Tuple[float, float], int], defa def process_token_features( # noqa: C901, PLR0915, PLR0912 data: Input, - random: np.random.Generator, + random: NumpyRNG, max_tokens: Optional[int] = None, binder_pocket_conditioned_prop: Optional[float] = 0.0, contact_conditioned_prop: Optional[float] = 0.0, @@ -946,7 +1002,7 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912 assert len(pairs) > 0 - pair = random.choice(pairs) + pair = pairs[random.randint(len(pairs))] token_1_mask = token_data["token_idx"] == pair[0] token_2_mask = token_data["token_idx"] == pair[1] @@ -994,7 +1050,7 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912 pairs.append((token_1["token_idx"], token_2["token_idx"])) if len(pairs) > 0: - pair = random.choice(pairs) + pair = pairs[random.randint(len(pairs))] token_1_mask = token_data["token_idx"] == pair[0] token_2_mask = token_data["token_idx"] == pair[1] @@ -1101,7 +1157,7 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912 def process_atom_features( data: Input, - random: np.random.Generator, + random: NumpyRNG, ensemble_features: Dict, molecules: Dict[str, Mol], atoms_per_window_queries: int = 32, @@ -1144,6 +1200,7 @@ def process_atom_features( frame_data = [] resolved_frame_data = [] atom_to_token = [] + atom_counts_per_token = [] # consumed by distributed featurizer for sharding token_to_rep_atom = [] # index on cropped atom table r_set_to_rep_atom = [] disto_coords_ensemble = [] @@ -1190,6 +1247,7 @@ def process_atom_features( # Map atoms to token indices ref_space_uid.extend([new_idx] * token["atom_num"]) atom_to_token.extend([token_id] * token["atom_num"]) + atom_counts_per_token.append(token["atom_num"]) # Add atom data start = token["atom_idx"] @@ -1375,6 +1433,12 @@ def process_atom_features( atom_idx += len(token_atoms) disto_coords_ensemble = np.array(disto_coords_ensemble) # (N_TOK, N_ENS, 3) + if disto_coords_ensemble.ndim != 3: + msg = ( + f"disto_coords_ensemble has shape {disto_coords_ensemble.shape} " + f"(expected 3D: N_TOK x N_ENS x 3) for record {data.record.id}" + ) + raise ValueError(msg) # Compute ensemble distogram L = len(data.tokens) @@ -1386,11 +1450,6 @@ def process_atom_features( # Only use a sampled structures to create distogram idx_list = ensemble_features["ensemble_ref_idxs"] - # Save a numpy array of the distogram to a path - # pdb_id = data.record.id - # with open(f"/afs/csail.mit.edu/u/m/mreveiz/rbg/temp_while_cp_rsg/temp/disto_outs_atlas10ns/{pdb_id}_disto_coords_ensemble.npy", "wb") as f: - # np.save(f, disto_coords_ensemble) - # Create distogram disto_target = torch.zeros(L, L, len(idx_list), num_bins) # TODO1 @@ -1430,6 +1489,7 @@ def process_atom_features( resolved_mask = from_numpy(atom_data["is_present"]) pad_mask = torch.ones(len(atom_data), dtype=torch.float) atom_to_token = torch.tensor(atom_to_token, dtype=torch.long) + atom_counts_per_token = torch.tensor(atom_counts_per_token, dtype=torch.long) token_to_rep_atom = torch.tensor(token_to_rep_atom, dtype=torch.long) r_set_to_rep_atom = torch.tensor(r_set_to_rep_atom, dtype=torch.long) bfactor = from_numpy(atom_data["bfactor"].copy()) @@ -1521,6 +1581,7 @@ def process_atom_features( pad_len = max_tokens - token_to_rep_atom.shape[0] if pad_len > 0: atom_to_token = pad_dim(atom_to_token, 1, pad_len) + atom_counts_per_token = pad_dim(atom_counts_per_token, 0, pad_len) token_to_rep_atom = pad_dim(token_to_rep_atom, 0, pad_len) r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 0, pad_len) disto_target = pad_dim(pad_dim(disto_target, 0, pad_len), 1, pad_len) @@ -1542,6 +1603,7 @@ def process_atom_features( "coords": coords, "atom_pad_mask": pad_mask, "atom_to_token": atom_to_token, + "atom_counts_per_token": atom_counts_per_token, "token_to_rep_atom": token_to_rep_atom, "r_set_to_rep_atom": r_set_to_rep_atom, "disto_target": disto_target, @@ -1559,7 +1621,7 @@ def process_atom_features( def process_msa_features( data: Input, - random: np.random.Generator, + random: NumpyRNG, max_seqs_batch: int, max_seqs: int, max_tokens: Optional[int] = None, @@ -1572,7 +1634,7 @@ def process_msa_features( ---------- data : Input The input to the model. - random : np.random.Generator + random : NumpyRNG The random number generator. max_seqs : int The maximum number of MSA sequences. @@ -1753,7 +1815,7 @@ def process_symmetry_features(cropped: Input, symmetries: Dict) -> Dict[str, Ten def process_ensemble_features( data: Input, - random: np.random.Generator, + random: NumpyRNG, num_ensembles: int, ensemble_sample_replacement: bool, fix_single_ensemble: bool, @@ -1764,7 +1826,7 @@ def process_ensemble_features( ---------- data : Input The input to the model. - random : np.random.Generator + random : NumpyRNG The random number generator. num_ensembles : int The maximum number of ensembles to sample. @@ -1792,7 +1854,7 @@ def process_ensemble_features( else: if ensemble_sample_replacement: # Used in training - ensemble_ref_idxs = random.integers(0, s_ensemble_num, (num_ensembles,)) + ensemble_ref_idxs = random.randint(0, s_ensemble_num, (num_ensembles,)) else: # Used in validation if s_ensemble_num < num_ensembles: @@ -1817,7 +1879,7 @@ class Boltz2Featurizer: def process( self, data: Input, - random: np.random.Generator, + random: NumpyRNG, molecules: Dict[str, Mol], training: bool, max_seqs: int, @@ -1876,7 +1938,7 @@ def process( # Compute random number of sequences if training and max_seqs is not None: if random.random() > single_sequence_prop: - max_seqs_batch = random.integers(1, max_seqs + 1) + max_seqs_batch = random.randint(1, max_seqs + 1) else: max_seqs_batch = 1 else: diff --git a/src/boltz/data/module/training.py b/src/boltz/data/module/training.py index 36583b6cf..bbdf545a1 100644 --- a/src/boltz/data/module/training.py +++ b/src/boltz/data/module/training.py @@ -66,6 +66,7 @@ class DataConfig: binder_pocket_sampling_geometric_p: float = 0.0 val_batch_size: int = 1 compute_constraint_features: bool = False + max_data_retries: int = 5 @dataclass @@ -208,6 +209,7 @@ def __init__( binder_pocket_sampling_geometric_p: Optional[float] = 0.0, return_symmetries: Optional[bool] = False, compute_constraint_features: bool = False, + max_data_retries: int = 5, ) -> None: """Initialize the training dataset.""" super().__init__() @@ -230,6 +232,8 @@ def __init__( self.binder_pocket_sampling_geometric_p = binder_pocket_sampling_geometric_p self.return_symmetries = return_symmetries self.compute_constraint_features = compute_constraint_features + self.max_data_retries = max_data_retries + self._fallback_depth = 0 self.samples = [] for dataset in datasets: records = dataset.manifest.records @@ -238,6 +242,20 @@ def __init__( iterator = dataset.sampler.sample(records, np.random) self.samples.append(iterator) + def _raise_or_return_item_0(self, e: Exception) -> dict[str, Tensor]: + if self.max_data_retries <= 0: + raise e + if self._fallback_depth >= self.max_data_retries: + raise RuntimeError( + f"Data loading failed {self.max_data_retries} consecutive times. " f"Last error: {e}" + ) from e + self._fallback_depth += 1 + try: + fallback_idx = np.random.randint(0, len(self)) + return self.__getitem__(fallback_idx) + finally: + self._fallback_depth -= 1 + def __getitem__(self, idx: int) -> dict[str, Tensor]: """Get an item from the dataset. @@ -265,18 +283,18 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: # Get the structure try: input_data = load_input(sample.record, dataset.target_dir, dataset.msa_dir) - except Exception as e: - print( + except Exception as e: # noqa: BLE001 + print( # noqa: T201 f"Failed to load input for {sample.record.id} with error {e}. Skipping." ) - return self.__getitem__(idx) + return self._raise_or_return_item_0(e) # Tokenize structure try: tokenized = dataset.tokenizer.tokenize(input_data) - except Exception as e: - print(f"Tokenizer failed on {sample.record.id} with error {e}. Skipping.") - return self.__getitem__(idx) + except Exception as e: # noqa: BLE001 + print(f"Tokenizer failed on {sample.record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Compute crop try: @@ -289,9 +307,9 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: chain_id=sample.chain_id, interface_id=sample.interface_id, ) - except Exception as e: - print(f"Cropper failed on {sample.record.id} with error {e}. Skipping.") - return self.__getitem__(idx) + except Exception as e: # noqa: BLE001 + print(f"Cropper failed on {sample.record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Check if there are tokens if len(tokenized.tokens) == 0: @@ -318,9 +336,9 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: binder_pocket_sampling_geometric_p=self.binder_pocket_sampling_geometric_p, compute_constraint_features=self.compute_constraint_features, ) - except Exception as e: - print(f"Featurizer failed on {sample.record.id} with error {e}. Skipping.") - return self.__getitem__(idx) + except Exception as e: # noqa: BLE001 + print(f"Featurizer failed on {sample.record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) return features @@ -360,6 +378,7 @@ def __init__( binder_pocket_conditioned_prop: Optional[float] = 0.0, binder_pocket_cutoff: Optional[float] = 6.0, compute_constraint_features: bool = False, + max_data_retries: int = 5, ) -> None: """Initialize the validation dataset.""" super().__init__() @@ -383,6 +402,22 @@ def __init__( self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop self.binder_pocket_cutoff = binder_pocket_cutoff self.compute_constraint_features = compute_constraint_features + self.max_data_retries = max_data_retries + self._fallback_depth = 0 + + def _raise_or_return_item_0(self, e: Exception) -> dict[str, Tensor]: + if self.max_data_retries <= 0: + raise e + if self._fallback_depth >= self.max_data_retries: + raise RuntimeError( + f"Data loading failed {self.max_data_retries} consecutive times. " f"Last error: {e}" + ) from e + self._fallback_depth += 1 + try: + fallback_idx = np.random.randint(0, len(self)) + return self.__getitem__(fallback_idx) + finally: + self._fallback_depth -= 1 def __getitem__(self, idx: int) -> dict[str, Tensor]: """Get an item from the dataset. @@ -413,16 +448,16 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: # Get the structure try: input_data = load_input(record, dataset.target_dir, dataset.msa_dir) - except Exception as e: - print(f"Failed to load input for {record.id} with error {e}. Skipping.") - return self.__getitem__(0) + except Exception as e: # noqa: BLE001 + print(f"Failed to load input for {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Tokenize structure try: tokenized = dataset.tokenizer.tokenize(input_data) - except Exception as e: - print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") - return self.__getitem__(0) + except Exception as e: # noqa: BLE001 + print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Compute crop try: @@ -433,9 +468,9 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: random=self.random, max_atoms=self.max_atoms, ) - except Exception as e: - print(f"Cropper failed on {record.id} with error {e}. Skipping.") - return self.__getitem__(0) + except Exception as e: # noqa: BLE001 + print(f"Cropper failed on {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Check if there are tokens if len(tokenized.tokens) == 0: @@ -466,9 +501,9 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: only_ligand_binder_pocket=True, compute_constraint_features=self.compute_constraint_features, ) - except Exception as e: - print(f"Featurizer failed on {record.id} with error {e}. Skipping.") - return self.__getitem__(0) + except Exception as e: # noqa: BLE001 + print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) return features @@ -542,17 +577,11 @@ def __init__(self, cfg: DataConfig) -> None: val_records = [] # Filter training records - train_records = [ - record - for record in train_records - if all(f.filter(record) for f in cfg.filters) - ] + train_records = [record for record in train_records if all(f.filter(record) for f in cfg.filters)] # Filter training records if data_config.filters is not None: train_records = [ - record - for record in train_records - if all(f.filter(record) for f in data_config.filters) + record for record in train_records if all(f.filter(record) for f in data_config.filters) ] # Create train dataset @@ -616,6 +645,7 @@ def __init__(self, cfg: DataConfig) -> None: binder_pocket_sampling_geometric_p=cfg.binder_pocket_sampling_geometric_p, return_symmetries=cfg.return_train_symmetries, compute_constraint_features=cfg.compute_constraint_features, + max_data_retries=cfg.max_data_retries, ) self._val_set = ValidationDataset( datasets=train if cfg.overfit is not None else val, @@ -637,6 +667,7 @@ def __init__(self, cfg: DataConfig) -> None: binder_pocket_conditioned_prop=cfg.val_binder_pocket_conditioned_prop, binder_pocket_cutoff=cfg.binder_pocket_cutoff, compute_constraint_features=cfg.compute_constraint_features, + max_data_retries=cfg.max_data_retries, ) def setup(self, stage: Optional[str] = None) -> None: diff --git a/src/boltz/data/module/trainingv2.py b/src/boltz/data/module/trainingv2.py index 99891fe47..156a0f5ba 100644 --- a/src/boltz/data/module/trainingv2.py +++ b/src/boltz/data/module/trainingv2.py @@ -1,4 +1,25 @@ -import json +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off from collections import defaultdict from dataclasses import dataclass from pathlib import Path @@ -98,6 +119,10 @@ class DataConfigV2: moldir: Optional[str] = None compute_frames: bool = False bfactor_md_correction: Optional[bool] = False + val_skip_sample_threshold_tokens: Optional[int] = None + val_skip_sample_threshold_atoms: Optional[int] = None + val_skip_sample_threshold_seqs: Optional[int] = None + max_data_retries: int = 5 @dataclass @@ -248,7 +273,7 @@ def load_templates( # Sample for training, pick firsts for validation if training: - max_chain_templates = random.integers(1, max_chain_templates + 1) + max_chain_templates = random.randint(1, max_chain_templates + 1) template_indices = torch.randperm(len(template_ids)) template_indices = template_indices[:max_chain_templates] template_ids = [template_ids[idx.item()] for idx in template_indices] @@ -368,6 +393,7 @@ def __init__( msa_sampling: bool = False, compute_frames: bool = False, bfactor_md_correction: bool = False, + max_data_retries: int = 5, ) -> None: """Initialize the training dataset. @@ -420,11 +446,28 @@ def __init__( self.overfit = overfit self.compute_frames = compute_frames self.bfactor_md_correction = bfactor_md_correction + self.max_data_retries = max_data_retries + self._fallback_depth = 0 self.samples: list[pd.DataFrame] = [] for d in datasets: self.samples.append(d.samples[:overfit] if overfit else d.samples) + def _raise_or_return_item_0(self, e: Exception) -> dict[str, Tensor]: + if self.max_data_retries <= 0: + raise e + if self._fallback_depth >= self.max_data_retries: + raise RuntimeError( + f"Data loading failed {self.max_data_retries} consecutive times. " + f"Last error: {e}" + ) from e + self._fallback_depth += 1 + try: + fallback_idx = np.random.randint(0, len(self)) + return self.__getitem__(fallback_idx) + finally: + self._fallback_depth -= 1 + def __getitem__(self, idx: int) -> dict[str, Tensor]: """Get an item from the dataset. @@ -434,8 +477,8 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: The sampled data features. """ - # Set a random state - random = np.random.default_rng() + # Use global NumPy RNG state (v1-style), seeded by test/train harness. + random = np.random # Pick a random dataset dataset_idx = random.choice(len(self.datasets), p=self.probs) @@ -472,15 +515,15 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: try: structure = load_structure(record, dataset.struct_dir) except Exception as e: # noqa: BLE001 - print(f"Failed to load input for {record.id} with error {e}. Skipping.") - return self.__getitem__(idx) + print(f"Failed to load input for {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Tokenize structure try: tokenized = dataset.tokenizer.tokenize(structure) except Exception as e: # noqa: BLE001 - print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") - return self.__getitem__(idx) + print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Compute crop try: @@ -497,8 +540,8 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: msg = "No tokens in cropped structure." raise ValueError(msg) # noqa: TRY301 except Exception as e: # noqa: BLE001 - print(f"Cropper failed on {record.id} with error {e}. Skipping.") - return self.__getitem__(idx) + print(f"Cropper failed on {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Get unique chain ids chain_ids = set(tokenized.tokens["asym_id"]) @@ -511,8 +554,8 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: msa_dir=dataset.msa_dir, ) except Exception as e: # noqa: BLE001 - print(f"MSA loading failed for {record.id} with error {e}. Skipping.") - return self.__getitem__(0) + print(f"MSA loading failed for {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Load templates templates = FileNotFoundError @@ -528,7 +571,7 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: random=random, ) except Exception as e: # noqa: BLE001 - print( + print( # noqa: T201 f"Template loading failed for {record.id} with error {e}. Using no templates." ) templates = None @@ -549,8 +592,8 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: mol_names = mol_names - set(molecules.keys()) molecules.update(load_molecules(self.moldir, mol_names)) except Exception as e: # noqa: BLE001 - print(f"Molecule loading failed for {record.id} with error {e}. Skipping.") - return self.__getitem__(0) + print(f"Molecule loading failed for {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Finalize input data input_data = InputTraining( @@ -597,11 +640,11 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: bfactor_md_correction=self.bfactor_md_correction, ) except Exception as e: # noqa: BLE001 - print(f"Featurizer failed on {record.id} with error {e}. Skipping.") + print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201 import traceback traceback.print_exc() - return self.__getitem__(idx) + return self._raise_or_return_item_0(e) features["pdb_id"] = record.id return features @@ -651,6 +694,7 @@ def __init__( no_template_prob: float = 0.0, compute_frames: bool = False, bfactor_md_correction: bool = False, + max_data_retries: int = 5, ) -> None: """Initialize the training dataset. @@ -674,6 +718,7 @@ def __init__( self.max_tokens = max_tokens self.max_seqs = max_seqs self.seed = seed + self.random = np.random if overfit else np.random.RandomState(self.seed) self.pad_to_max_tokens = pad_to_max_tokens self.pad_to_max_atoms = pad_to_max_atoms self.pad_to_max_seqs = pad_to_max_seqs @@ -695,6 +740,23 @@ def __init__( self.no_template_prob = no_template_prob self.compute_frames = compute_frames self.bfactor_md_correction = bfactor_md_correction + self.max_data_retries = max_data_retries + self._fallback_depth = 0 + + def _raise_or_return_item_0(self, e: Exception) -> dict[str, torch.Tensor]: + if self.max_data_retries <= 0: + raise e + if self._fallback_depth >= self.max_data_retries: + raise RuntimeError( + f"Data loading failed {self.max_data_retries} consecutive times. " + f"Last error: {e}" + ) from e + self._fallback_depth += 1 + try: + fallback_idx = np.random.randint(0, len(self)) + return self.__getitem__(fallback_idx) + finally: + self._fallback_depth -= 1 def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: """Get an item from the dataset. @@ -705,9 +767,8 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: The sampled data features. """ - # Set random state - seed = self.seed if self.overfit is None else None - random = np.random.default_rng(seed) + # Use persistent RNG state (v1-style semantics). + random = self.random # Pick dataset based on idx for idx_dataset, dataset in enumerate(self.datasets): # noqa: B007 @@ -726,15 +787,15 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: try: structure = load_structure(record, dataset.struct_dir) except Exception as e: # noqa: BLE001 - print(f"Failed to load input for {record.id} with error {e}. Skipping.") - return self.__getitem__(0) + print(f"Failed to load input for {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Tokenize structure try: tokenized = dataset.tokenizer.tokenize(structure) except Exception as e: # noqa: BLE001 - print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") - return self.__getitem__(0) + print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Get unique chains chain_ids = set(np.unique(tokenized.tokens["asym_id"]).tolist()) @@ -743,8 +804,8 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: try: msas = load_msas(chain_ids, record, dataset.msa_dir) except Exception as e: # noqa: BLE001 - print(f"MSA loading failed for {record.id} with error {e}. Skipping.") - return self.__getitem__(0) + print(f"MSA loading failed for {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Load templates if self.use_templates and dataset.template_dir is not None: @@ -759,7 +820,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: random=random, ) except Exception as e: # noqa: BLE001 - print( + print( # noqa: T201 f"Template loading failed for {record.id} with error {e}. Using no templates." ) templates = None @@ -778,8 +839,8 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: mol_names = mol_names - set(molecules.keys()) molecules.update(load_molecules(self.moldir, mol_names)) except Exception as e: # noqa: BLE001 - print(f"Molecule loading failed for {record.id} with error {e}. Skipping.") - return self.__getitem__(0) + print(f"Molecule loading failed for {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Finalize input data input_data = InputTraining( @@ -798,8 +859,8 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: molecules=molecules, random=random, training=False, - max_atoms=None, - max_tokens=None, + max_atoms=self.max_atoms if self.pad_to_max_atoms else None, + max_tokens=self.max_tokens if self.pad_to_max_tokens else None, max_seqs=self.max_seqs, pad_to_max_seqs=self.pad_to_max_seqs, atoms_per_window_queries=self.atoms_per_window_queries, @@ -827,8 +888,8 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: ) except Exception as e: # noqa: BLE001 - print(f"Featurizer failed on {record.id} with error {e}. Skipping.") - return self.__getitem__(0) + print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) # Add dataset idx idx_dataset = torch.tensor([idx_dataset], dtype=torch.long) @@ -1051,6 +1112,7 @@ def __init__(self, cfg: DataConfigV2) -> None: no_template_prob=cfg.no_template_prob_train, compute_frames=cfg.compute_frames, bfactor_md_correction=cfg.bfactor_md_correction, + max_data_retries=cfg.max_data_retries, ) self._val_set = ValidationDataset( datasets=train if cfg.overfit is not None else val, @@ -1081,6 +1143,7 @@ def __init__(self, cfg: DataConfigV2) -> None: no_template_prob=cfg.no_template_prob_val, compute_frames=cfg.compute_frames, bfactor_md_correction=cfg.bfactor_md_correction, + max_data_retries=cfg.max_data_retries, ) def setup(self, stage: Optional[str] = None) -> None: # noqa: ARG002 (unused) diff --git a/src/boltz/data/parse/schema.py b/src/boltz/data/parse/schema.py index 9ffe7ad01..eb97edd4f 100644 --- a/src/boltz/data/parse/schema.py +++ b/src/boltz/data/parse/schema.py @@ -1,3 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + from collections.abc import Mapping from dataclasses import dataclass from pathlib import Path @@ -9,7 +31,7 @@ from chembl_structure_pipeline.exclude_flag import exclude_flag from chembl_structure_pipeline.standardizer import standardize_mol from rdkit import Chem, rdBase -from rdkit.Chem import AllChem, HybridizationType +from rdkit.Chem import AllChem, Descriptors, HybridizationType from rdkit.Chem.MolStandardize import rdMolStandardize from rdkit.Chem.rdchem import BondStereo, Conformer, Mol from rdkit.Chem.rdDistGeom import GetMoleculeBoundsMatrix @@ -20,7 +42,6 @@ from boltz.data.mol import load_molecules from boltz.data.parse.mmcif import parse_mmcif from boltz.data.parse.pdb import parse_pdb - from boltz.data.types import ( AffinityInfo, Atom, @@ -317,13 +338,8 @@ def compute_geometry_constraints(mol: Mol, idx_map): doTriangleSmoothing=True, useMacrocycle14config=False, ) - bonds = set( - tuple(sorted(b)) for b in mol.GetSubstructMatches(Chem.MolFromSmarts("*~*")) - ) - angles = set( - tuple(sorted([a[0], a[2]])) - for a in mol.GetSubstructMatches(Chem.MolFromSmarts("*~*~*")) - ) + bonds = {tuple(sorted(b)) for b in mol.GetSubstructMatches(Chem.MolFromSmarts("*~*"))} + angles = {tuple(sorted([a[0], a[2]])) for a in mol.GetSubstructMatches(Chem.MolFromSmarts("*~*~*"))} constraints = [] for i, j in zip(*np.triu_indices(mol.GetNumAtoms(), k=1)): @@ -341,18 +357,11 @@ def compute_geometry_constraints(mol: Mol, idx_map): def compute_chiral_atom_constraints(mol, idx_map): constraints = [] - if all([atom.HasProp("_CIPRank") for atom in mol.GetAtoms()]): - for center_idx, orientation in Chem.FindMolChiralCenters( - mol, includeUnassigned=False - ): + if all(atom.HasProp("_CIPRank") for atom in mol.GetAtoms()): + for center_idx, orientation in Chem.FindMolChiralCenters(mol, includeUnassigned=False): center = mol.GetAtomWithIdx(center_idx) - neighbors = [ - (neighbor.GetIdx(), int(neighbor.GetProp("_CIPRank"))) - for neighbor in center.GetNeighbors() - ] - neighbors = sorted( - neighbors, key=lambda neighbor: neighbor[1], reverse=True - ) + neighbors = [(neighbor.GetIdx(), int(neighbor.GetProp("_CIPRank"))) for neighbor in center.GetNeighbors()] + neighbors = sorted(neighbors, key=lambda neighbor: neighbor[1], reverse=True) neighbors = tuple(neighbor[0] for neighbor in neighbors) is_r = orientation == "R" @@ -389,7 +398,7 @@ def compute_chiral_atom_constraints(mol, idx_map): def compute_stereo_bond_constraints(mol, idx_map): constraints = [] - if all([atom.HasProp("_CIPRank") for atom in mol.GetAtoms()]): + if all(atom.HasProp("_CIPRank") for atom in mol.GetAtoms()): for bond in mol.GetBonds(): stereo = bond.GetStereo() if stereo in {BondStereo.STEREOE, BondStereo.STEREOZ}: @@ -402,18 +411,14 @@ def compute_stereo_bond_constraints(mol, idx_map): for neighbor in mol.GetAtomWithIdx(start_atom_idx).GetNeighbors() if neighbor.GetIdx() != end_atom_idx ] - start_neighbors = sorted( - start_neighbors, key=lambda neighbor: neighbor[1], reverse=True - ) + start_neighbors = sorted(start_neighbors, key=lambda neighbor: neighbor[1], reverse=True) start_neighbors = [neighbor[0] for neighbor in start_neighbors] end_neighbors = [ (neighbor.GetIdx(), int(neighbor.GetProp("_CIPRank"))) for neighbor in mol.GetAtomWithIdx(end_atom_idx).GetNeighbors() if neighbor.GetIdx() != start_atom_idx ] - end_neighbors = sorted( - end_neighbors, key=lambda neighbor: neighbor[1], reverse=True - ) + end_neighbors = sorted(end_neighbors, key=lambda neighbor: neighbor[1], reverse=True) end_neighbors = [neighbor[0] for neighbor in end_neighbors] is_e = stereo == BondStereo.STEREOE @@ -453,9 +458,7 @@ def compute_stereo_bond_constraints(mol, idx_map): def compute_flatness_constraints(mol, idx_map): planar_double_bond_smarts = Chem.MolFromSmarts("[C;X3;^2](*)(*)=[C;X3;^2](*)(*)") aromatic_ring_5_smarts = Chem.MolFromSmarts("[ar5^2]1[ar5^2][ar5^2][ar5^2][ar5^2]1") - aromatic_ring_6_smarts = Chem.MolFromSmarts( - "[ar6^2]1[ar6^2][ar6^2][ar6^2][ar6^2][ar6^2]1" - ) + aromatic_ring_6_smarts = Chem.MolFromSmarts("[ar6^2]1[ar6^2][ar6^2][ar6^2][ar6^2][ar6^2]1") planar_double_bond_constraints = [] aromatic_ring_5_constraints = [] @@ -467,14 +470,10 @@ def compute_flatness_constraints(mol, idx_map): ) for match in mol.GetSubstructMatches(aromatic_ring_5_smarts): if all(i in idx_map for i in match): - aromatic_ring_5_constraints.append( - ParsedPlanarRing5Constraint(atom_idxs=tuple(idx_map[i] for i in match)) - ) + aromatic_ring_5_constraints.append(ParsedPlanarRing5Constraint(atom_idxs=tuple(idx_map[i] for i in match))) for match in mol.GetSubstructMatches(aromatic_ring_6_smarts): if all(i in idx_map for i in match): - aromatic_ring_6_constraints.append( - ParsedPlanarRing6Constraint(atom_idxs=tuple(idx_map[i] for i in match)) - ) + aromatic_ring_6_constraints.append(ParsedPlanarRing6Constraint(atom_idxs=tuple(idx_map[i] for i in match))) return ( planar_double_bond_constraints, @@ -674,9 +673,7 @@ def parse_ccd_residue( pos = (0, 0, 0) ref_atom = ref_mol.GetAtoms()[0] - chirality_type = const.chirality_type_ids.get( - str(ref_atom.GetChiralTag()), unk_chirality - ) + chirality_type = const.chirality_type_ids.get(str(ref_atom.GetChiralTag()), unk_chirality) atom = ParsedAtom( name=ref_atom.GetProp("name"), element=ref_atom.GetAtomicNum(), @@ -725,9 +722,7 @@ def parse_ccd_residue( element = atom.GetAtomicNum() ref_coords = conformer.GetAtomPosition(atom.GetIdx()) ref_coords = (ref_coords.x, ref_coords.y, ref_coords.z) - chirality_type = const.chirality_type_ids.get( - str(atom.GetChiralTag()), unk_chirality - ) + chirality_type = const.chirality_type_ids.get(str(atom.GetChiralTag()), unk_chirality) # Get PDB coordinates, if any coords = (0, 0, 0) @@ -770,8 +765,8 @@ def parse_ccd_residue( rdkit_bounds_constraints = compute_geometry_constraints(ref_mol, idx_map) chiral_atom_constraints = compute_chiral_atom_constraints(ref_mol, idx_map) stereo_bond_constraints = compute_stereo_bond_constraints(ref_mol, idx_map) - planar_bond_constraints, planar_ring_5_constraints, planar_ring_6_constraints = ( - compute_flatness_constraints(ref_mol, idx_map) + planar_bond_constraints, planar_ring_5_constraints, planar_ring_6_constraints = compute_flatness_constraints( + ref_mol, idx_map ) unk_prot_id = const.unk_token_ids["PROTEIN"] @@ -888,9 +883,7 @@ def parse_polymer( coords=coords, conformer=ref_coords, is_present=atom_is_present, - chirality=const.chirality_type_ids.get( - str(ref_atom.GetChiralTag()), unk_chirality - ), + chirality=const.chirality_type_ids.get(str(ref_atom.GetChiralTag()), unk_chirality), ) ) @@ -926,9 +919,7 @@ def parse_polymer( ) -def token_spec_to_ids( - chain_name, residue_index_or_atom_name, chain_to_idx, atom_idx_map, chains -): +def token_spec_to_ids(chain_name, residue_index_or_atom_name, chain_to_idx, atom_idx_map, chains): if chains[chain_name].type == const.chain_type_ids["NONPOLYMER"]: # Non-polymer chains are indexed by atom name _, _, atom_idx = atom_idx_map[(chain_name, 0, residue_index_or_atom_name)] @@ -1063,10 +1054,7 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 raise ValueError(msg) if chain_name_to_entity_type[binder] != "ligand": - msg = ( - f"Chain {binder} is not a ligand! " - "Affinity is currently only supported for ligands." - ) + msg = f"Chain {binder} is not a ligand! " "Affinity is currently only supported for ligands." raise ValueError(msg) affinity_ligands.add(binder) @@ -1199,16 +1187,21 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 ref_mol = get_mol(code, ccd, mol_dir) if affinity: - affinity_mw = AllChem.Descriptors.MolWt(ref_mol) + # Fixed to used 'rdkit.Chem.Descriptors.MolWt' as module 'rdkit.Chem.AllChem' has no attribute 'Descriptors' + affinity_mw = Descriptors.MolWt(ref_mol) # Add error and warning messaging when computing affinity with ligands too large if ref_mol.GetNumAtoms() > 128: - msg = f"The ligand for affinity is too large, ligands with more than 128 atoms are not " \ - f"supported in the affinity prediction module" + msg = ( + "The ligand for affinity is too large, ligands with more than 128 atoms are not " + "supported in the affinity prediction module" + ) raise ValueError(msg) elif ref_mol.GetNumAtoms() > 56: - print("WARNING: the ligand used for affinity calculation is larger than 56 heavy-atoms, which " - "was the maximum during training, therefore the affinity output might be inaccurate.") + print( + "WARNING: the ligand used for affinity calculation is larger than 56 heavy-atoms, which " + "was the maximum during training, therefore the affinity output might be inaccurate." + ) # Parse residue residue = parse_ccd_residue( @@ -1229,9 +1222,7 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 affinity_mw=affinity_mw, ) - assert not items[0][entity_type].get("cyclic", False), ( - "Cyclic flag is not supported for ligands" - ) + assert not items[0][entity_type].get("cyclic", False), "Cyclic flag is not supported for ligands" elif (entity_type == "ligand") and ("smiles" in items[0][entity_type]): seq = items[0][entity_type]["smiles"] @@ -1248,10 +1239,7 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 for atom, can_idx in zip(mol.GetAtoms(), canonical_order): atom_name = atom.GetSymbol().upper() + str(can_idx + 1) if len(atom_name) > 4: - msg = ( - f"{seq} has an atom with a name longer than " - f"4 characters: {atom_name}." - ) + msg = f"{seq} has an atom with a name longer than " f"4 characters: {atom_name}." raise ValueError(msg) atom.SetProp("name", atom_name) @@ -1265,13 +1253,16 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 if affinity: # Add error and warning messaging when computing affinity with ligands too large if mol_no_h.GetNumAtoms() > 128: - msg = f"The ligand for affinity is too large, ligands with more than 128 atoms are not supported in the affinity prediction module" + msg = "The ligand for affinity is too large, ligands with more than 128 atoms are not supported in the affinity prediction module" raise ValueError(msg) elif mol_no_h.GetNumAtoms() > 56: - print("WARNING: the ligand used for affinity calculation is larger than 56 heavy-atoms, " - "which was the maximum during training, therefore the affinity output might be inaccurate.") + print( + "WARNING: the ligand used for affinity calculation is larger than 56 heavy-atoms, " + "which was the maximum during training, therefore the affinity output might be inaccurate." + ) - affinity_mw = AllChem.Descriptors.MolWt(mol_no_h) if affinity else None + # Fixed to used rdkit.Chem.Descriptors.MolWt as module 'rdkit.Chem.AllChem' has no attribute 'Descriptors' + affinity_mw = Descriptors.MolWt(mol_no_h) if affinity else None extra_mols[f"LIG{ligand_id}"] = mol_no_h residue = parse_ccd_residue( name=f"LIG{ligand_id}", @@ -1290,9 +1281,7 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 affinity_mw=affinity_mw, ) - assert not items[0][entity_type].get("cyclic", False), ( - "Cyclic flag is not supported for ligands" - ) + assert not items[0][entity_type].get("cyclic", False), "Cyclic flag is not supported for ligands" else: msg = f"Invalid entity type: {entity_type}" @@ -1404,10 +1393,7 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 for constraint in res.rdkit_bounds_constraints: rdkit_bounds_constraint_data.append( # noqa: PERF401 ( - tuple( - c_atom_idx + atom_idx - for c_atom_idx in constraint.atom_idxs - ), + tuple(c_atom_idx + atom_idx for c_atom_idx in constraint.atom_idxs), constraint.is_bond, constraint.is_angle, constraint.upper_bound, @@ -1418,10 +1404,7 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 for constraint in res.chiral_atom_constraints: chiral_atom_constraint_data.append( # noqa: PERF401 ( - tuple( - c_atom_idx + atom_idx - for c_atom_idx in constraint.atom_idxs - ), + tuple(c_atom_idx + atom_idx for c_atom_idx in constraint.atom_idxs), constraint.is_reference, constraint.is_r, ) @@ -1430,10 +1413,7 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 for constraint in res.stereo_bond_constraints: stereo_bond_constraint_data.append( # noqa: PERF401 ( - tuple( - c_atom_idx + atom_idx - for c_atom_idx in constraint.atom_idxs - ), + tuple(c_atom_idx + atom_idx for c_atom_idx in constraint.atom_idxs), constraint.is_check, constraint.is_e, ) @@ -1441,32 +1421,17 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 if res.planar_bond_constraints is not None: for constraint in res.planar_bond_constraints: planar_bond_constraint_data.append( # noqa: PERF401 - ( - tuple( - c_atom_idx + atom_idx - for c_atom_idx in constraint.atom_idxs - ), - ) + (tuple(c_atom_idx + atom_idx for c_atom_idx in constraint.atom_idxs),) ) if res.planar_ring_5_constraints is not None: for constraint in res.planar_ring_5_constraints: planar_ring_5_constraint_data.append( # noqa: PERF401 - ( - tuple( - c_atom_idx + atom_idx - for c_atom_idx in constraint.atom_idxs - ), - ) + (tuple(c_atom_idx + atom_idx for c_atom_idx in constraint.atom_idxs),) ) if res.planar_ring_6_constraints is not None: for constraint in res.planar_ring_6_constraints: planar_ring_6_constraint_data.append( # noqa: PERF401 - ( - tuple( - c_atom_idx + atom_idx - for c_atom_idx in constraint.atom_idxs - ), - ) + (tuple(c_atom_idx + atom_idx for c_atom_idx in constraint.atom_idxs),) ) for bond in res.bonds: @@ -1516,7 +1481,7 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 for constraint in constraints: if "bond" in constraint: if "atom1" not in constraint["bond"] or "atom2" not in constraint["bond"]: - msg = f"Bond constraint was not properly specified" + msg = "Bond constraint was not properly specified" raise ValueError(msg) c1, r1, a1 = tuple(constraint["bond"]["atom1"]) @@ -1525,29 +1490,24 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 c2, r2, a2 = atom_idx_map[(c2, r2 - 1, a2)] # 1-indexed connections.append((c1, c2, r1, r2, a1, a2)) elif "pocket" in constraint: - if ( - "binder" not in constraint["pocket"] - or "contacts" not in constraint["pocket"] - ): - msg = f"Pocket constraint was not properly specified" + if "binder" not in constraint["pocket"] or "contacts" not in constraint["pocket"]: + msg = "Pocket constraint was not properly specified" raise ValueError(msg) if len(pocket_constraints) > 0 and not boltz_2: - msg = f"Only one pocket binders is supported in Boltz-1!" + msg = "Only one pocket binders is supported in Boltz-1!" raise ValueError(msg) max_distance = constraint["pocket"].get("max_distance", 6.0) if max_distance != 6.0 and not boltz_2: - msg = f"Max distance != 6.0 is not supported in Boltz-1!" + msg = "Max distance != 6.0 is not supported in Boltz-1!" raise ValueError(msg) binder = constraint["pocket"]["binder"] binder = chain_to_idx[binder] contacts = [] - for chain_name, residue_index_or_atom_name in constraint["pocket"][ - "contacts" - ]: + for chain_name, residue_index_or_atom_name in constraint["pocket"]["contacts"]: contact = token_spec_to_ids( chain_name, residue_index_or_atom_name, @@ -1560,15 +1520,12 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 force = constraint["pocket"].get("force", False) pocket_constraints.append((binder, contacts, max_distance, force)) elif "contact" in constraint: - if ( - "token1" not in constraint["contact"] - or "token2" not in constraint["contact"] - ): - msg = f"Contact constraint was not properly specified" + if "token1" not in constraint["contact"] or "token2" not in constraint["contact"]: + msg = "Contact constraint was not properly specified" raise ValueError(msg) if not boltz_2: - msg = f"Contact constraint is not supported in Boltz-1!" + msg = "Contact constraint is not supported in Boltz-1!" raise ValueError(msg) max_distance = constraint["contact"].get("max_distance", 6.0) @@ -1630,20 +1587,16 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 if template_chain_ids is not None and not isinstance(template_chain_ids, list): template_chain_ids = [template_chain_ids] - if ( - template_chain_ids is not None - and chain_ids is not None - ): - - if len(template_chain_ids) == len(chain_ids): - if len(template_chain_ids) > 0 and len(chain_ids) > 0: - matched = True - else: - msg = ( - "When providing both the chain_id and template_id, the number of" - "template_ids provided must match the number of chain_ids!" - ) - raise ValueError(msg) + if template_chain_ids is not None and chain_ids is not None: + if len(template_chain_ids) == len(chain_ids): + if len(template_chain_ids) > 0 and len(chain_ids) > 0: + matched = True + else: + msg = ( + "When providing both the chain_id and template_id, the number of" + "template_ids provided must match the number of chain_ids!" + ) + raise ValueError(msg) # Get relevant chains ids if chain_ids is None: @@ -1651,10 +1604,7 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 for chain_id in chain_ids: if chain_id not in protein_chains: - msg = ( - f"Chain {chain_id} assigned for template" - f"{template_id} is not one of the protein chains!" - ) + msg = f"Chain {chain_id} assigned for template" f"{template_id} is not one of the protein chains!" raise ValueError(msg) # Get relevant template chain ids @@ -1675,9 +1625,7 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 compute_interfaces=False, ) template_proteins = { - str(c["name"]) - for c in parsed_template.data.chains - if c["mol_type"] == const.chain_type_ids["PROTEIN"] + str(c["name"]) for c in parsed_template.data.chains if c["mol_type"] == const.chain_type_ids["PROTEIN"] } if template_chain_ids is None: template_chain_ids = list(template_proteins) @@ -1733,24 +1681,12 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 chains = np.array(chain_data, dtype=Chain) interfaces = np.array([], dtype=Interface) mask = np.ones(len(chain_data), dtype=bool) - rdkit_bounds_constraints = np.array( - rdkit_bounds_constraint_data, dtype=RDKitBoundsConstraint - ) - chiral_atom_constraints = np.array( - chiral_atom_constraint_data, dtype=ChiralAtomConstraint - ) - stereo_bond_constraints = np.array( - stereo_bond_constraint_data, dtype=StereoBondConstraint - ) - planar_bond_constraints = np.array( - planar_bond_constraint_data, dtype=PlanarBondConstraint - ) - planar_ring_5_constraints = np.array( - planar_ring_5_constraint_data, dtype=PlanarRing5Constraint - ) - planar_ring_6_constraints = np.array( - planar_ring_6_constraint_data, dtype=PlanarRing6Constraint - ) + rdkit_bounds_constraints = np.array(rdkit_bounds_constraint_data, dtype=RDKitBoundsConstraint) + chiral_atom_constraints = np.array(chiral_atom_constraint_data, dtype=ChiralAtomConstraint) + stereo_bond_constraints = np.array(stereo_bond_constraint_data, dtype=StereoBondConstraint) + planar_bond_constraints = np.array(planar_bond_constraint_data, dtype=PlanarBondConstraint) + planar_ring_5_constraints = np.array(planar_ring_5_constraint_data, dtype=PlanarRing5Constraint) + planar_ring_6_constraints = np.array(planar_ring_6_constraint_data, dtype=PlanarRing6Constraint) if boltz_2: atom_data = [(a[0], a[3], a[5], 0.0, 1.0) for a in atom_data] @@ -1803,9 +1739,7 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912 ) chain_infos.append(chain_info) - options = InferenceOptions( - pocket_constraints=pocket_constraints, contact_constraints=contact_constraints - ) + options = InferenceOptions(pocket_constraints=pocket_constraints, contact_constraints=contact_constraints) record = Record( id=name, structure=struct_info, diff --git a/src/boltz/data/tokenize/boltz.py b/src/boltz/data/tokenize/boltz.py index 06c126458..8dbe54782 100644 --- a/src/boltz/data/tokenize/boltz.py +++ b/src/boltz/data/tokenize/boltz.py @@ -80,6 +80,7 @@ def tokenize(self, data: Input) -> Tokenized: # Filter to valid chains only chains = struct.chains[struct.mask] + has_cyclic_period = "cyclic_period" in chains.dtype.names for chain in chains: # Get residue indices @@ -122,13 +123,12 @@ def tokenize(self, data: Input) -> Tokenized: disto_coords=d_coords, resolved_mask=is_present, disto_mask=is_disto_present, - cyclic_period=chain["cyclic_period"], + cyclic_period=chain["cyclic_period"] if has_cyclic_period else 0, ) token_data.append(token_astuple(token)) # Update atom_idx to token_idx - atom_to_token.update( - dict.fromkeys(range(atom_start, atom_end), token_idx)) + atom_to_token.update(dict.fromkeys(range(atom_start, atom_end), token_idx)) token_idx += 1 @@ -165,9 +165,7 @@ def tokenize(self, data: Input) -> Tokenized: disto_coords=atom_coords[i], resolved_mask=is_present, disto_mask=is_present, - cyclic_period=chain[ - "cyclic_period" - ], # Enforced to be False in chain parser + cyclic_period=(chain["cyclic_period"] if has_cyclic_period else 0), ) token_data.append(token_astuple(token)) @@ -180,10 +178,7 @@ def tokenize(self, data: Input) -> Tokenized: # Add atom-atom bonds from ligands for bond in struct.bonds: - if ( - bond["atom_1"] not in atom_to_token - or bond["atom_2"] not in atom_to_token - ): + if bond["atom_1"] not in atom_to_token or bond["atom_2"] not in atom_to_token: continue token_bond = ( atom_to_token[bond["atom_1"]], @@ -193,10 +188,7 @@ def tokenize(self, data: Input) -> Tokenized: # Add connection bonds (covalent) for conn in struct.connections: - if ( - conn["atom_1"] not in atom_to_token - or conn["atom_2"] not in atom_to_token - ): + if conn["atom_1"] not in atom_to_token or conn["atom_2"] not in atom_to_token: continue token_bond = ( atom_to_token[conn["atom_1"]], diff --git a/src/boltz/data/types.py b/src/boltz/data/types.py index ca330a718..767d9d8b5 100644 --- a/src/boltz/data/types.py +++ b/src/boltz/data/types.py @@ -1,3 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + import json from dataclasses import asdict, dataclass from pathlib import Path @@ -250,12 +272,8 @@ def remove_invalid_chains(self) -> "Structure": # noqa: PLR0915 # Update the residue new_res = res.copy() new_res["atom_idx"] = atom_idx - new_res["atom_center"] = ( - atom_idx + new_res["atom_center"] - res["atom_idx"] - ) - new_res["atom_disto"] = ( - atom_idx + new_res["atom_disto"] - res["atom_idx"] - ) + new_res["atom_center"] = atom_idx + new_res["atom_center"] - res["atom_idx"] + new_res["atom_disto"] = atom_idx + new_res["atom_disto"] - res["atom_idx"] residues.append(new_res) res_map[res_start + j] = res_idx res_idx += 1 @@ -333,6 +351,34 @@ class StructureV2(NumpySerializable): ensemble: np.ndarray pocket: Optional[np.ndarray] = None + @classmethod + def load(cls: "StructureV2", path: Path) -> "StructureV2": + """Load a StructureV2 from an NPZ file, patching legacy data. + + Parameters + ---------- + path : Path + The path to the file. + + Returns + ------- + StructureV2 + The loaded structure. + + """ + data = dict(np.load(path, allow_pickle=True)) + # Legacy structures may lack cyclic_period; the tokenizer and + # distributed featurizer both rely on this field existing on chains. + chains = data["chains"] + if "cyclic_period" not in chains.dtype.names: + new_dtype = chains.dtype.descr + [("cyclic_period", "i4")] + new_chains = np.empty(chains.shape, dtype=new_dtype) + for name in chains.dtype.names: + new_chains[name] = chains[name] + new_chains["cyclic_period"] = 0 + data["chains"] = new_chains + return cls(**data) + def remove_invalid_chains(self) -> "StructureV2": # noqa: PLR0915 """Remove invalid chains. @@ -380,12 +426,8 @@ def remove_invalid_chains(self) -> "StructureV2": # noqa: PLR0915 # Update the residue new_res = res.copy() new_res["atom_idx"] = atom_idx - new_res["atom_center"] = ( - atom_idx + new_res["atom_center"] - res["atom_idx"] - ) - new_res["atom_disto"] = ( - atom_idx + new_res["atom_disto"] - res["atom_idx"] - ) + new_res["atom_center"] = atom_idx + new_res["atom_center"] - res["atom_idx"] + new_res["atom_disto"] = atom_idx + new_res["atom_disto"] - res["atom_idx"] residues.append(new_res) res_map[res_start + j] = res_idx res_idx += 1 @@ -522,12 +564,8 @@ class InterfaceInfo: class InferenceOptions: """InferenceOptions datatype.""" - pocket_constraints: Optional[ - list[tuple[int, list[tuple[int, int]], float, bool]] - ] = None - contact_constraints: Optional[ - list[tuple[tuple[int, int], tuple[int, int], float, bool]] - ] = None + pocket_constraints: Optional[list[tuple[int, list[tuple[int, int]], float, bool]]] = None + contact_constraints: Optional[list[tuple[tuple[int, int], tuple[int, int], float, bool]]] = None @dataclass(frozen=True) diff --git a/src/boltz/data/write/writer.py b/src/boltz/data/write/writer.py index 984be2ae5..d24319cbe 100644 --- a/src/boltz/data/write/writer.py +++ b/src/boltz/data/write/writer.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + import json from dataclasses import asdict, replace from pathlib import Path @@ -66,20 +89,24 @@ def write_on_batch_end( # Get the predictions coords = prediction["coords"] - coords = coords.unsqueeze(0) - pad_masks = prediction["masks"] - # Get ranking - if "confidence_score" in prediction: - argsort = torch.argsort(prediction["confidence_score"], descending=True) - idx_to_rank = {idx.item(): rank for rank, idx in enumerate(argsort)} - # Handles cases where confidence summary is False - else: - idx_to_rank = {i: i for i in range(len(records))} + # Get ranking (will be calculated per record) + confidence_scores = prediction.get("confidence_score", None) + + # Calculate diffusion samples per record + n_records = len(records) + n_diffusion_samples = coords.shape[0] // n_records # Iterate over the records - for record, coord, pad_mask in zip(records, coords, pad_masks): + for record_idx, record in enumerate(records): + pad_mask = pad_masks[record_idx] + + # Get the diffusion samples for this record + start_idx = record_idx * n_diffusion_samples + end_idx = start_idx + n_diffusion_samples + record_coords = coords[start_idx:end_idx] + # Load the structure path = self.data_dir / f"{record.id}.npz" if self.boltz2: @@ -96,9 +123,17 @@ def write_on_batch_end( # Remove masked chains completely structure = structure.remove_invalid_chains() - for model_idx in range(coord.shape[0]): + # Calculate ranking for this record's diffusion samples + if confidence_scores is not None: + record_confidence_scores = confidence_scores[start_idx:end_idx] + argsort = torch.argsort(record_confidence_scores, descending=True) + idx_to_rank = {idx.item(): rank for rank, idx in enumerate(argsort)} + else: + idx_to_rank = {i: i for i in range(record_coords.shape[0])} + + for model_idx in range(record_coords.shape[0]): # Get model coord - model_coord = coord[model_idx] + model_coord = record_coords[model_idx] # Unpad coord_unpad = model_coord[pad_mask.bool()] coord_unpad = coord_unpad.cpu().numpy() @@ -245,7 +280,7 @@ def write_on_batch_end( / f"pde_{record.id}_model_{idx_to_rank[model_idx]}.npz" ) np.savez_compressed(path, pde=pde.cpu().numpy()) - + # Save embeddings if self.write_embeddings and "s" in prediction and "z" in prediction: s = prediction["s"].cpu().numpy() diff --git a/src/boltz/distributed/README.md b/src/boltz/distributed/README.md new file mode 100644 index 000000000..6f99fada1 --- /dev/null +++ b/src/boltz/distributed/README.md @@ -0,0 +1,92 @@ +# Fold-CP: A Context Parallelism Framework for Biomolecular Modeling + +Context parallelism (CP) for distributed inference and training for +biomolecular folding models across multiple GPUs using a 2D CP mesh combined +with data parallelism, demonstrated with the Boltz model. + +⚠️ **Note**
+This repository demonstrates a proof-of-concept implementation of Fold-CP with Boltz-2.
+Learn more about Fold-CP here: https://research.nvidia.com/labs/dbr/assets/data/manuscripts/fold_cp.pdf + +For an introduction to the Boltz family of biomolecular interaction models, +see the [public Boltz repository](https://github.com/jwohlwend/boltz). + +## Copyright and License Compliance + +- The context parallel code is licensed under the terms and conditions as written in [the license file](../../../licenses/LICENSE) + +- The original Boltz code is licensed under their respective MIT license (See the [third-party-attr.txt](../../../licenses/third-party-attr.txt)) + +- This project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use + +## Key Capabilities + +- **Distributed inference** with DTensor context parallelism +- **Distributed training** with DTensor context parallelism +- Combined data parallelism (DP) and context parallelism (CP) +- Multiple attention kernel backends: cuEquivariance, trifast, FlexAttention +- Support for BF16, BF16-mixed, TF32, and FP32 precision modes + +## Requirements + +- Python 3.10+ +- PyTorch 2.9+ with CUDA support +- Multiple NVIDIA GPUs (CP requires at least 4 GPUs; CP size must be a + perfect square) +- `torchrun` or SLURM `srun` for multi-process launching + +## Distributed Inference + +Distributed inference uses `src/boltz/distributed/main.py predict` to run +structure prediction with DTensor context parallelism. + +```bash +torchrun \ + --nnodes 1 \ + --nproc_per_node 4 \ + src/boltz/distributed/main.py predict \ + /path/to/preprocessed_data \ + --out_dir ./predictions \ + --size_dp 1 \ + --size_cp 4 \ + --recycling_steps 3 \ + --sampling_steps 200 \ + --diffusion_samples 5 +``` + +For full documentation of all options, the inference pipeline stages, and +differences from serial prediction, see the +[Distributed Inference Guide](docs/boltz2_cp_prediction.md). + +## Distributed Training + +Distributed training uses `src/boltz/distributed/train.py` with a YAML +config file to run training with DTensor context parallelism. + +```bash +torchrun \ + --nnodes 1 \ + --nproc_per_node 8 \ + src/boltz/distributed/train.py \ + scripts/train/configs/structurev2_small_cp.yaml \ + parallel_size.size_dp=2 \ + parallel_size.size_cp=4 \ + output= +``` + +For full documentation of the configuration hierarchy, CP-specific settings, +CLI overrides, and differences from serial training, see the +[Distributed Training Guide](docs/boltz2_cp_training.md). + +## Contributing + +This project is currently not accepting contributions. + +## Security + +See [SECURITY.md](SECURITY.md) for vulnerability reporting instructions. + +## License + +This project is licensed under the MIT License. See [LICENSE](LICENSE) for +details. diff --git a/src/boltz/distributed/__init__.py b/src/boltz/distributed/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/src/boltz/distributed/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/src/boltz/distributed/comm.py b/src/boltz/distributed/comm.py new file mode 100644 index 000000000..05a902825 --- /dev/null +++ b/src/boltz/distributed/comm.py @@ -0,0 +1,782 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from copy import deepcopy +from typing import Optional + +import torch +import torch.distributed as dist + +from boltz.distributed.utils import LayoutMap, get_group_rank_from_axial_shift + + +class One2OneComm: + def __init__(self, group: dist.ProcessGroup, rank_send_to: int, rank_recv_from: int, parity: Optional[bool] = None): + """ + Initializes a One2OneComm instance for point-to-point communication. + + Arguments: + group (dist.ProcessGroup): The process group that provides the communication. + rank_send_to (int): The rank within the group to send data to. + rank_recv_from (int): The rank within the group to receive data from. + parity (bool): If True, issue [isend, irecv]; otherwise issue [irecv, isend] + in batch_isend_irecv. If None, parity is `rank % 2`, where `rank` is the + calling rank's index in the WORLD group. If self.is_self_comm is True, i.e., + `rank_send_to == rank and rank_recv_from == rank`, this argument has no effect. + The motivation of setting the parity is to avoid potential deadlocks in NCCL backend + when doing batch_isend_irecv. + + Note: rank_send_to and rank_recv_from must be ranks within the input process group. + + Raises: + ValueError: If rank_send_to or rank_recv_from is not a valid rank within the group. + """ + self.group = group + + self.rank = dist.get_rank(self.group) + self.world_size = dist.get_world_size(self.group) + + if rank_send_to >= self.world_size: + raise ValueError(f"rank_send_to >= world_size {self.world_size}") + if rank_recv_from >= self.world_size: + raise ValueError(f"rank_recv_from >= world_size {self.world_size}") + # make all comm functions no-ops if self-send and self-recv + is_self_send = rank_send_to == self.rank + is_self_recv = rank_recv_from == self.rank + if is_self_send != is_self_recv: + raise ValueError( + "Asymmetric send/recv tends to cause NCCL backend deadlocking " + f"and it's not supported: is_self_send: {is_self_send}, " + f"is_self_recv: {is_self_recv}" + ) + self.is_self_comm = is_self_send + self._rank_in_group_send_to = rank_send_to + self._rank_in_group_recv_from = rank_recv_from + + self.parity = parity + + if not self.is_self_comm: + # convert to global rank + self.rank_send_to = dist.get_global_rank(self.group, rank_send_to) + self.rank_recv_from = dist.get_global_rank(self.group, rank_recv_from) + + if self.parity is None: + self.parity = self.rank % 2 + + self._queue_send_recv = [] + self._work_to_finish = None + + def __deepcopy__(self, memo): + """ + Create a deep copy of the One2OneComm instance. + + This method enables the One2OneComm object to be deep copied using the copy.deepcopy() function. + It creates a new One2OneComm instance with the same communication parameters as the original. + + Args: + memo (dict): Dictionary used by deepcopy to avoid circular references. + + Returns: + One2OneComm: A new One2OneComm instance with identical configuration to the original. + """ + return One2OneComm(self.group, self._rank_in_group_send_to, self._rank_in_group_recv_from, self.parity) + + def _prep_batch_isend_irecv( + self, + to_send: torch.Tensor, + to_recv: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Prepare tensors and communication operations for batch send/receive. + + This private method sets up the tensor operations and queues the communication + operations for later dispatch. It handles both self-communication (where send + and receive are on the same rank) and inter-rank communication. + + Args: + to_send (torch.Tensor): The tensor to be sent to the target rank. + to_recv (Optional[torch.Tensor], optional): The tensor buffer to receive data into. + If None, a new tensor with the same shape and properties as `to_send` will be created. + + Returns: + torch.Tensor: The tensor that will contain the received data. For self-communication, + this is either a clone of `to_send` or `to_recv` with data copied from `to_send`. + For inter-rank communication, this is the buffer where received data will be stored. + + Note: + - For self-communication (`is_self_comm=True`), the data is immediately copied + and no communication operations are queued. + - For inter-rank communication, P2P operations are queued based on parity to + avoid potential deadlocks in NCCL backend. + - The order of send/receive operations depends on the parity flag to ensure + consistent ordering across ranks. + """ + if self.is_self_comm: + # the copy semantics remain even if self.is_self_comm + if to_recv is None: + ans = to_send.detach().clone() + else: + ans = to_recv + ans.copy_(to_send) + return ans + + ans = torch.empty_like(to_send) if to_recv is None else to_recv + + if self.parity: + # TODO: verify if the order of P2POp calls matter + # and consolidate the two branches' P2POp calls if not + send_op = dist.P2POp( + dist.isend, + to_send, + self.rank_send_to, + group=self.group, + ) + recv_op = dist.P2POp( + dist.irecv, + ans, + self.rank_recv_from, + group=self.group, + ) + self._queue_send_recv.append(send_op) + self._queue_send_recv.append(recv_op) + else: + recv_op = dist.P2POp( + dist.irecv, + ans, + self.rank_recv_from, + group=self.group, + ) + send_op = dist.P2POp( + dist.isend, + to_send, + self.rank_send_to, + group=self.group, + ) + self._queue_send_recv.append(recv_op) + self._queue_send_recv.append(send_op) + return ans + + def _dispatch(self): + """ + Dispatch all queued communication operations. + + This private method initiates all point-to-point communication operations that have been + queued by previous calls to `_prep_batch_isend_irecv`. The operations are dispatched + asynchronously using `dist.batch_isend_irecv`. + + Raises: + RuntimeError: If there are already unfinished communications in the queue when trying + to dispatch new operations. This prevents overlapping communication operations + which could lead to undefined behavior. + + Note: + - For self-communication (`is_self_comm=True`), this method does nothing as no + actual network communication is required. + - After dispatching, the work handles are stored in `_work_to_finish` for later + synchronization via `wait_until_finished()`. + - This method should only be called after communication operations have been + queued via `_prep_batch_isend_irecv()`. + """ + if self.is_self_comm: + return + if self._work_to_finish is not None: + raise RuntimeError("There is unfinished communication in queue. Cannot dispatch new communication") + self._work_to_finish = dist.batch_isend_irecv(self._queue_send_recv) + + def wait_until_finished(self): + """ + Wait for all dispatched communication operations to complete. + + This method blocks until all previously dispatched communication operations have + finished. It ensures data consistency by synchronizing all pending send/receive + operations before proceeding. + + Raises: + RuntimeError: If called when there are no unfinished communications in the queue. + This typically happens when `wait_until_finished()` is called without a + preceding `_dispatch()` call. + + Note: + - For self-communication (`is_self_comm=True`), this method returns immediately + as no actual network communication needs to be synchronized. + - After completion, the internal communication queue and work handles are reset, + allowing new communication operations to be queued. + - This method must be called after `_dispatch()` to ensure communication + operations have completed before accessing the received data. + """ + if self.is_self_comm: + return + if self._work_to_finish is None: + raise RuntimeError("Cannot wait without unfinished communication in queue") + for work in self._work_to_finish: + work.wait() + self._work_to_finish = None + self._queue_send_recv = [] + + def enqueue_to_dispatch( + self, + to_send: torch.Tensor, + to_recv: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Enqueue a communication operation and immediately dispatch it. + + This method combines the functionality of `_prep_batch_isend_irecv()` and `_dispatch()` + in a single call. It prepares the tensors for communication, queues the operations, + and immediately dispatches them for execution. + + Args: + to_send (torch.Tensor): The tensor to be sent to the target rank. + to_recv (Optional[torch.Tensor], optional): The tensor buffer to receive data into. + If None, a new tensor with the same shape and properties as `to_send` will be created. + + Returns: + torch.Tensor: The tensor that will contain the received data. For self-communication, + this contains the copied data immediately. For inter-rank communication, this is + the buffer where data will be received once the communication completes. + + Note: + - For self-communication (`is_self_comm=True`), the data is immediately available + in the returned tensor. + - For inter-rank communication, you must call `wait_until_finished()` before + accessing the data in the returned tensor to ensure the communication has completed. + - This is a convenience method that internally calls `_prep_batch_isend_irecv()` + followed by `_dispatch()`. + + Example: + ```python + comm = One2OneComm(group, send_rank, recv_rank) + recv_tensor = comm.enqueue_to_dispatch(send_tensor) + comm.wait_until_finished() # Wait for completion before using recv_tensor + ``` + """ + recv = self._prep_batch_isend_irecv(to_send, to_recv) + if self.is_self_comm: + return recv + self._dispatch() + return recv + + +class TransposeComm(One2OneComm): + def __init__(self, process_group: dist.ProcessGroup, group_layout: LayoutMap): + if group_layout.shape is None: + raise ValueError("group_layout must have a shape") + + self.world_size = dist.get_world_size(process_group) + if self.world_size != group_layout.numel: + raise ValueError("Inconsistent world_size with the num elements of group_layout") + + if len(group_layout.shape) != 2: + raise ValueError(f"{self.__class__} only supports 2D group layout") + + if group_layout.shape[0] != group_layout.shape[1]: + raise ValueError(f"group_layout.shape {group_layout.shape} is not square") + + self.group_layout = group_layout + + self.global_rank = dist.get_rank() + self.group_rank = dist.get_rank(process_group) + self.rank_coords: tuple[int, int] = self.group_layout.unravel(self.group_rank) + + transpose_group_rank = self.group_layout(self.rank_coords[::-1]) + self.transpose_rank = dist.get_global_rank(process_group, transpose_group_rank) + + self.parity_transpose = self.rank_coords[0] < self.rank_coords[1] + + # Call One2OneComm's __init__ instead of creating a separate comm instance + super().__init__(process_group, transpose_group_rank, transpose_group_rank, parity=self.parity_transpose) + + def __deepcopy__(self, memo): + return TransposeComm(self.group, self.group_layout) + + +def ternary_parity(my_rank: int, send_rank: int, recv_rank: int) -> bool: + """ + Determines parity for communication ordering based on rank relationships. + + Used to establish consistent communication ordering between three ranks to avoid deadlocks. + Returns True if the current rank is less than both the send and receive ranks. + + Args: + my_rank: Current process rank + send_rank: Rank to send data to + recv_rank: Rank to receive data from + + Returns: + bool: True if current rank is less than both send and receive ranks, False otherwise + """ + return my_rank < min(send_rank, recv_rank) + + +class Ring2DComm: + """ + Implements communication primitives for distributed operations on a 2D grid of devices. + + This class provides general-purpose ring communication patterns for operations like + TriangleMultiplication and OuterProductMean across a 2D grid of devices. Unlike + Ring2DCommTriAttn which is specialized for triangular attention, this class provides + more general ring communication patterns. + + The communication patterns implemented include: + 1. Transpose communication for matrix operations + 2. Row-wise ring communication (left shifts) + 3. Column-wise ring communication (up shifts) + + Parameters + ---------- + group_2d : dist.ProcessGroup + The process group representing the 2D grid of devices. This should include + all processes in the distributed computation. + group_col : dist.ProcessGroup + A subprocess group that provides communication between ranks in the same column. + group_layout : LayoutMap + A mapping from the 2D grid index to the flattened index of the devices on the 2D grid. + Must represent a square grid (same dimensions in both axes). + + Notes + ----- + The class implements various communication patterns needed for distributed matrix + operations, including initial communication (with different shift patterns based on + coordinates) and subsequent iterations (with fixed shifts). + + Communication ordering is carefully managed to prevent deadlocks by using + ternary_parity to determine consistent send/receive ordering across different ranks. + """ + + def __init__( + self, + group_2d: dist.ProcessGroup, + group_col: dist.ProcessGroup, + group_layout: LayoutMap, + ): + """ + Ring comm over a 2d grid of devices with comm happening along both axes + Arguments: + group_2d: Group torch process group that provides communication + across the full cross-device + group_col: Subprocess group that provides communication + between ranks in the same column + group_layout: mapping from the 2d grid index to the flatten index + of the devices on the 2d grid + """ + # TODO: consolidate the ring 2d comm groups with other modules e,g. triangle attn + self.group_2d = group_2d + self.group_col = group_col + self.group_layout = group_layout + ranks_group_2d = set(dist.get_process_group_ranks(self.group_2d)) + ranks_group_col = set(dist.get_process_group_ranks(self.group_col)) + + if not ranks_group_col.issubset(ranks_group_2d): + raise ValueError("The col ranks are not a subset of ranks_group_2d") + + self.size_2d = dist.get_world_size(self.group_2d) + + if self.size_2d != self.group_layout.numel: + raise ValueError( + f"size of group_2d {self.size_2d} differs from the number of elements in group_layout {self.group_layout.numel}" + ) + + if self.group_layout.shape[0] != self.group_layout.shape[1]: + raise ValueError(f"group_layout.shape {self.group_layout.shape} is not square") + + self.rank_2d = dist.get_rank(self.group_2d) + self.coord_2d = self.group_layout.unravel(self.rank_2d) + + # all the send/recv ranks must be global in order to use isend/irecv + # only need transpose at the beginning of the batched GEMM for b or a + self.comm_2d_trans = TransposeComm(self.group_2d, self.group_layout) + + # always do left shift per row + # for iteration 0, i'th row left shift by i column + self.send_rank_row_init = get_group_rank_from_axial_shift( + self.coord_2d, 1, -self.coord_2d[0], self.group_layout + ) + self.recv_rank_row_init = get_group_rank_from_axial_shift(self.coord_2d, 1, self.coord_2d[0], self.group_layout) + + self.comm_row_init = One2OneComm( + self.group_2d, + self.send_rank_row_init, + self.recv_rank_row_init, + parity=ternary_parity(self.rank_2d, self.send_rank_row_init, self.recv_rank_row_init), + ) + # for other iterations left shift by 1 col + self.send_rank_row = get_group_rank_from_axial_shift(self.coord_2d, 1, -1, self.group_layout) + self.recv_rank_row = get_group_rank_from_axial_shift(self.coord_2d, 1, 1, self.group_layout) + + self.comm_row = One2OneComm( + self.group_2d, + self.send_rank_row, + self.recv_rank_row, + parity=ternary_parity(self.rank_2d, self.send_rank_row, self.recv_rank_row), + ) + + # always do up shift per col + # for iteration 0, j'th col up shift by j row + self.send_rank_col_init = get_group_rank_from_axial_shift( + self.coord_2d, 0, -self.coord_2d[1], self.group_layout + ) + self.recv_rank_col_init = get_group_rank_from_axial_shift(self.coord_2d, 0, self.coord_2d[1], self.group_layout) + self.comm_col_init = One2OneComm( + self.group_2d, + self.send_rank_col_init, + self.recv_rank_col_init, + parity=ternary_parity(self.rank_2d, self.send_rank_col_init, self.recv_rank_col_init), + ) + # for other iterations, up shift by 1 row + self.send_rank_col = get_group_rank_from_axial_shift(self.coord_2d, 0, -1, self.group_layout) + self.recv_rank_col = get_group_rank_from_axial_shift(self.coord_2d, 0, 1, self.group_layout) + self.comm_col = One2OneComm( + self.group_2d, + self.send_rank_col, + self.recv_rank_col, + parity=ternary_parity(self.rank_2d, self.send_rank_col, self.recv_rank_col), + ) + + # fused communication for transposition and initial row/col shift in backward + coords_transpose = self.coord_2d[::-1] + self.send_rank_transpose_row_init = get_group_rank_from_axial_shift( + coords_transpose, 1, -coords_transpose[0], self.group_layout + ) # shifting the transposed rank + recv_rank_transpose_row_init = get_group_rank_from_axial_shift( + self.coord_2d, 1, self.coord_2d[0], self.group_layout + ) # counter-shifting + self.recv_rank_transpose_row_init = self.group_layout( + self.group_layout.unravel(recv_rank_transpose_row_init)[::-1] + ) # counter-transposition + self.comm_transpose_row_init = One2OneComm( + self.group_2d, + self.send_rank_transpose_row_init, + self.recv_rank_transpose_row_init, + parity=ternary_parity(self.rank_2d, self.send_rank_transpose_row_init, self.recv_rank_transpose_row_init), + ) + + self.send_rank_transpose_col_init = get_group_rank_from_axial_shift( + coords_transpose, 0, -coords_transpose[1], self.group_layout + ) # shifting the transposed rank + recv_rank_transpose_col_init = get_group_rank_from_axial_shift( + self.coord_2d, 0, self.coord_2d[1], self.group_layout + ) # counter-shifting + self.recv_rank_transpose_col_init = self.group_layout( + self.group_layout.unravel(recv_rank_transpose_col_init)[::-1] + ) # counter-transposition + self.comm_transpose_col_init = One2OneComm( + self.group_2d, + self.send_rank_transpose_col_init, + self.recv_rank_transpose_col_init, + parity=ternary_parity(self.rank_2d, self.send_rank_transpose_col_init, self.recv_rank_transpose_col_init), + ) + + +class AttentionPairBiasComm: + def __init__( + self, + process_group: dist.ProcessGroup, + group_layout: LayoutMap, + cp_axis_0_group: dist.ProcessGroup, + cp_axis_1_group: dist.ProcessGroup, + ): + self.process_group = process_group + self.cp_axis_0_group = cp_axis_0_group + self.cp_axis_1_group = cp_axis_1_group + + if group_layout.shape is None: + raise ValueError("group_layout must have a shape") + + self.world_size = dist.get_world_size(self.process_group) + if self.world_size != group_layout.numel: + raise ValueError("Inconsistent world_size with the num elements of group_layout") + + if len(group_layout.shape) != 2: + raise ValueError(f"{self.__class__} only supports 2D group layout") + + if group_layout.shape[0] != group_layout.shape[1]: + raise ValueError(f"group_layout.shape {group_layout.shape} is not square") + + self.group_layout = group_layout + + self.global_rank = dist.get_rank() + self.group_rank = dist.get_rank(self.process_group) + self.rank_coords: tuple[int, int] = self.group_layout.unravel(self.group_rank) + + self.comm_transpose_k = TransposeComm(self.process_group, self.group_layout) # also used for backward + self.comm_transpose_v = TransposeComm(self.process_group, self.group_layout) # also used for backward + self.comm_transpose_mask = TransposeComm(self.process_group, self.group_layout) + + # for k, v and z comm + self.send_rank_kvz = get_group_rank_from_axial_shift(self.rank_coords, 1, 1, self.group_layout) + self.recv_rank_kvz = get_group_rank_from_axial_shift(self.rank_coords, 1, -1, self.group_layout) + self.parity = self.rank_coords[1] % 2 == 1 + self.comm_k = One2OneComm(self.process_group, self.send_rank_kvz, self.recv_rank_kvz, parity=self.parity) + self.comm_v = One2OneComm(self.process_group, self.send_rank_kvz, self.recv_rank_kvz, parity=self.parity) + self.comm_z = One2OneComm(self.process_group, self.send_rank_kvz, self.recv_rank_kvz, parity=self.parity) + + def __deepcopy__(self, memo): + return AttentionPairBiasComm( + self.process_group, + self.group_layout, + self.cp_axis_0_group, + self.cp_axis_1_group, + ) + + +class Ring2DCommTriAttn: + """ + Implements communication primitives for triangular attention in a 2D device grid. + + This class handles the specialized communication patterns required for triangular attention + operations across a 2D grid of devices, with particular focus on avoiding cross-rail traffic + and NCCL deadlocks. It's used in both TriangleAttentionStartingNode (axis_cp=1) and + TriangleAttentionEndingNode (axis_cp=0) implementations. + + The communication is designed to efficiently handle: + 1. Triangle bias reshuffling (in two stages to avoid cross-rail traffic) + 2. Key/value pair initial shuffling + 3. Iterative ring-based communication during attention computation + + Parameters + ---------- + group_2d : dist.ProcessGroup + The process group representing the 2D grid of devices across which the triangular + attention is distributed. + group_layout : LayoutMap + A mapping from 2D grid indices to flattened rank indices in the process group. + Must represent a square grid (same dimensions in both axes). + axis_cp : int + Specifies the axis for the context parallelism (CP): + - 0: For TriangleAttentionEndingNode (operating on columns) + - 1: For TriangleAttentionStartingNode (operating on rows) + + Notes + ----- + The triangle attention requires special data distribution patterns where triangle bias + is reorganized in a two-stage process: + - First stage flattens diagonals onto rows or columns + - Second stage rotates elements to meet ring attention requirements + + This implementation carefully manages communication ordering to prevent deadlocks by using + parity flags that ensure consistent send/receive ordering across different ranks. + """ + + def __init__( + self, + group_2d: dist.ProcessGroup, + group_layout: LayoutMap, + axis_cp: int, + ): + # The triangle bias requires 2d grid group while q/k/v communicates in a ring + # To prevent NCCL from hanging due to the assymetric isend/irecv, these two + # groups with different topo need to be separated so that the associated ops + # are launched into different cuda stream + self.group_2d = group_2d + self.group_layout = group_layout + self.axis_cp = axis_cp + + if self.axis_cp not in (0, 1): + raise NotImplementedError("axis_cp is not 0 or 1") + + self.size_2d = dist.get_world_size(self.group_2d) + + if self.group_layout.numel != self.size_2d: + raise ValueError( + f"Inconsistent group_layout.numel {self.group_layout.numel} with size of group_2d {self.size_2d}" + ) + + if self.group_layout.shape[0] != self.group_layout.shape[1]: + raise ValueError(f"group_layout.shape {self.group_layout.shape} is not square") + + self.rank_2d = dist.get_rank(self.group_2d) + self.coord_2d = self.group_layout.unravel(self.rank_2d) + + # comm handle for the initial shuffling of triangle bias + # initially, device[i, j] owns bias[i, j]. We reorganize + # the data by: + # if self.axis_cp == 1: + # device[i, j] sends its initial bias to device[(i - j) % self.size_row, i] + # device[i, j] receives new bias from device[j, (j - i) % self.size_row] + # if self.axis_cp == 0: + # device[i, j] sends its initial bias to device[j, (j - i) % self.size_col] + # device[i, j] receives new bias from device[(i - j) % self.size_col, i] + # To avoid cross-rail traffic on certain Slurm clusters, e.g., OCI-based ones, + # we need to broken down the communication into the following 2 stages, which + # together is equivalent to the aforementioned communication: + # 1. flatten the k'th lower (self.axis_cp == 1) or upper (self.axis_cp == 0) diagonals + # onto the k'th row (self.axis_cp == 1) or column (self.axis_cp == 0) + # if self.axis_cp == 1: + # j col up-shift by j place + # if self.axis_cp == 0: + # i row left-shift by i place + # + # Example of stage 1 on a 3x3 device grid: + # Original Data Ownership After Stage 1 (for axis_cp=1) + # ┌───┬───┬───┐ ┌───┬───┬───┐ + # │0,0│0,1│0,2│ │0,0│1,1│2,2│ original lower diagonal 0 + # ├───┼───┼───┤ ├───┼───┼───┤ + # │1,0│1,1│1,2│ → │1,0│2,1│0,2│ original lower diagonal 1 + # ├───┼───┼───┤ ├───┼───┼───┤ + # │2,0│2,1│2,2│ │2,0│0,1│1,2│ original lower diagonal 2 + # └───┴───┴───┘ └───┴───┴───┘ + # + # Original Data Ownership After Stage 1 (for axis_cp=0) + # ┌───┬───┬───┐ ┌───┬───┬───┐ + # │0,0│0,1│0,2│ │0,0│0,1│0,2│ (col 0: original upper diagonal 0) + # ├───┼───┼───┤ ├───┼───┼───┤ + # │1,0│1,1│1,2│ → │1,1│1,2│1,0│ (col 1: original upper diagonal 1) + # ├───┼───┼───┤ ├───┼───┼───┤ + # │2,0│2,1│2,2│ │2,2│2,0│2,1│ (col 2: original upper diagonal 2) + # └───┴───┴───┘ └───┴───┴───┘ + # 2. rotate elements in each row or col to meet the ring attention initialization requirement: + # if self.axis_cp == 1: + # i row right-shift by i place + # if self.axis_cp == 0: + # j col down-shift by j place + # + # After Stage 2 (for axis_cp=1) + # ┌───┬───┬───┐ + # │0,0│1,1│2,2│ original lower diagonal 0 + # ├───┼───┼───┤ + # │0,2│1,0│2,1│ original lower diagonal 1 + # ├───┼───┼───┤ + # │0,1│1,2│2,0│ original lower diagonal 2 + # └───┴───┴───┘ + # + # After Stage 2 (for axis_cp=0) + # ┌───┬───┬───┐ + # │0,0│2,0│1,0│ (col 0: original upper diagonal 0) + # ├───┼───┼───┤ + # │1,1│0,1│2,1│ (col 1: original upper diagonal 1) + # ├───┼───┼───┤ + # │2,2│1,2│0,2│ (col 2: original upper diagonal 2) + # └───┴───┴───┘ + self.axis_shift_bias = self.axis_cp ^ 1 + + # stage 1: flatten the k'th lower (self.axis_cp == 1) or upper (self.axis_cp == 0) diagonals + # i.e., j col up-shift by j place (self.axis_cp == 1) or i row left-shift by i place (self.axis_cp == 0) + self.rank_send_bias_init0 = get_group_rank_from_axial_shift( + self.coord_2d, self.axis_shift_bias, -self.coord_2d[self.axis_cp], self.group_layout + ) + self.rank_recv_bias_init0 = get_group_rank_from_axial_shift( + self.coord_2d, self.axis_shift_bias, self.coord_2d[self.axis_cp], self.group_layout + ) + self.comm_bias_init0 = One2OneComm( + self.group_2d, + self.rank_send_bias_init0, + self.rank_recv_bias_init0, + parity=self.coord_2d[self.axis_shift_bias] % 2 == 1, + ) + + # stage 2: rotate elements in each row or col to meet the ring attention initialization requirement + # i.e., i row right-shift by i place (self.axis_cp == 1) or j col down-shift by j place (self.axis_cp == 0) + self.rank_send_bias_init1 = get_group_rank_from_axial_shift( + self.coord_2d, self.axis_cp, self.coord_2d[self.axis_shift_bias], self.group_layout + ) + self.rank_recv_bias_init1 = get_group_rank_from_axial_shift( + self.coord_2d, self.axis_cp, -self.coord_2d[self.axis_shift_bias], self.group_layout + ) + self.comm_bias_init1 = One2OneComm( + self.group_2d, + self.rank_send_bias_init1, + self.rank_recv_bias_init1, + parity=self.coord_2d[self.axis_cp] % 2 == 1, + ) + + # every subsequent iteration, the triangle bias is up- or left-shift by 1 + self.rank_send_bias = get_group_rank_from_axial_shift( + self.coord_2d, self.axis_shift_bias, -1, self.group_layout + ) + self.rank_recv_bias = get_group_rank_from_axial_shift(self.coord_2d, self.axis_shift_bias, 1, self.group_layout) + + self.comm_bias = One2OneComm( + self.group_2d, + self.rank_send_bias, + self.rank_recv_bias, + parity=self.coord_2d[self.axis_shift_bias] % 2 == 1, + ) + + # comm handle for the initial shuffling of k/v pairs + # to offset the computation along the attention matrix's diagonal + # along axis_cp, i'th group right-/down- shift by i places + self.rank_send_kv_init = get_group_rank_from_axial_shift( + self.coord_2d, self.axis_cp, self.coord_2d[self.axis_cp ^ 1], self.group_layout + ) + self.rank_recv_kv_init = get_group_rank_from_axial_shift( + self.coord_2d, self.axis_cp, -self.coord_2d[self.axis_cp ^ 1], self.group_layout + ) + + parity_kv_init = ternary_parity(self.rank_2d, self.rank_send_kv_init, self.rank_recv_kv_init) + self.comm_k_init = One2OneComm( + self.group_2d, + self.rank_send_kv_init, + self.rank_recv_kv_init, + parity=parity_kv_init, + ) + self.comm_v_init = One2OneComm( + self.group_2d, + self.rank_send_kv_init, + self.rank_recv_kv_init, + parity=parity_kv_init, + ) + # the padding mask is along the K/V axis of the attn matrix + self.comm_mask_init = One2OneComm( + self.group_2d, + self.rank_send_kv_init, + self.rank_recv_kv_init, + parity=parity_kv_init, + ) + + # At every iteration, i'th group right-/down-shift by 1 + self.rank_send_kv = get_group_rank_from_axial_shift(self.coord_2d, self.axis_cp, 1, self.group_layout) + self.rank_recv_kv = get_group_rank_from_axial_shift(self.coord_2d, self.axis_cp, -1, self.group_layout) + + parity_kv = self.coord_2d[self.axis_cp] % 2 == 1 + + self.comm_k = One2OneComm(self.group_2d, self.rank_send_kv, self.rank_recv_kv, parity=parity_kv) + self.comm_v = One2OneComm(self.group_2d, self.rank_send_kv, self.rank_recv_kv, parity=parity_kv) + self.comm_mask = One2OneComm(self.group_2d, self.rank_send_kv, self.rank_recv_kv, parity=parity_kv) + + # comm handles for backward pass + self.comm_dk = deepcopy(self.comm_k) + self.comm_dv = deepcopy(self.comm_v) + self.comm_dbias = deepcopy(self.comm_bias) + + # these are used at the final stage of the backward pass to + # revert the data ownership of k, v and triangle bias to the initial state + # so the send/recv ranks are reversed of the comm_*_init + self.comm_dk_final = One2OneComm( + self.group_2d, + self.comm_k_init._rank_in_group_recv_from, + self.comm_k_init._rank_in_group_send_to, + parity=self.comm_k_init.parity ^ 1, + ) + + # for triangle bias, reverse the stage and send/recv ranks + self.comm_dbias_final0 = One2OneComm( + self.group_2d, + self.comm_bias_init1._rank_in_group_recv_from, + self.comm_bias_init1._rank_in_group_send_to, + parity=self.comm_bias_init1.parity ^ 1, + ) + self.comm_dbias_final1 = One2OneComm( + self.group_2d, + self.comm_bias_init0._rank_in_group_recv_from, + self.comm_bias_init0._rank_in_group_send_to, + parity=self.comm_bias_init0.parity ^ 1, + ) diff --git a/src/boltz/distributed/data/feature/featurizer.py b/src/boltz/distributed/data/feature/featurizer.py new file mode 100644 index 000000000..781c74375 --- /dev/null +++ b/src/boltz/distributed/data/feature/featurizer.py @@ -0,0 +1,627 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import warnings +from collections import OrderedDict + +import torch +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Placement, Replicate, Shard + +from boltz.data.pad import pad_dim +from boltz.distributed.data.feature.featurizer_utils import ( + _ATOM_IGNORE_FIELDS_FOR_CONTEXT_PARALLEL_PREDICT, + remap_atom_indices_unpadded_to_padded, +) +from boltz.distributed.data.utils import TensorMetadata +from boltz.distributed.model.layers.shardwise_op import shardwise_argmax, shardwise_offset +from boltz.distributed.model.layers.squeeze import shardwise_unsqueeze +from boltz.distributed.model.layers.utils import distributed_pack_and_pad +from boltz.distributed.utils import LayoutRightMap + + +def pad_and_scatter_atom_features_dtensor( + features: dict[str, torch.Tensor] | None, + placements: dict[str, Placement], + group: ProcessGroup, + src_rank_global: int, + device_mesh: DeviceMesh, +) -> dict[str, DTensor]: + """Pad and distribute atom features as DTensors for the AtomTransformer in distributed setting. + + This function performs sophisticated sharding of molecular features across a 2D device mesh, + handling both 1D (atom-based) and 2D (pair-based) features with appropriate padding and + scattering. Tokens are sharded along the first dimension (i) and duplicated across the + second dimension (j), ensuring balanced distribution of computational work. + + Parameters + ---------- + features : dict[str, torch.Tensor] | None + Dictionary mapping feature names to tensors containing molecular data. Must be provided + on the source rank and None on all other ranks. Features may include random prefixes + for testing ordering robustness. Key features include: + - atom_counts_per_token: Number of atoms per token for sharding calculations + - coords: Atom coordinates with leading ensemble dimension (E, A, 3) + - pair_mask: Pairwise interaction masks requiring 2D sharding + - atom_to_token, token_to_rep_atom: Mapping features with diagonal replication + - frames_idx: Frame indices requiring cumulative padding adjustments + placements : dict[str, Placement] + Dictionary mapping feature names to their desired tensor placements. Must have the + same keys as features on the source rank. Expected placements: + - Single representation features: (Shard(0), Replicate()) or (Shard(1), Replicate()) for coords + - 2D features: (Shard(0), Shard(1)) for pair_mask, (Shard(0), Replicate()) for mapping features + group : ProcessGroup + The distributed process group containing all ranks in the device mesh. Used for + broadcasting metadata and scattering tensor shards. + src_rank_global : int + The global rank that serves as the source for the original feature tensors. + This rank must be included in the process group. + device_mesh : DeviceMesh + A 2D square device mesh defining the distributed tensor layout. Currently only + supports square meshes (n_rows == n_cols). The mesh ranks must match the process group ranks. + + Returns + ------- + dict[str, DTensor] + Dictionary mapping feature names to distributed tensors (DTensors) with appropriate + sharding and padding applied. Features are processed in deterministic order to ensure + consistency across ranks. Some features may be filtered out based on usage context. + + Raises + ------ + ValueError + - If placements is not a dict + - If group is not a ProcessGroup + - If device_mesh is not a DeviceMesh or not 2D/square + - If device is not a torch.device or doesn't match device_mesh type + - If features/placements key mismatch on source rank + - If features is not None on non-source ranks + - If process group ranks don't match device mesh ranks + - If source rank is not in the process group + - If number of tokens is not divisible by shard dimension + - If placement dimensions don't match expected patterns for each feature type + + Notes + ----- + - Only supports 2D square device meshes currently + - Filters out fields in _ATOM_IGNORE_FIELDS_FOR_CONTEXT_PARALLEL_PREDICT (training/confidence-only) + - Uses OrderedDict internally to ensure consistent feature iteration order across ranks + - Handles complex padding logic for both 1D and 2D features: + * 1D features: Padded to max_atoms_per_shard along atom dimension + * 2D features: Complex 2D padding with different strategies per feature type + - Special handling for coordinates (leading ensemble dimension) and frames_idx (cumulative padding) + - Broadcasts tensor metadata before scattering to ensure consistent shard shapes + - Each feature requires specific placement patterns based on its semantic meaning + + Examples + -------- + >>> features = {"coords": torch.randn(1, 100, 3), "atom_counts_per_token": torch.ones(100)} + >>> placements = {"coords": (Shard(1), Replicate()), "atom_counts_per_token": (Shard(0), Replicate())} + >>> dtensors = pad_and_scatter_atom_features_dtensor(features, placements, group, 0, mesh) + """ + rank_global = torch.distributed.get_rank() + is_src_rank = rank_global == src_rank_global + + if not isinstance(placements, dict): + raise ValueError(f"Placements must be a dict, got {placements}") + + if not isinstance(group, ProcessGroup): + raise ValueError(f"Group must be a ProcessGroup, got {group}") + + if not isinstance(device_mesh, DeviceMesh): + raise ValueError(f"Device mesh must be a DeviceMesh, got {device_mesh}") + + if is_src_rank: + if not isinstance(features, dict): + raise ValueError(f"Features must be a dict on source rank {rank_global}, got {features}") + if features.keys() != placements.keys(): + raise ValueError( + f"Features and placements must have the same keys, got {features.keys()} and {placements.keys()}" + ) + else: + if features is not None: + raise ValueError(f"Features must be None on non-source rank {rank_global}, got {features}") + + # check consistency of process group and device mesh + ranks_in_group = torch.distributed.get_process_group_ranks(group) + + ranks_in_mesh = device_mesh.mesh.clone() # for inplace modification later + if ranks_in_group != ranks_in_mesh.flatten().tolist(): + raise ValueError( + f"Ranks in group {ranks_in_group} do not match ranks in mesh {ranks_in_mesh}, got {ranks_in_group} and {ranks_in_mesh}" + ) + + if src_rank_global not in ranks_in_group: + raise ValueError(f"Source rank {src_rank_global} not in group {ranks_in_group}") + + # we hardcode the sharding of token and atom features along a square sub-mesh of the device mesh + if device_mesh.ndim != 2: + raise ValueError(f"Only 2D device meshes currently supported, got {device_mesh.ndim}") + n_rows, n_cols = device_mesh.shape + if n_rows != n_cols: + raise ValueError(f"Only square device grids currently supported, got {n_rows},{n_cols}") + + # note that ranks_in_mesh might not be consecutive so we need a dictionary + # this is equivalent to + # rank_global_to_idx_scatter_list = { + # r_global: torch.distributed.get_group_rank(r_global, group) for r_global in ranks_in_mesh + # } + rank_global_to_idx_scatter_list = {r: i for i, r in enumerate(ranks_in_mesh.flatten().tolist())} + # mesh_coord_to_idx_scatter_list[i_row, j_col] -> idx in the group + # Later in the loop, while we iterate the shards in a LayoutRight order, + # the mesh_coord_to_idx_scatter_list mapping always respect the layout of the device mesh + mesh_coord_to_idx_scatter_list = ranks_in_mesh.apply_(rank_global_to_idx_scatter_list.get) + + # torch init_device_mesh and DeviceMesh ctor will do something like torch.cuda.set_device(rank) + # so the associated device is rank specific, which we can rely on to set the device for the resulting + # tensors + device = torch.device(device_mesh.device_type) + + if is_src_rank: + # metadata only relevant to source rank + token_atom_counts = features["atom_counts_per_token"] + + N_tokens_total = token_atom_counts.shape[0] + if N_tokens_total % n_rows != 0: + raise ValueError(f"Number of tokens ({N_tokens_total}) is not divisible by shard dimension ({n_rows})") + + # atom_counts_per_token is padded to CP-divisible length by the + # caller, but atom_to_token's token dimension (dim 1) may lag behind + # because its placement Shard(0) only pads the atom dimension. + if "atom_to_token" in features and features["atom_to_token"].shape[-1] < N_tokens_total: + features["atom_to_token"] = pad_dim( + features["atom_to_token"], + features["atom_to_token"].ndim - 1, + N_tokens_total - features["atom_to_token"].shape[-1], + ) + + token_atom_count_cumsum = torch.concatenate( + (torch.tensor([0], device=token_atom_counts.device), torch.cumsum(token_atom_counts, dim=0)) + ) + N_tokens_per_shard = N_tokens_total // n_rows + + shard_atom_counts_token = token_atom_counts.unflatten(dim=0, sizes=(n_rows, N_tokens_per_shard)).sum(dim=1) + max_atoms_per_shard_token = shard_atom_counts_token.amax().item() + + # max_atoms_per_shard for consistent intersperse padding + max_atoms_per_shard = max_atoms_per_shard_token + + # Pre-compute r_set_to_rep_atom metadata (loop-independent, computed once) + # r_set_to_rep_atom: [size_r_set, N_atoms] one-hot + # + # CRITICAL: R-set sharding must ALIGN with token sharding! + # Each R-set element corresponds to a specific token (via its rep atom). + # We must put each R-set element in the same shard as its token. + # + # The co-sharding of N_tokens and N_R is bijective: + # - Forward: R-set elements of shard i are a SUBSET of tokens of shard i. + # No R-set element can correspond to a token in a different shard. + # - Inverse: Tokens of shard i are a SUPERSET of its R-set elements. + # This holds by definition since each R element maps to exactly one token. + # + # PURPOSE: This co-sharding enables LOCAL matmul for mapping coordinates + # from atom space to R-set space (r_coords = r_set_to_rep_atom @ atom_coords). + # This mirrors the atom_to_token co-sharding strategy where atoms and tokens + # are co-sharded as diagonal blocks, enabling local atom-to-token mappings. + # + # Algorithm: + # 1. Compute r_set_to_token = r_set_to_rep_atom @ atom_to_token + # 2. Determine which shard each R-set element belongs to + # 3. Group R-set elements by their token shard + # 4. Each shard (i, *) gets only R-set elements for tokens in shard i + r_set_precomputed = None + if "r_set_to_rep_atom" in features and "atom_to_token" in features: + r_set_to_rep_atom_v = features["r_set_to_rep_atom"] + # Get valid R-set elements (filter out padding rows - all zeros) + r_set_valid_mask = r_set_to_rep_atom_v.any(dim=-1) # [size_r_set_total] + r_set_valid = r_set_to_rep_atom_v[r_set_valid_mask] # [size_r_set_valid, N_atom_global] + size_r_set_valid = r_set_valid.shape[0] + + if size_r_set_valid == 0: + r_set_precomputed = {"size_r_set_valid": 0} + else: + # Get atom_to_token for computing r_set -> token mapping. + # The general padding loop may have padded dim 0 (atoms) for + # Shard(0), but r_set_valid has the original atom count in + # dim 1. Slice to match. + n_atoms_r_set = r_set_valid.shape[-1] + atom_to_token = features["atom_to_token"][:n_atoms_r_set] + # Both r_set_valid and atom_to_token are one-hot, so the matmul + # r_set_valid @ atom_to_token is equivalent to two index lookups: + # 1. argmax over r_set_valid rows -> rep atom index per r_set element + # 2. argmax over atom_to_token rows -> token index per atom + # This avoids an O(N_r * N_atoms * N_tokens) int64 scalar loop + # (no BLAS path for int64) and replaces it with two O(N * M) argmax passes. + r_set_rep_atom_idx = r_set_valid.argmax(dim=-1) # [size_r_set_valid] + atom_to_token_idx = atom_to_token.argmax(dim=-1) # [N_atoms] + r_set_token_idx = atom_to_token_idx[r_set_rep_atom_idx] # [size_r_set_valid] + + # Determine which shard each R-set element belongs to + r_set_shard_idx = r_set_token_idx // N_tokens_per_shard # [size_r_set_valid] + + # Count R-set elements per shard using scatter_add + ones = torch.ones(size_r_set_valid, dtype=torch.long, device=device) + shard_counts = torch.zeros(n_rows, dtype=torch.long, device=device) + shard_counts.scatter_add_(0, r_set_shard_idx, ones) + + # Max R-set size per shard for padding + max_size_r_set_per_shard = max(1, shard_counts.max().item()) + + r_set_precomputed = { + "size_r_set_valid": size_r_set_valid, + "r_set_valid": r_set_valid, + "r_set_shard_idx": r_set_shard_idx, + "max_size_r_set_per_shard": max_size_r_set_per_shard, + } + + _2D_feats = { + "pair_mask", + "atom_to_token", + "token_to_rep_atom", + "frames_idx", + "r_set_to_rep_atom", # 2D one-hot [N_R, N_atoms] + } + + _2D_features_placement_as_single = { + "atom_to_token", + "token_to_rep_atom", + "frames_idx", + "r_set_to_rep_atom", # Diagonal block sharding like token_to_rep_atom + } + + # loop over each feature, create a list of tensors, scatter them then call DTensor.from_local + # To guarantee the order of iterating thru the features in the dictionary, we need to + # convert placements to a OrderedDict first so that all ranks go thru the keys in the same order. + placements_ordered = OrderedDict(sorted(placements.items())) + + result: dict[str, DTensor] = {} + for k, placement in placements_ordered.items(): + if k in _ATOM_IGNORE_FIELDS_FOR_CONTEXT_PARALLEL_PREDICT: + continue + + if k == "atom_counts_per_token": + # Only used in sharding preprocessing. + continue + + if len(placement) != device_mesh.ndim: + raise ValueError( + f"Placement for {k} has {len(placement)} dimensions, expected {device_mesh.ndim} from the device_mesh" + ) + + is_single_repr = k not in _2D_feats + if is_single_repr: + if k == "coords": + placement_expected = (Shard(1), Replicate()) + else: + placement_expected = (Shard(0), Replicate()) + else: + if k == "frames_idx": + placement_expected = (Shard(1), Replicate()) + elif k in _2D_features_placement_as_single: + placement_expected = (Shard(0), Replicate()) + else: + placement_expected = (Shard(0), Shard(1)) + + if is_src_rank and k == "frames_idx": + v_src = features[k] + if v_src.ndim != 3: + raise ValueError( + "frames_idx must have v2 ensemble-aware shape (E, T, 3) " + f"with ndim=3, got ndim={v_src.ndim} with shape={tuple(v_src.shape)}" + ) + + if is_src_rank and k == "frame_resolved_mask": + v_src = features[k] + if v_src.ndim != 2: + raise ValueError( + "frame_resolved_mask must have v2 ensemble-aware shape (E, T) " + f"with ndim=2, got ndim={v_src.ndim} with shape={tuple(v_src.shape)}" + ) + + placement_valid = placement == placement_expected + if not placement_valid: + raise ValueError(f"Placement for {k} is {placement}, expected {placement_expected} from the device_mesh") + + if is_src_rank: + v = features[k] + # create a list of tensors on the src rank + scatter_list = [None] * ranks_in_mesh.numel() + + for i in range(n_rows): + # Entries duplicated over j + if k not in _2D_feats: + # single representation + token_start = N_tokens_per_shard * i + token_end = N_tokens_per_shard * (i + 1) + atom_start = token_atom_count_cumsum[token_start] + atom_end = token_atom_count_cumsum[token_end] + num_atoms_in_range = atom_end - atom_start + pad_amount = max_atoms_per_shard - num_atoms_in_range + + if k == "coords": + # Leading dimension is ensemble count (E, A, 3); shard and pad along atom dim (1). + j_duplicates_val = pad_dim(v[:, atom_start:atom_end, ...], 1, pad_amount) + else: + j_duplicates_val = pad_dim(v[atom_start:atom_end], 0, pad_amount) + for j in range(n_cols): + # TODO: see if we can avoid the clone here + if j_duplicates_val.dtype in [torch.int8, torch.bool]: + j_duplicates_val = j_duplicates_val.clone() + scatter_list[mesh_coord_to_idx_scatter_list[i, j].item()] = j_duplicates_val + else: + # pair representation + # find token and atom ranges for 2d padding + row_token_start = N_tokens_per_shard * i + row_token_end = N_tokens_per_shard * (i + 1) + shard_atom_start = token_atom_count_cumsum[row_token_start] + shard_atom_end = token_atom_count_cumsum[row_token_end] + shard_atoms_in_range = shard_atom_end - shard_atom_start + + # 2D entries need separate calculation for each i, j + for j in range(n_cols): + if k == "pair_mask": + col_token_start = N_tokens_per_shard * j + col_token_end = N_tokens_per_shard * (j + 1) + + # 2D indexing and padding, (atoms * atoms) + res = torch.zeros( + size=(max_atoms_per_shard, max_atoms_per_shard), dtype=v.dtype, device=device + ) + col_atom_start = token_atom_count_cumsum[col_token_start] + col_atom_end = token_atom_count_cumsum[col_token_end] + col_atoms_in_range = col_atom_end - col_atom_start + res[:shard_atoms_in_range, :col_atoms_in_range] = v[ + shard_atom_start:shard_atom_end, col_atom_start:col_atom_end + ] + elif k == "atom_to_token": + # 2D indexing and padding, (atoms * tokens), internal padding only needed in atom dim. + # NOTE: Each j column gets the diagonal representation (i,i) - so columns are i-based here, j is ignored + # except for computing the output shard. + col_token_start = N_tokens_per_shard * i + col_token_end = N_tokens_per_shard * (i + 1) + col_tokens_in_range = col_token_end - col_token_start + res = torch.zeros( + size=(max_atoms_per_shard, N_tokens_per_shard), dtype=v.dtype, device=device + ) + res[:shard_atoms_in_range, :col_tokens_in_range] = v[ + shard_atom_start:shard_atom_end, col_token_start:col_token_end + ] + elif k == "frames_idx": + # frames_idx shape is (E, T, 3): E ensembles, T tokens, + # and for each token 3 global atom indices that define + # its local coordinate frame. The atom indices refer to + # positions in the unpadded, unsharded atom array. + # After sharding, each shard is padded to + # max_atoms_per_shard, shifting atom positions. + frame_idx = v[:, row_token_start:row_token_end, :] + res = remap_atom_indices_unpadded_to_padded( + frame_idx, shard_atom_counts_token, max_atoms_per_shard + ) + elif k == "token_to_rep_atom": + # 2D indexing and padding, (tokens * atoms), internal padding only needed in atom dim. Similar to atom_to_token. + # NOTE: Each j column gets the diagonal representation (i,i) - so columns are i-based here, j is ignored + # except for computing the output shard. + col_token_start = N_tokens_per_shard * i + col_token_end = N_tokens_per_shard * (i + 1) + col_tokens_in_range = col_token_end - col_token_start + res = torch.zeros( + size=(N_tokens_per_shard, max_atoms_per_shard), dtype=v.dtype, device=device + ) + res[:col_tokens_in_range, :shard_atoms_in_range] = v[ + col_token_start:col_token_end, + shard_atom_start:shard_atom_end, + ] + elif k == "r_set_to_rep_atom": + # Use pre-computed metadata (computed once before loops) + # See r_set_precomputed initialization for full documentation on: + # - Co-sharding of N_tokens and N_R (bijective relationship) + # - Purpose: enabling LOCAL matmul for atom->R-set coordinate mapping + # - Algorithm for token-aligned R-set sharding + + if r_set_precomputed["size_r_set_valid"] == 0: + # No valid R-set elements - create empty shard + res = torch.zeros((1, max_atoms_per_shard), dtype=v.dtype, device=device) + else: + r_set_valid = r_set_precomputed["r_set_valid"] + r_set_shard_idx = r_set_precomputed["r_set_shard_idx"] + max_size_r_set_per_shard = r_set_precomputed["max_size_r_set_per_shard"] + + # Get indices of R-set elements for shard i + shard_mask = r_set_shard_idx == i # [size_r_set_valid] + r_set_in_shard = r_set_valid[shard_mask] # [count_i, N_atom_global] + count_i = r_set_in_shard.shape[0] + + # Create output: [max_size_r_set_per_shard, max_atoms_per_shard] + res = torch.zeros( + (max_size_r_set_per_shard, max_atoms_per_shard), dtype=v.dtype, device=device + ) + if count_i > 0 and shard_atoms_in_range > 0: + # Slice atoms for this token shard i (diagonal block) + res[:count_i, :shard_atoms_in_range] = r_set_in_shard[ + :, shard_atom_start:shard_atom_end + ] + scatter_list[mesh_coord_to_idx_scatter_list[i, j].item()] = res + + else: + scatter_list = None + + # broadcast the metadata + # Assumption: all shards in scatter_list have the same shape, which is the assumption made by the + # code blocks above + l_metadata = [TensorMetadata(dtype=v.dtype, shape=scatter_list[0].shape)] if is_src_rank else [None] + torch.distributed.broadcast_object_list(l_metadata, src=src_rank_global, group=group, device=device) + + # scatter the tensor + local_shard = torch.empty(l_metadata[0].shape, dtype=l_metadata[0].dtype, device=device) + torch.distributed.scatter(local_shard, scatter_list, src=src_rank_global, group=group) + + # create the DTensor from local shard + # Due to the padding, we need to recompute the global shape with padding applied + shape_global = list(l_metadata[0].shape) + for i_dim_mesh, p in enumerate(placement): + if isinstance(p, Shard): + shape_global[p.dim] *= device_mesh.shape[i_dim_mesh] + shape_global = tuple(shape_global) + # Due to the local_shard from torch.empty and data shared by torch.distributed.scatter, + # the stride is guaranteed to be LayoutRight, i.e., torch's 'contiguous' memory layout + stride_global = LayoutRightMap(shape=shape_global).strides + dtensor = DTensor.from_local( + local_shard, device_mesh, placements=placement, shape=shape_global, stride=stride_global + ) + result[k] = dtensor + + return result + + +def pack_atom_features( + feats: dict[str, DTensor], + keys_subset: set[str], + W: int, +) -> OrderedDict[str, DTensor]: + """Pack and pad atom features using distributed_pack_and_pad. + + This removes per-shard trailing padding from pad_and_scatter_atom_features_dtensor + and creates a packed DTensor with global trailing padding (multiple of W * size_cp). + + The function handles keys in keys_subset as follows: + - "atom_to_token": Special handling - converts shard-local indices to global before packing, + and stores the global indices as "atom_to_token_ids_global". NOTE: The original + 'atom_to_token' one-hot matrix is NOT packed. Directly packing 'atom_to_token' would + give inconsistent sharding scheme between atom and token dimensions, making the packed + 'atom_to_token' not useful in practice. Only the global indices are packed and returned. + - All other keys: Treated as generic atom features and packed directly with the mask + + Parameters + ---------- + feats : dict[str, DTensor] + Dictionary of atom features as DTensors. Must contain "atom_pad_mask" key + with shape (B, N_atoms) to use as the packing mask. + keys_subset : set[str] + Set of keys to pack from feats. Must contain "atom_pad_mask". + Only keys present in both feats and keys_subset will be processed. + All keys (except "atom_to_token" and "atom_pad_mask") are treated as + generic atom features with N_atoms axis at position 1. + W : int + Window size for packing (atoms per window for queries). + The packed output will have length that is a multiple of W * size_cp. + + Returns + ------- + OrderedDict[str, DTensor] + OrderedDict of packed atom features with keys in sorted order. Contains: + - All keys from keys_subset that were in feats, with packed values + - "atom_to_token_ids_global" if "atom_to_token" was in keys_subset + Keys are sorted to ensure consistent iteration order across distributed ranks. + + Raises + ------ + ValueError + If feats does not contain "atom_pad_mask" key. + If feats["atom_pad_mask"] does not have ndim=2 (expected shape: B, N_atoms). + If keys_subset does not contain "atom_pad_mask". + If keys_subset contains keys not present in feats. + NotImplementedError + If keys_subset contains "coords" (packing coords is not supported). + """ + if "atom_pad_mask" not in feats: + raise ValueError("feats must contain 'atom_pad_mask' key") + if feats["atom_pad_mask"].ndim != 2: + raise ValueError( + f"feats['atom_pad_mask'] must have ndim=2 (B, N_atoms), got ndim={feats['atom_pad_mask'].ndim}" + ) + if "atom_pad_mask" not in keys_subset: + raise ValueError("keys_subset must contain 'atom_pad_mask'") + if "coords" in keys_subset: + raise NotImplementedError("packing 'coords' is not supported") + + # Verify all keys in keys_subset are present in feats + missing_keys = keys_subset - feats.keys() + if missing_keys: + raise ValueError(f"keys_subset contains keys not in feats: {missing_keys}") + + # Sort keys to ensure consistent iteration order across all ranks + # This is critical because distributed collective operations must be called in the same order + keys_sorted = sorted(keys_subset) + + # Pack and pad each atom feature in keys_subset + # Use no_grad to prevent backprop - these are input features not supposed to receive gradients + feats_dt_packed: OrderedDict[str, DTensor] = OrderedDict() + with torch.no_grad(): + # Get atom_mask for pack_and_pad (shape: B, N_atoms) + # Must convert to bool() to match pack_and_pad behavior and avoid + # precision issues when summing bfloat16 mask values + atom_mask_dt = feats["atom_pad_mask"].bool() + + for key in keys_sorted: + # Verify feature shape matches mask shape in first two dimensions (B, N_atoms) + if feats[key].shape[:2] != atom_mask_dt.shape: + raise ValueError( + f"feats['{key}'].shape[:2]={feats[key].shape[:2]} must match atom_mask_dt.shape={atom_mask_dt.shape}" + ) + + if key == "atom_to_token": + # Special handling: convert shard-local indices to global BEFORE packing. + # This is necessary because distributed_pack_and_pad may move atoms between shards, + # which would result in different sharding schemes between token and atom. + # On the other hand, shardwise_offset below assumes consistent sharding scheme + # between token and atom, i.e., any rank must own all atoms from all its own + # tokens, or any token's atom collection is not sharded. Therefore, + # shardwise_offset must be applied before distributed_pack_and_pad. + # 1. Get shard-local token indices from block-diagonal atom_to_token + atom_to_token_dt = feats[key] # (B, N_atoms, N_tokens_per_shard) + if atom_to_token_dt.ndim != 3: + raise ValueError( + f"feats['atom_to_token'] must have ndim=3 (B, N_atoms, N_tokens_per_shard), " + f"got ndim={atom_to_token_dt.ndim}" + ) + n_tokens_per_shard = atom_to_token_dt.to_local().shape[2] + atom_to_token_ids_local = shardwise_argmax(atom_to_token_dt, dim=-1, keepdim=False) + # 2. Convert to global indices + atom_to_token_ids_global = shardwise_offset( + atom_to_token_ids_local, dim=1, offset_per_rank=n_tokens_per_shard + ) + # 3. Pack the global indices (not the one-hot matrix) + mask_for_ids = atom_mask_dt + packed_ids, _ = distributed_pack_and_pad(atom_to_token_ids_global, mask_for_ids, W, axis=1) + feats_dt_packed["atom_to_token_ids_global"] = packed_ids + feats_dt_packed["atom_to_token_local_onehot"] = atom_to_token_dt + if not getattr(pack_atom_features, "_warned_a2t", False): + pack_atom_features._warned_a2t = True + warnings.warn( + "pack_atom_features: 'atom_to_token' one-hot matrix is NOT packed but returned " + "as-is together with 'atom_to_token_ids_global' for window batching AtomAttentionEncoder. " + "Directly packing 'atom_to_token' would give inconsistent sharding scheme " + "between atom and token, making it not useful in practice.", + stacklevel=2, + ) + else: + # Generic atom feature: expand mask and pack + feat = feats[key] + mask_for_feat = atom_mask_dt + # Add trailing dimensions to mask for broadcasting with feat + while mask_for_feat.ndim < feat.ndim: + mask_for_feat = shardwise_unsqueeze(mask_for_feat, -1) + packed_feat, _packed_mask = distributed_pack_and_pad(feat, mask_for_feat, W, axis=1) + feats_dt_packed[key] = packed_feat + + return feats_dt_packed diff --git a/src/boltz/distributed/data/feature/featurizer_utils.py b/src/boltz/distributed/data/feature/featurizer_utils.py new file mode 100644 index 000000000..535e00e6d --- /dev/null +++ b/src/boltz/distributed/data/feature/featurizer_utils.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from functools import partial + +import torch +from torch import Tensor + +from boltz.data.types import Tokenized +from boltz.model.modules.encodersv2 import get_indexing_matrix, single_to_keys + +# Fields to ignore during context parallel prediction +_ATOM_IGNORE_FIELDS_FOR_CONTEXT_PARALLEL_PREDICT: set[str] = set() + +# Features whose values are atom indices into a padded global atom array. +# Both pad_and_scatter_atom_features_dtensor and CollateDTensor must remap +# these values whenever atom padding changes. +ATOM_INDEX_FEATURES = {"frames_idx"} + + +def remap_atom_indices_unpadded_to_padded( + indices: Tensor, + shard_atom_counts: Tensor, + padded_atoms_per_shard: int, +) -> Tensor: + """Remap atom indices from an unpadded global atom array to a padded layout. + + In the padded layout each CP shard occupies exactly ``padded_atoms_per_shard`` + positions, with trailing zeros filling unused slots. This shifts atom + positions relative to the dense (unpadded) layout. + + Parameters + ---------- + indices : Tensor + Atom indices into the unpadded global atom array (arbitrary shape). + shard_atom_counts : Tensor + Actual (unpadded) atom count per shard, shape ``(n_shards,)``. + padded_atoms_per_shard : int + Fixed size of each shard in the padded layout. + + Returns + ------- + Tensor + Indices remapped to the padded layout, same shape as *indices*. + """ + dtype = indices.dtype + # cumsum promotes int32→int64 to avoid overflow; cast back explicitly. + shard_starts = torch.cat( + [ + torch.zeros(1, device=shard_atom_counts.device, dtype=dtype), + shard_atom_counts.cumsum(dim=0).to(dtype), + ] + ) + pad_offsets = ( + torch.arange(len(shard_atom_counts), device=shard_atom_counts.device, dtype=dtype) * padded_atoms_per_shard + - shard_starts[:-1] + ) + shard_idx = torch.bucketize(indices, shard_starts[1:], right=True) + return indices + pad_offsets[shard_idx] + + +def remap_atom_indices_repad( + indices: Tensor, + old_atoms_per_shard: int, + new_atoms_per_shard: int, +) -> Tensor: + """Remap atom indices when per-shard padding changes. + + After ``pad_and_scatter_atom_features_dtensor``, atom-index features + reference a padded layout with stride ``old_atoms_per_shard``. When + collation pads the atom dimension further (to align a batch or across DP + ranks), the stride grows to ``new_atoms_per_shard`` and every stored index + must be adjusted. + + Parameters + ---------- + indices : Tensor + Atom indices in the old padded layout (arbitrary shape). + old_atoms_per_shard : int + Per-shard atom count in the current (old) padded layout. + new_atoms_per_shard : int + Per-shard atom count in the target (new) padded layout. + + Returns + ------- + Tensor + Indices remapped to the new padded layout, same shape as *indices*. + """ + if old_atoms_per_shard == new_atoms_per_shard: + return indices + shard_of_atom = indices // old_atoms_per_shard + offset_in_shard = indices % old_atoms_per_shard + return shard_of_atom * new_atoms_per_shard + offset_in_shard + + +def get_pair_mask(N_atoms: int, W: int = 32, H: int = 128) -> Tensor: + """Get the pair mask for the atom transformer. + + Parameters + ---------- + N_atoms : int + The number of atoms. + W : int, optional + The attention window queries, by default 32. + H : int, optional + The attention window keys, by default 128. + + Returns + ------- + Tensor + The pair mask. + + """ + mask = torch.zeros(N_atoms, N_atoms) + + if N_atoms % W == 0: + max_atoms = N_atoms + else: + # pad to the next multiple of W + max_atoms = ((N_atoms // W) + 1) * W + + # construct pair mask through indexing matrices + # TODO construct pair mask directly from AF3 appendix + index = torch.arange(1, max_atoms + 1) + index[N_atoms:] = 0 + index = index.unsqueeze(0) + + K = max_atoms // W + keys_indexing_matrix = get_indexing_matrix(K, W, H, index.device) + to_keys = partial(single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H) + + index_queries = index.view(K, W) + index_keys = to_keys(index.unsqueeze(-1).float()).view(K, H).long() + + for index_query, index_key in zip(index_queries, index_keys): + index_query = index_query[index_query != 0] + index_key = index_key[index_key != 0] + mask[index_query.min() - 1 : index_query.max(), index_key.min() - 1 : index_key.max()] = 1 + + return mask + + +def tokenized_stats( + tokenized: Tokenized, +) -> dict[str, int]: + """Get statistics about the tokenized data. + + Parameters + ---------- + tokenized : Tokenized + The tokenized data. + + Returns + ------- + dict[str, int] + Dictionary containing: + - num_atoms_total: Total number of atoms across all tokens + - num_tokens: Number of tokens + - num_atoms_max: Maximum atoms in any single token + - num_atoms_min: Minimum atoms in any single token + """ + num_atoms_total = sum([token["atom_num"] for token in tokenized.tokens]) + num_tokens = len(tokenized.tokens) + num_atoms_max = max([token["atom_num"] for token in tokenized.tokens]) + num_atoms_min = min([token["atom_num"] for token in tokenized.tokens]) + + return { + "num_atoms_total": num_atoms_total, + "num_tokens": num_tokens, + "num_atoms_max": num_atoms_max, + "num_atoms_min": num_atoms_min, + } + + +def get_num_atoms_tokens(tokenized: Tokenized) -> tuple[int, int]: + """Get the number of atoms and tokens from tokenized data. + + Parameters + ---------- + tokenized : Tokenized + The tokenized data. + + Returns + ------- + tuple[int, int] + Tuple of (num_atoms_total, num_tokens). + """ + stats = tokenized_stats(tokenized) + return stats["num_atoms_total"], stats["num_tokens"] diff --git a/src/boltz/distributed/data/feature/symmetry.py b/src/boltz/distributed/data/feature/symmetry.py new file mode 100644 index 000000000..13ebec2bd --- /dev/null +++ b/src/boltz/distributed/data/feature/symmetry.py @@ -0,0 +1,369 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from __future__ import annotations + +from itertools import chain +from numbers import Integral + +import torch +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor + +from boltz.distributed.model.loss.triton.cdist_lddt import cdist_lddt + + +def minimum_lddt_symmetry_coords( + coords: DTensor, + feats: dict, + index_batch_local: int, + i_batch_multiplicity_local: int = 0, +): + """Find coordinates with best lDDT under symmetry transformations (Boltz-2 semantics). + + This function handles the distributed case where: + - coords is a DTensor sharded across (DP, CP_axis_0, CP_axis_1) device mesh + - symmetry features are plain Tensors with only the local batch for this DP rank + + Unlike the Boltz-1 version, this does NOT perform RMSD alignment (matching + the Boltz-2 serial ``minimum_lddt_symmetry_coords`` from ``boltz.data.mol``). + + Local Batch Indexing Semantics + ------------------------------ + In distributed training with DP (data parallelism) and CP (context parallelism): + + 1. **coords (DTensor)**: Has global shape [B_global * multiplicity, N_atoms_padded, 3] + - Sharded along batch dim (DP axis) and atom dim (CP axes) + - Contains interspersed padding from pad_and_scatter_atom_features_dtensor + - After CP gather: local shape is [B_local * multiplicity, N_atoms_padded, 3] + - B_local = B_global / num_DP_ranks (each DP rank owns different samples) + - Use ``i_batch_multiplicity_local`` to index: coords_local[i_batch_multiplicity_local] + + 2. **symmetry features (plain Tensors and nested iterables broadcasted among all ranks)**: + - The tensors have shape [B_local, N_atoms_no_pad, ...] + - NOT sharded - each DP rank only has its own local batch + - NO interspersed padding - contiguous atoms without CP shard padding + - Use ``index_batch_local`` to index: feats["all_coords"][index_batch_local] + + 3. **atom_pad_mask (DTensor in feats)**: Has global shape [B_global, N_atoms_padded] + - Indicates valid atoms (True/1.0) vs padding (False/0.0) + - Used to remove interspersed padding from coords for symmetry resolution + - Then re-add padding before returning as DTensor + + Index Relationship (with multiplicity M): + - index_batch_local = i_batch_multiplicity_local // M + - i_batch_multiplicity_local = index_batch_local * M + rep (where rep in [0, M)) + + Parameters + ---------- + coords : DTensor + Predicted coordinates with shape [B_global * multiplicity, N_atoms_padded, 3]. + The batch dimension includes diffusion_samples multiplicity and is sharded along DP axis. + Contains interspersed padding from distributed featurization. + feats : dict + Dictionary containing (symmetry features are per-DP-rank local tensors): + - all_coords: Tensor [B_local, N_all, 3] + - all_resolved_mask: Tensor [B_local, N_all] + - crop_to_all_atom_map: Tensor [B_local, N_crop] + - chain_swaps: List[B_local] of swap combinations + - amino_acids_symmetries: List[B_local] of symmetry groups + - ligand_symmetries: List[B_local] of symmetry groups + - atom_pad_mask: DTensor [B_global, N_atoms_padded] - atom padding mask (required) + index_batch_local : int + Local batch index into symmetry features (range: [0, B_local)). + i_batch_multiplicity_local : int + Local index into coords (range: [0, B_local * multiplicity)). + + Returns + ------- + tuple[DTensor, DTensor] + (true_coords_dtensor, true_resolved_mask_dtensor) as DTensors + with the same placements as the input coords. + + """ + if not isinstance(coords, DTensor): + raise TypeError(f"coords must be a DTensor, got {type(coords).__name__}.") + + if "atom_pad_mask" not in feats: + raise KeyError("feats must contain 'atom_pad_mask' key for handling interspersed padding.") + atom_pad_mask = feats["atom_pad_mask"] + if not isinstance(atom_pad_mask, DTensor): + raise TypeError(f"feats['atom_pad_mask'] must be a DTensor, got {type(atom_pad_mask).__name__}.") + + coords_mesh = coords.device_mesh + coords_placements = coords.placements + if coords_placements != (Shard(0), Shard(1), Replicate()): + raise ValueError(f"Expected coords placements (Shard(0), Shard(1), Replicate()), got {coords_placements}") + if atom_pad_mask.device_mesh != coords_mesh: + raise ValueError("atom_pad_mask.device_mesh does not match coords.device_mesh") + if atom_pad_mask.placements != coords_placements: + raise ValueError(f"Expected atom_pad_mask placements {coords_placements}, got {atom_pad_mask.placements}") + + all_coords = feats["all_coords"] + all_resolved_mask = feats["all_resolved_mask"] + crop_to_all_atom_map = feats["crop_to_all_atom_map"] + + # CollateDTensor collates NON_SHARDED_FEATURES_V2 as Python lists of + # tensors (one per sample, no batch dim). Stack to add the batch dim. + if isinstance(all_coords, list): + all_coords = torch.stack(all_coords, dim=0) + if isinstance(all_resolved_mask, list): + all_resolved_mask = torch.stack(all_resolved_mask, dim=0) + if isinstance(crop_to_all_atom_map, list): + crop_to_all_atom_map = torch.stack(crop_to_all_atom_map, dim=0) + + for key, val in [ + ("all_coords", all_coords), + ("all_resolved_mask", all_resolved_mask), + ("crop_to_all_atom_map", crop_to_all_atom_map), + ]: + if not isinstance(val, torch.Tensor): + raise TypeError(f"feats['{key}'] must be a plain torch.Tensor, got {type(val).__name__}.") + + if coords.ndim != 3 or coords.shape[2] != 3: + raise ValueError("coords must have shape [B, N, 3].") + if all_coords.ndim != 3 or all_coords.shape[2] != 3: + raise ValueError("feats['all_coords'] must have shape [B, N_all, 3].") + + chain_swaps_all = feats.get("chain_swaps") + if not isinstance(chain_swaps_all, (list, tuple)): + raise TypeError("feats['chain_swaps'] must be a list/tuple of swap combinations.") + chain_swaps = chain_swaps_all[index_batch_local] + if not isinstance(chain_swaps, (list, tuple)): + raise TypeError("chain_swaps must be a list/tuple of swap combinations.") + if not all(isinstance(combo, (list, tuple)) for combo in chain_swaps): + raise TypeError("chain_swaps entries must be list/tuple of swaps.") + if not all( + isinstance(swap, (list, tuple)) and len(swap) == 6 and all(isinstance(v, Integral) for v in swap) + for swap in chain.from_iterable(chain_swaps) + ): + raise ValueError("chain_swaps swaps must be 6-int tuples: (start1, end1, start2, end2, chainidx1, chainidx2).") + + amino_acids_symmetries_all = feats.get("amino_acids_symmetries") + if not isinstance(amino_acids_symmetries_all, (list, tuple)): + raise TypeError("feats['amino_acids_symmetries'] must be a list/tuple of symmetry groups.") + amino_acids_symmetries = amino_acids_symmetries_all[index_batch_local] + + ligand_symmetries_all = feats.get("ligand_symmetries") + if not isinstance(ligand_symmetries_all, (list, tuple)): + raise TypeError("feats['ligand_symmetries'] must be a list/tuple of symmetry groups.") + ligand_symmetries = ligand_symmetries_all[index_batch_local] + + # --- Validate structure: sym_groups[residue][sym_op] = [(i,j), ...] --- + for sym_name, sym_groups in [ + ("amino_acids_symmetries", amino_acids_symmetries), + ("ligand_symmetries", ligand_symmetries), + ]: + for group in sym_groups: + for option in group: + if not isinstance(option, (list, tuple)): + raise ValueError( + f"{sym_name} symmetry operation must be a list of (i, j) pairs, got {type(option)}" + ) + for swap_pair in option: + if not isinstance(swap_pair, (list, tuple)) or len(swap_pair) != 2: + raise ValueError(f"{sym_name} entries must be 2-tuples (i, j), got {swap_pair}") + + # --- Gather coords along CP axes, keep DP sharding --- + cp_gathered_placements = tuple( + coords_placements[0] if i == 0 else Replicate() for i in range(len(coords_placements)) + ) + coords_cp_gathered = coords.redistribute(coords_mesh, cp_gathered_placements) + coords_local = coords_cp_gathered.to_local() # [B_local * mul, N_atoms_padded, 3] + + coords_single_padded = coords_local[i_batch_multiplicity_local : i_batch_multiplicity_local + 1] + + atom_pad_mask_cp_gathered = atom_pad_mask.redistribute(coords_mesh, cp_gathered_placements) + atom_pad_mask_local = atom_pad_mask_cp_gathered.to_local() + mask_single = atom_pad_mask_local[index_batch_local].bool() # [N_atoms_padded] + + # Remove interspersed padding: [1, N_atoms_padded, 3] -> [1, N_atoms_no_pad, 3] + coords_single = coords_single_padded[:, mask_single, :] + + # Index symmetry features (plain tensors, no multiplicity) + all_coords_indexed = all_coords[index_batch_local].unsqueeze(0).to(coords_single) + all_resolved_mask_indexed = all_resolved_mask[index_batch_local].to(coords_single).to(torch.bool) + crop_to_all_atom_map_indexed = crop_to_all_atom_map[index_batch_local].to(coords_single).to(torch.long) + + n_crop = int(crop_to_all_atom_map_indexed.numel()) + pred_coords_crop = coords_single[:, :n_crop] + + # --- Chain swap selection (Boltz-2 semantics) --- + best_true_coords = all_coords_indexed[:, crop_to_all_atom_map_indexed].clone() + best_true_resolved_mask = all_resolved_mask_indexed[crop_to_all_atom_map_indexed].clone() + best_lddt = -1.0 + + for c in chain_swaps: + true_all_coords = all_coords_indexed.clone() + true_all_resolved_mask = all_resolved_mask_indexed.clone() + for start1, end1, start2, end2, _chainidx1, _chainidx2 in c: + true_all_coords[:, start1:end1] = all_coords_indexed[:, start2:end2] + true_all_resolved_mask[start1:end1] = all_resolved_mask_indexed[start2:end2] + + true_coords = true_all_coords[:, crop_to_all_atom_map_indexed] + true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map_indexed] + + mask_row = true_resolved_mask.unsqueeze(0) + mask_col = true_resolved_mask.unsqueeze(0) + lddt = cdist_lddt( + pred_coords_row=pred_coords_crop, + pred_coords_col=pred_coords_crop, + true_coords_row=true_coords, + true_coords_col=true_coords, + mask_row=mask_row, + mask_col=mask_col, + multiplicity=1, + cutoff=15.0, + per_atom=False, + )[0].item() + + if lddt > best_lddt and torch.sum(true_resolved_mask) > 3: + best_lddt = lddt + best_true_coords = true_coords + best_true_resolved_mask = true_resolved_mask + + # --- Atom-level symmetries (Boltz-2 semantics: best improvement) --- + true_coords = best_true_coords.clone() + true_resolved_mask = best_true_resolved_mask.clone() + for symmetric_amino_or_lig in amino_acids_symmetries + ligand_symmetries: + best_lddt_improvement = 0.0 + + # Precompute all unique indices across all options in this group + indices_set: set[int] = set() + for c in symmetric_amino_or_lig: + for i, j in c: + indices_set.add(i) + if len(indices_set) == 0: + continue + indices = sorted(indices_set) + indices = torch.as_tensor(indices, device=pred_coords_crop.device, dtype=torch.long) + pred_coords_subset = pred_coords_crop[:, indices] + + for c in symmetric_amino_or_lig: + new_true_coords = true_coords.clone() + new_true_resolved_mask = true_resolved_mask.clone() + for i, j in c: + new_true_coords[:, i] = true_coords[:, j] + new_true_resolved_mask[i] = true_resolved_mask[j] + + true_coords_subset = true_coords[:, indices] + new_true_coords_subset = new_true_coords[:, indices] + + mask_row = true_resolved_mask.unsqueeze(0) + mask_col = true_resolved_mask[indices].unsqueeze(0) + indices_batch = indices.unsqueeze(0) + lddt = cdist_lddt( + pred_coords_row=pred_coords_crop, + pred_coords_col=pred_coords_subset, + true_coords_row=true_coords, + true_coords_col=true_coords_subset, + mask_row=mask_row, + mask_col=mask_col, + multiplicity=1, + atom_indices_col=indices_batch, + cutoff=15.0, + per_atom=False, + )[0].item() + + new_mask_row = new_true_resolved_mask.unsqueeze(0) + new_mask_col = new_true_resolved_mask[indices].unsqueeze(0) + new_lddt = cdist_lddt( + pred_coords_row=pred_coords_crop, + pred_coords_col=pred_coords_subset, + true_coords_row=new_true_coords, + true_coords_col=new_true_coords_subset, + mask_row=new_mask_row, + mask_col=new_mask_col, + multiplicity=1, + atom_indices_col=indices_batch, + cutoff=15.0, + per_atom=False, + )[0].item() + + lddt_improvement = new_lddt - lddt + if lddt_improvement > best_lddt_improvement: + best_true_coords = new_true_coords + best_true_resolved_mask = new_true_resolved_mask + best_lddt_improvement = lddt_improvement + + true_coords = best_true_coords.clone() + true_resolved_mask = best_true_resolved_mask.clone() + + # --- Re-add interspersed padding and wrap as DTensors --- + # Shape consistency check (same as Boltz-1 CP): boolean masking at line 167 removes + # ALL padding (both trailing batch padding and interspersed CP padding) from + # coords_single, so coords_single.shape[1] = sum(atom_pad_mask) = n_crop. + # true_coords.shape[1] = len(crop_to_all_atom_map) = n_crop. They must be equal. + n_atoms_padded = coords_single_padded.shape[1] + n_real = int(mask_single.sum()) + if true_coords.shape[1] != n_real: + raise ValueError( + f"Shape mismatch: true_coords.shape[1]={true_coords.shape[1]} != " + f"sum(atom_pad_mask)={n_real}. Both should equal the number of crop atoms." + ) + if true_resolved_mask.shape[0] != n_real: + raise ValueError( + f"Shape mismatch: true_resolved_mask.shape[0]={true_resolved_mask.shape[0]} != " + f"sum(atom_pad_mask)={n_real}. Both should equal the number of crop atoms." + ) + + true_coords_padded = torch.zeros((1, n_atoms_padded, 3), dtype=true_coords.dtype, device=true_coords.device) + true_coords_padded[:, mask_single, :] = true_coords + + true_resolved_mask_padded = torch.zeros( + (n_atoms_padded,), dtype=true_resolved_mask.dtype, device=true_resolved_mask.device + ) + true_resolved_mask_padded[mask_single] = true_resolved_mask + + device_mesh = coords_mesh + placements = coords_placements + + # Broadcast from CP rank 0 to ensure bitwise-identical results across CP ranks + # Distribute with the target CP placements directly + # so each rank only receives the shard it needs. + cp_mesh = device_mesh["cp_axis_0", "cp_axis_1"] + cp_shard_placements = (placements[1], placements[2]) # placements: (dp, cp_axis_0, cp_axis_1) + + true_coords_cp = distribute_tensor( + true_coords_padded, device_mesh=cp_mesh, placements=cp_shard_placements, src_data_rank=0 + ) + _coords_global_shape = true_coords_padded.shape + true_coords_dtensor = DTensor.from_local( + true_coords_cp.to_local(), + device_mesh, + placements=placements, + shape=_coords_global_shape, + stride=true_coords_padded.stride(), + ) + + true_mask_unsqueezed = true_resolved_mask_padded.unsqueeze(0) + true_mask_cp = distribute_tensor( + true_mask_unsqueezed, device_mesh=cp_mesh, placements=cp_shard_placements, src_data_rank=0 + ) + _mask_global_shape = true_mask_unsqueezed.shape + true_resolved_mask_dtensor = DTensor.from_local( + true_mask_cp.to_local(), + device_mesh, + placements=placements, + shape=_mask_global_shape, + stride=true_mask_unsqueezed.stride(), + ) + + return true_coords_dtensor, true_resolved_mask_dtensor diff --git a/src/boltz/distributed/data/module/inferencev2.py b/src/boltz/distributed/data/module/inferencev2.py new file mode 100644 index 000000000..f6a304d33 --- /dev/null +++ b/src/boltz/distributed/data/module/inferencev2.py @@ -0,0 +1,441 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# TODO: v2 features not yet supported for distributed sharding: +# - Constraint features (requires compute_constraint_features=True) +# - Template features +# - Affinity module features +# +# NOTE: The following features are produced by the featurizer but not consumed by the model. +# They are silently dropped during distribution and this is intentional: +# - token_to_center_atom: produced but unused by model +# - ensemble_ref_idxs: consumed during featurization only, unused by model + +import math +import warnings +from pathlib import Path +from typing import Dict, Optional + +import numpy as np +import pytorch_lightning as pl +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.utils.data import DataLoader, DistributedSampler + +from boltz.data import const +from boltz.data.feature.featurizerv2 import Boltz2Featurizer +from boltz.data.module.inferencev2 import load_input +from boltz.data.mol import load_canonicals, load_molecules +from boltz.data.pad import pad_dim +from boltz.data.tokenize.boltz2 import Boltz2Tokenizer +from boltz.data.types import Manifest +from boltz.distributed.data.feature.featurizer import pad_and_scatter_atom_features_dtensor +from boltz.distributed.data.feature.featurizer_utils import get_pair_mask +from boltz.distributed.data.module.placements import INFERENCE_FEATURE_PLACEMENTS_V2 +from boltz.distributed.data.types import PairMaskMode +from boltz.distributed.data.utils import ( + ATOM_FEATURES_V2, + CollateDTensor, + distribute_features, + get_flattened_group, +) + + +class PredictionDatasetCPWithDTensorV2(torch.utils.data.Dataset): + """Prediction dataset with DTensor context parallelism for Boltz2.""" + + def __init__( + self, + manifest: Manifest, + target_dir: Path, + msa_dir: Path, + mol_dir: Path, + device_mesh: DeviceMesh, + device_mesh_cpu: DeviceMesh, + constraints_dir: Optional[Path] = None, + template_dir: Optional[Path] = None, + extra_mols_dir: Optional[Path] = None, + max_msa_seqs: int = const.max_msa_seqs, + msa_pad_to_max_seqs: bool = False, + max_data_retries: int = 5, + pair_mask_mode: PairMaskMode = PairMaskMode.NONE, + atoms_per_window_queries: Optional[int] = 32, + atoms_per_window_keys: Optional[int] = 128, + num_ensembles: int = 1, + per_shard_token_multiple: int = 1, + ) -> None: + super().__init__() + + self.manifest = manifest + self.target_dir = target_dir + self.msa_dir = msa_dir + self.mol_dir = mol_dir + + if constraints_dir is not None: + raise NotImplementedError("Constraints are not supported for CP") + + self.constraints_dir = constraints_dir + self.template_dir = template_dir + self.extra_mols_dir = extra_mols_dir + self.max_msa_seqs = max_msa_seqs + self.msa_pad_to_max_seqs = msa_pad_to_max_seqs + self.device_mesh = device_mesh + self.device_mesh_cpu = device_mesh_cpu + self.tokenizer = Boltz2Tokenizer() + self.featurizer = Boltz2Featurizer() + self.canonicals = load_canonicals(self.mol_dir) + self.max_data_retries = max_data_retries + self.num_ensembles = num_ensembles + self.per_shard_token_multiple = per_shard_token_multiple + + if (atoms_per_window_queries is None) != (atoms_per_window_keys is None): + raise ValueError( + "atoms_per_window_queries and atoms_per_window_keys must be either both None or both not None" + ) + if pair_mask_mode == PairMaskMode.SEQUENCE_LOCAL_ATTENTION and atoms_per_window_queries is None: + raise ValueError("atoms_per_window_queries must not be None if pair_mask_mode is SequenceLocalAttention") + self.pair_mask_mode = pair_mask_mode + self.atoms_per_window_queries = atoms_per_window_queries + self.atoms_per_window_keys = atoms_per_window_keys + + self._cp_submesh = device_mesh_cpu[("cp_axis_0_cpu", "cp_axis_1_cpu")] + self._cp_submesh_group = get_flattened_group(self._cp_submesh, backend="gloo") + self.is_cp_rank_zero = tuple(device_mesh.get_coordinate()[1:]) == (0, 0) + + n_shards_axis_0 = self.device_mesh.shape[1] + if self.max_msa_seqs % n_shards_axis_0 != 0: + if self.msa_pad_to_max_seqs is False: + warnings.warn( + f"Number CP ranks along process group grid axis 0 {n_shards_axis_0} is not " + f"a integer divisor of max_msa_seqs {self.max_msa_seqs}. Will modify max_msa_seqs " + f"to a multiple of {n_shards_axis_0} and pad the MSA number of sequences to it" + ) + self.msa_pad_to_max_seqs = True + + self.FEATURE_TO_DTENSOR_PLACEMENT = INFERENCE_FEATURE_PLACEMENTS_V2 + self._fallback_depth = 0 + + @property + def use_window_batching(self) -> bool: + return self.pair_mask_mode == PairMaskMode.NONE + + def _raise_or_return_item_0(self, e: Exception) -> None: + if self.max_data_retries <= 0: + raise e + if self._fallback_depth >= self.max_data_retries: + raise RuntimeError( + f"Data loading failed {self.max_data_retries} consecutive times. " f"Last error: {e}" + ) from e + self._fallback_depth += 1 + try: + fallback_idx = np.random.randint(0, len(self)) + return self.__getitem__(fallback_idx) + finally: + self._fallback_depth -= 1 + + def __getitem__(self, idx: int) -> Dict[str, DTensor]: + if self.is_cp_rank_zero: + record = self.manifest.records[idx] + + try: + input_data = load_input( + record, + self.target_dir, + self.msa_dir, + constraints_dir=self.constraints_dir, + template_dir=self.template_dir, + extra_mols_dir=self.extra_mols_dir, + ) + except Exception as e: # noqa: BLE001 + print(f"Data loading failed on {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) + + try: + tokenized = self.tokenizer.tokenize(input_data) + except Exception as e: # noqa: BLE001 + print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) + + try: + molecules = {} + molecules.update(self.canonicals) + if input_data.extra_mols: + molecules.update(input_data.extra_mols) + mol_names = set(tokenized.tokens["res_name"].tolist()) + mol_names = mol_names - set(molecules.keys()) + molecules.update(load_molecules(self.mol_dir, mol_names)) + except Exception as e: # noqa: BLE001 + print(f"Molecule loading failed for {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) + + seed = 42 + random = np.random.default_rng(seed) + + # Pad dimensions to be divisible by the CP shard dimension (and atoms_per_window_queries for atoms) + n_shards_axis_0 = self.device_mesh.shape[1] + W = self.atoms_per_window_queries or 32 + + max_tokens = tokenized.tokens.shape[0] + max_seqs = self.max_msa_seqs + pad_to_max_seqs = self.msa_pad_to_max_seqs + + token_align = n_shards_axis_0 * self.per_shard_token_multiple + if max_tokens % token_align != 0: + max_tokens = ((max_tokens + token_align - 1) // token_align) * token_align + + if self.use_window_batching: + max_atoms = None + else: + max_atoms = int(np.sum(tokenized.tokens["atom_num"])) if len(tokenized.tokens) > 0 else 0 + # Must be divisible by both atoms_per_window_queries and n_shards_axis_0 + atom_align = math.lcm(W, n_shards_axis_0) + if max_atoms % atom_align != 0: + max_atoms = ((max_atoms + atom_align - 1) // atom_align) * atom_align + + if max_seqs % n_shards_axis_0 != 0: + max_seqs = max_seqs + n_shards_axis_0 - max_seqs % n_shards_axis_0 + + try: + features = self.featurizer.process( + tokenized, + molecules=molecules, + random=random, + training=False, + max_atoms=max_atoms, + max_tokens=max_tokens, + max_seqs=max_seqs, + pad_to_max_seqs=pad_to_max_seqs, + atoms_per_window_queries=self.atoms_per_window_queries, + num_ensembles=self.num_ensembles, + fix_single_ensemble=self.num_ensembles == 1, + compute_frames=True, + compute_constraint_features=False, + ) + # Distributed-specific features not produced by the base featurizer + mask = features["token_pad_mask"] + features["token_pair_pad_mask"] = mask[:, None] * mask[None, :] + if self.pair_mask_mode == PairMaskMode.GLOBAL_ATOM_ATTENTION: + N_atoms = len(features["ref_pos"]) + features["pair_mask"] = torch.ones(N_atoms, N_atoms, dtype=torch.float) + elif self.pair_mask_mode == PairMaskMode.SEQUENCE_LOCAL_ATTENTION: + features["pair_mask"] = get_pair_mask( + N_atoms=len(features["ref_pos"]), W=self.atoms_per_window_queries + ) + except Exception as e: # noqa: BLE001 + print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201 + return self._raise_or_return_item_0(e) + + if not pad_to_max_seqs: + num_seqs_actual = features["msa_mask"].shape[0] + target_seqs = max(1, num_seqs_actual) + if target_seqs % n_shards_axis_0 != 0: + target_seqs = target_seqs + n_shards_axis_0 - target_seqs % n_shards_axis_0 + if num_seqs_actual < target_seqs: + pad_len = target_seqs - num_seqs_actual + msa_feature_keys = ("msa", "msa_paired", "deletion_value", "has_deletion", "msa_mask") + for key in msa_feature_keys: + if key in features: + features[key] = ( + pad_dim(features[key], 0, pad_len, const.token_ids["-"]) + if key == "msa" + else pad_dim(features[key], 0, pad_len) + ) + + record_list = [record] + atom_placements = { + key: self.FEATURE_TO_DTENSOR_PLACEMENT[key] + for key in ATOM_FEATURES_V2 + if key in self.FEATURE_TO_DTENSOR_PLACEMENT + } + atom_features = {key: features[key] for key in features if key in atom_placements} + token_and_msa_features = { + key: features[key] + for key in features + if key not in ATOM_FEATURES_V2 and key in self.FEATURE_TO_DTENSOR_PLACEMENT + } + token_and_msa_placements = { + key: self.FEATURE_TO_DTENSOR_PLACEMENT[key] + for key in self.FEATURE_TO_DTENSOR_PLACEMENT + if key not in ATOM_FEATURES_V2 + } + + else: + features = None + record_list = [None] + atom_features = None + atom_placements = { + key: self.FEATURE_TO_DTENSOR_PLACEMENT[key] + for key in ATOM_FEATURES_V2 + if key in self.FEATURE_TO_DTENSOR_PLACEMENT + } + token_and_msa_features = None + token_and_msa_placements = { + key: self.FEATURE_TO_DTENSOR_PLACEMENT[key] + for key in self.FEATURE_TO_DTENSOR_PLACEMENT + if key not in ATOM_FEATURES_V2 + } + + if self.use_window_batching: + atom_placements.pop("pair_mask") + + cp_submesh = self._cp_submesh + cp_submesh_group = self._cp_submesh_group + cp_group_src_rank_global = min(torch.distributed.get_process_group_ranks(cp_submesh_group)) + + atom_features_dtensor = pad_and_scatter_atom_features_dtensor( + features=atom_features, + placements=atom_placements, + group=cp_submesh_group, + src_rank_global=cp_group_src_rank_global, + device_mesh=cp_submesh, + ) + + token_and_msa_features_dtensor = distribute_features( + features=token_and_msa_features, + placements=token_and_msa_placements, + group=cp_submesh_group, + src_rank_global=cp_group_src_rank_global, + device_mesh=cp_submesh, + ) + + features_dtensor = {**token_and_msa_features_dtensor, **atom_features_dtensor} + + torch.distributed.broadcast_object_list(record_list, src=cp_group_src_rank_global, group=cp_submesh_group) + features_dtensor["record"] = record_list[0] + + return features_dtensor + + def __len__(self) -> int: + return len(self.manifest.records) + + +class Boltz2InferenceDataModuleDTensor(pl.LightningDataModule): + """DataModule for Boltz2 distributed inference with DTensor CP.""" + + def __init__( + self, + manifest: Manifest, + target_dir: Path, + msa_dir: Path, + mol_dir: Path, + num_workers: int, + device_mesh: DeviceMesh, + device_mesh_cpu: DeviceMesh, + constraints_dir: Optional[Path] = None, + template_dir: Optional[Path] = None, + extra_mols_dir: Optional[Path] = None, + max_msa_seqs: int = const.max_msa_seqs, + msa_pad_to_max_seqs: bool = False, + max_data_retries: int = 5, + pair_mask_mode: PairMaskMode = PairMaskMode.NONE, + atoms_per_window_queries: int = 32, + atoms_per_window_keys: int = 128, + local_batch_size: int = 1, + num_ensembles: int = 1, + per_shard_token_multiple: int = 1, + ) -> None: + super().__init__() + if num_workers != 0: + raise NotImplementedError("num_workers != 0 is not supported for CP") + self.num_workers = num_workers + self.manifest = manifest + self.target_dir = target_dir + self.msa_dir = msa_dir + self.mol_dir = mol_dir + if constraints_dir is not None: + raise NotImplementedError("Constraints are not supported for CP") + self.constraints_dir = constraints_dir + self.template_dir = template_dir + self.extra_mols_dir = extra_mols_dir + self.max_msa_seqs = max_msa_seqs + self.msa_pad_to_max_seqs = msa_pad_to_max_seqs + self.device_mesh = device_mesh + self.device_mesh_cpu = device_mesh_cpu + self.max_data_retries = max_data_retries + self.pair_mask_mode = pair_mask_mode + self.atoms_per_window_queries = atoms_per_window_queries + self.atoms_per_window_keys = atoms_per_window_keys + self.dataset: Optional[PredictionDatasetCPWithDTensorV2] = None + self.local_batch_size = local_batch_size + self.num_ensembles = num_ensembles + self.per_shard_token_multiple = per_shard_token_multiple + + def setup(self, stage: Optional[str] = None) -> None: + if stage != "predict": + raise ValueError(f"Only predict stage is supported for inference but got {stage}") + + self.dataset = PredictionDatasetCPWithDTensorV2( + manifest=self.manifest, + target_dir=self.target_dir, + msa_dir=self.msa_dir, + mol_dir=self.mol_dir, + device_mesh=self.device_mesh, + device_mesh_cpu=self.device_mesh_cpu, + constraints_dir=self.constraints_dir, + template_dir=self.template_dir, + extra_mols_dir=self.extra_mols_dir, + max_msa_seqs=self.max_msa_seqs, + msa_pad_to_max_seqs=self.msa_pad_to_max_seqs, + max_data_retries=self.max_data_retries, + pair_mask_mode=self.pair_mask_mode, + num_ensembles=self.num_ensembles, + per_shard_token_multiple=self.per_shard_token_multiple, + ) + + def predict_dataloader(self) -> DataLoader: + sampler = DistributedSampler( + self.dataset, + num_replicas=self.device_mesh_cpu.shape[0], + rank=self.device_mesh_cpu.get_local_rank(0), + shuffle=False, + drop_last=False, + ) + custom_collate = CollateDTensor(self.device_mesh_cpu) + + return DataLoader( + self.dataset, + batch_size=self.local_batch_size, + num_workers=self.num_workers, + pin_memory=False, + shuffle=False, + collate_fn=custom_collate, + sampler=sampler, + ) + + def transfer_batch_to_device( + self, + batch: dict, + device: torch.device, + dataloader_idx: int, # noqa: ARG002 + ) -> dict: + for key in batch: + if key not in {"record"}: + batch_local = batch[key].to_local().to(device) + batch[key] = DTensor.from_local( + batch_local, + device_mesh=self.device_mesh, + placements=batch[key].placements, + shape=batch[key].shape, + stride=batch[key].stride(), + ) + + return batch diff --git a/src/boltz/distributed/data/module/placements.py b/src/boltz/distributed/data/module/placements.py new file mode 100644 index 000000000..f71983216 --- /dev/null +++ b/src/boltz/distributed/data/module/placements.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from boltz.distributed.data.utils import ( + PLACEMENT_TYPE_SHARD0_REPLICATE, + PLACEMENT_TYPE_SHARD0_SHARD1, + PLACEMENT_TYPE_SHARD1_REPLICATE, +) + +BASE_FEATURE_PLACEMENTS_V2: dict[str, tuple] = { + # Atom features + "ref_pos": PLACEMENT_TYPE_SHARD0_REPLICATE, + "ref_charge": PLACEMENT_TYPE_SHARD0_REPLICATE, + "atom_resolved_mask": PLACEMENT_TYPE_SHARD0_REPLICATE, + "ref_element": PLACEMENT_TYPE_SHARD0_REPLICATE, + "ref_atom_name_chars": PLACEMENT_TYPE_SHARD0_REPLICATE, + "ref_space_uid": PLACEMENT_TYPE_SHARD0_REPLICATE, + "coords": PLACEMENT_TYPE_SHARD1_REPLICATE, + "atom_counts_per_token": PLACEMENT_TYPE_SHARD0_REPLICATE, + "frame_resolved_mask": PLACEMENT_TYPE_SHARD1_REPLICATE, + "frames_idx": PLACEMENT_TYPE_SHARD1_REPLICATE, + "atom_pad_mask": PLACEMENT_TYPE_SHARD0_REPLICATE, + "atom_to_token": PLACEMENT_TYPE_SHARD0_REPLICATE, + "token_to_rep_atom": PLACEMENT_TYPE_SHARD0_REPLICATE, + "r_set_to_rep_atom": PLACEMENT_TYPE_SHARD0_REPLICATE, + "ref_chirality": PLACEMENT_TYPE_SHARD0_REPLICATE, + "atom_backbone_feat": PLACEMENT_TYPE_SHARD0_REPLICATE, + "bfactor": PLACEMENT_TYPE_SHARD0_REPLICATE, + "plddt": PLACEMENT_TYPE_SHARD0_REPLICATE, + # Token features + "token_index": PLACEMENT_TYPE_SHARD0_REPLICATE, + "residue_index": PLACEMENT_TYPE_SHARD0_REPLICATE, + "asym_id": PLACEMENT_TYPE_SHARD0_REPLICATE, + "entity_id": PLACEMENT_TYPE_SHARD0_REPLICATE, + "sym_id": PLACEMENT_TYPE_SHARD0_REPLICATE, + "mol_type": PLACEMENT_TYPE_SHARD0_REPLICATE, + "res_type": PLACEMENT_TYPE_SHARD0_REPLICATE, + "disto_center": PLACEMENT_TYPE_SHARD0_REPLICATE, + "disto_target": PLACEMENT_TYPE_SHARD0_SHARD1, + "disto_coords_ensemble": PLACEMENT_TYPE_SHARD1_REPLICATE, + "token_bonds": PLACEMENT_TYPE_SHARD0_SHARD1, + "type_bonds": PLACEMENT_TYPE_SHARD0_SHARD1, + "token_pad_mask": PLACEMENT_TYPE_SHARD0_REPLICATE, + "token_resolved_mask": PLACEMENT_TYPE_SHARD0_REPLICATE, + "token_disto_mask": PLACEMENT_TYPE_SHARD0_REPLICATE, + "token_pair_pad_mask": PLACEMENT_TYPE_SHARD0_SHARD1, + "pair_mask": PLACEMENT_TYPE_SHARD0_SHARD1, + "contact_conditioning": PLACEMENT_TYPE_SHARD0_SHARD1, + "contact_threshold": PLACEMENT_TYPE_SHARD0_SHARD1, + "cyclic_period": PLACEMENT_TYPE_SHARD0_REPLICATE, + "method_feature": PLACEMENT_TYPE_SHARD0_REPLICATE, + "modified": PLACEMENT_TYPE_SHARD0_REPLICATE, + # MSA features + "msa": PLACEMENT_TYPE_SHARD0_SHARD1, + "msa_paired": PLACEMENT_TYPE_SHARD0_SHARD1, + "deletion_value": PLACEMENT_TYPE_SHARD0_SHARD1, + "has_deletion": PLACEMENT_TYPE_SHARD0_SHARD1, + "msa_mask": PLACEMENT_TYPE_SHARD0_SHARD1, + "deletion_mean": PLACEMENT_TYPE_SHARD0_REPLICATE, + "profile": PLACEMENT_TYPE_SHARD0_REPLICATE, +} + +TRAINING_FEATURE_PLACEMENTS_V2: dict[str, tuple] = { + **BASE_FEATURE_PLACEMENTS_V2, + "temp_feature": PLACEMENT_TYPE_SHARD0_REPLICATE, + "ph_feature": PLACEMENT_TYPE_SHARD0_REPLICATE, +} + +INFERENCE_FEATURE_PLACEMENTS_V2: dict[str, tuple] = { + **BASE_FEATURE_PLACEMENTS_V2, + "affinity_token_mask": PLACEMENT_TYPE_SHARD0_REPLICATE, +} diff --git a/src/boltz/distributed/data/module/trainingv2.py b/src/boltz/distributed/data/module/trainingv2.py new file mode 100644 index 000000000..40188ef93 --- /dev/null +++ b/src/boltz/distributed/data/module/trainingv2.py @@ -0,0 +1,519 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from typing import Any, Optional + +import pytorch_lightning as pl +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Shard +from torch.utils.data import DataLoader, DistributedSampler + +from boltz.data import const +from boltz.data.module.trainingv2 import ( + Boltz2TrainingDataModule as Boltz2TrainingDataModuleSerial, +) +from boltz.data.module.trainingv2 import DataConfigV2 +from boltz.data.pad import pad_dim +from boltz.distributed.data.feature.featurizer import pad_and_scatter_atom_features_dtensor +from boltz.distributed.data.module.placements import TRAINING_FEATURE_PLACEMENTS_V2 +from boltz.distributed.data.utils import ( + ATOM_FEATURES_V2, + NON_SHARDED_FEATURES_V2, + CollateDTensor, + distribute_features, + get_flattened_group, +) + + +class _BaseDatasetCPWithDTensorV2(torch.utils.data.Dataset): + """Wrap a serial Boltz2 dataset and distribute features as DTensors.""" + + def __init__( + self, + serial_dataset: torch.utils.data.Dataset, + device_mesh: DeviceMesh, + device_mesh_cpu: DeviceMesh, + ) -> None: + """Initialize the distributed dataset wrapper. + + Wraps a serial Boltz2 dataset and sets up the device mesh and placement + mapping used to distribute features as DTensors across context-parallel + ranks. + + Parameters + ---------- + serial_dataset : torch.utils.data.Dataset + The serial (single-rank) Boltz2 dataset to wrap. + device_mesh : DeviceMesh + Device mesh for distributed tensor operations on GPU. + device_mesh_cpu : DeviceMesh + Device mesh for distributed tensor operations on CPU, used for + data-loading collectives that run before GPU transfer. + + """ + super().__init__() + self.serial_dataset = serial_dataset + self.device_mesh = device_mesh + self.device_mesh_cpu = device_mesh_cpu + self._cp_submesh = device_mesh_cpu[("cp_axis_0_cpu", "cp_axis_1_cpu")] + self._cp_submesh_group = get_flattened_group(self._cp_submesh, backend="gloo") + self.is_cp_rank_zero = tuple(device_mesh.get_coordinate()[1:]) == (0, 0) + + self.feature_to_dtensor_placement = TRAINING_FEATURE_PLACEMENTS_V2 + + def __len__(self) -> int: + """Get the length of the dataset. + + Returns + ------- + int + The number of samples in the underlying serial dataset. + + """ + return len(self.serial_dataset) + + def _distribute_features(self, features: Optional[dict[str, Any]]) -> dict[str, Any]: + """Distribute serial features as DTensors across context-parallel ranks. + + CP rank zero pads tensor features to be evenly shardable, then broadcasts + tensor keys and non-sharded metadata to all CP ranks. Atom features are + distributed via ``pad_and_scatter_atom_features_dtensor`` and token/MSA + features via ``distribute_features``. + + Parameters + ---------- + features : dict[str, Any] or None + Feature dictionary produced by the serial dataset on CP rank zero. + Must be ``None`` on non-zero CP ranks. + + Returns + ------- + dict[str, Any] + Feature dictionary where tensor values are DTensors distributed + according to ``self.feature_to_dtensor_placement``, and non-sharded + values are broadcast copies. + + """ + cp_group_src_rank_global = min(torch.distributed.get_process_group_ranks(self._cp_submesh_group)) + + if self.is_cp_rank_zero: + # Synthesize token_pair_pad_mask if the serial featurizer did not + # produce it. The distributed model forward requires this feature + # (boltz2.py, trunkv2.py) but the serial v2 training featurizer + # does not generate it — only the inference featurizer does. + if "token_pair_pad_mask" not in features and "token_pad_mask" in features: + mask = features["token_pad_mask"] + features["token_pair_pad_mask"] = mask[:, None] * mask[None, :] + + unknown_tensor_keys = sorted( + key + for key, value in features.items() + if isinstance(value, torch.Tensor) + and key not in self.feature_to_dtensor_placement + and key not in NON_SHARDED_FEATURES_V2 + ) + if unknown_tensor_keys: + raise KeyError( + "Found tensor feature keys without DTensor placement mapping. " + f"Please add placements for: {unknown_tensor_keys}" + ) + + tensor_features_all = { + key: value + for key, value in features.items() + if isinstance(value, torch.Tensor) + and key in self.feature_to_dtensor_placement + and key not in NON_SHARDED_FEATURES_V2 + } + + n_shards_axis_0 = self.device_mesh.shape[1] + for key, tensor in tensor_features_all.items(): + placements = self.feature_to_dtensor_placement[key] + padded = tensor + for placement in placements: + if not isinstance(placement, Shard): + continue + shard_dim = placement.dim + if shard_dim >= padded.ndim: + continue + remainder = padded.shape[shard_dim] % n_shards_axis_0 + if remainder == 0: + continue + pad_len = n_shards_axis_0 - remainder + pad_value = const.token_ids["-"] if key == "msa" and shard_dim == 0 else 0 + padded = pad_dim(padded, shard_dim, pad_len, pad_value) + tensor_features_all[key] = padded + tensor_feature_keys = sorted(tensor_features_all.keys()) + keys_payload = [tensor_feature_keys] + torch.distributed.broadcast_object_list( + keys_payload, + src=cp_group_src_rank_global, + group=self._cp_submesh_group, + ) + tensor_feature_keys_shared = keys_payload[0] + + atom_placements = { + key: self.feature_to_dtensor_placement[key] + for key in ATOM_FEATURES_V2 + if key in tensor_feature_keys_shared + } + atom_features = {key: tensor_features_all[key] for key in atom_placements} + token_and_msa_features = { + key: value for key, value in tensor_features_all.items() if key not in ATOM_FEATURES_V2 + } + token_and_msa_placements = {key: self.feature_to_dtensor_placement[key] for key in token_and_msa_features} + non_sharded_features = {key: value for key, value in features.items() if key in NON_SHARDED_FEATURES_V2} + object_payload = [non_sharded_features] + else: + keys_payload = [None] + torch.distributed.broadcast_object_list( + keys_payload, + src=cp_group_src_rank_global, + group=self._cp_submesh_group, + ) + tensor_feature_keys_shared = keys_payload[0] + atom_features = None + token_and_msa_features = None + atom_placements = { + key: self.feature_to_dtensor_placement[key] + for key in ATOM_FEATURES_V2 + if key in tensor_feature_keys_shared + } + token_and_msa_placements = { + key: placement + for key, placement in self.feature_to_dtensor_placement.items() + if key in tensor_feature_keys_shared and key not in ATOM_FEATURES_V2 + } + object_payload = [None] + + atom_features_dtensor = pad_and_scatter_atom_features_dtensor( + features=atom_features, + placements=atom_placements, + group=self._cp_submesh_group, + src_rank_global=cp_group_src_rank_global, + device_mesh=self._cp_submesh, + ) + token_and_msa_features_dtensor = distribute_features( + features=token_and_msa_features, + placements=token_and_msa_placements, + group=self._cp_submesh_group, + src_rank_global=cp_group_src_rank_global, + device_mesh=self._cp_submesh, + ) + + torch.distributed.broadcast_object_list( + object_payload, + src=cp_group_src_rank_global, + group=self._cp_submesh_group, + ) + + features_dtensor = {**token_and_msa_features_dtensor, **atom_features_dtensor} + features_dtensor.update(object_payload[0] or {}) + return features_dtensor + + +class TrainingDatasetCPWithDTensorV2(_BaseDatasetCPWithDTensorV2): + """Training dataset with DTensor context parallelism for Boltz2.""" + + def __getitem__(self, idx: int) -> dict[str, Any]: + """Fetch and distribute a single training sample. + + CP rank zero retrieves the sample from the serial dataset; all other + CP ranks receive the distributed DTensor features via collectives. + + Parameters + ---------- + idx : int + Sample index in the serial dataset. + + Returns + ------- + dict[str, Any] + Distributed feature dictionary with DTensor values. + + """ + features = self.serial_dataset[idx] if self.is_cp_rank_zero else None + return self._distribute_features(features) + + +class ValidationDatasetCPWithDTensorV2(_BaseDatasetCPWithDTensorV2): + """Validation dataset with DTensor context parallelism for Boltz2.""" + + def __init__( + self, + serial_dataset: torch.utils.data.Dataset, + device_mesh: DeviceMesh, + device_mesh_cpu: DeviceMesh, + val_skip_sample_threshold_tokens: Optional[int] = None, + val_skip_sample_threshold_atoms: Optional[int] = None, + val_skip_sample_threshold_seqs: Optional[int] = None, + ) -> None: + """Initialize the distributed validation dataset. + + Parameters + ---------- + serial_dataset : torch.utils.data.Dataset + The serial (single-rank) Boltz2 validation dataset to wrap. + device_mesh : DeviceMesh + Device mesh for distributed tensor operations on GPU. + device_mesh_cpu : DeviceMesh + Device mesh for distributed tensor operations on CPU. + val_skip_sample_threshold_tokens : int, optional + Skip validation samples with more tokens than this threshold to + prevent OOM. + val_skip_sample_threshold_atoms : int, optional + Skip validation samples with more atoms than this threshold to + prevent OOM. + val_skip_sample_threshold_seqs : int, optional + Skip validation samples with more MSA sequences than this threshold + to prevent OOM. + + """ + super().__init__(serial_dataset=serial_dataset, device_mesh=device_mesh, device_mesh_cpu=device_mesh_cpu) + self.val_skip_sample_threshold_tokens = val_skip_sample_threshold_tokens + self.val_skip_sample_threshold_atoms = val_skip_sample_threshold_atoms + self.val_skip_sample_threshold_seqs = val_skip_sample_threshold_seqs + + def __getitem__(self, idx: int) -> dict[str, Any]: + """Fetch and distribute a single validation sample. + + On CP rank zero, iterates from ``idx`` through the dataset looking for + a sample that satisfies all ``val_skip_sample_threshold_*`` constraints. + If no sample passes, raises ``RuntimeError``. Other CP ranks receive + the distributed DTensor features via collectives. + + Parameters + ---------- + idx : int + Starting sample index in the serial dataset. + + Returns + ------- + dict[str, Any] + Distributed feature dictionary with DTensor values. + + Raises + ------ + RuntimeError + If every sample in the dataset is filtered out by the thresholds. + + """ + if self.is_cp_rank_zero: + num_items = len(self.serial_dataset) + for shift in range(num_items): + curr_idx = (idx + shift) % num_items + features = self.serial_dataset[curr_idx] + + if self.val_skip_sample_threshold_tokens is not None and "token_pad_mask" in features: + tokens = int(features["token_pad_mask"].sum().item()) + if tokens > self.val_skip_sample_threshold_tokens: + continue + if self.val_skip_sample_threshold_atoms is not None and "atom_pad_mask" in features: + atoms = int(features["atom_pad_mask"].sum().item()) + if atoms > self.val_skip_sample_threshold_atoms: + continue + if self.val_skip_sample_threshold_seqs is not None and "msa_mask" in features: + msa_mask = features["msa_mask"] + seqs = int((msa_mask.sum(dim=1) > 0).sum().item()) if msa_mask.ndim == 2 else int(msa_mask.shape[0]) + if seqs > self.val_skip_sample_threshold_seqs: + continue + break + else: + raise RuntimeError("All validation samples were filtered out by val_skip_sample_threshold_*") + else: + features = None + + return self._distribute_features(features) + + +class Boltz2TrainingDataModule(pl.LightningDataModule): + """DataModule for Boltz2 distributed training with DTensor CP.""" + + def __init__( + self, + cfg: DataConfigV2, + device_mesh: DeviceMesh, + device_mesh_cpu: DeviceMesh, + ) -> None: + """Initialize the distributed training data module. + + Wraps the serial ``Boltz2TrainingDataModule`` and its datasets with + DTensor context-parallel distribution. Internally creates + :class:`TrainingDatasetCPWithDTensorV2` and + :class:`ValidationDatasetCPWithDTensorV2` from the serial module's + datasets. + + Parameters + ---------- + cfg : DataConfigV2 + The data configuration. + device_mesh : DeviceMesh + Device mesh for distributed tensor operations on GPU. + device_mesh_cpu : DeviceMesh + Device mesh for distributed tensor operations on CPU. + + Raises + ------ + NotImplementedError + If ``cfg.num_workers != 0``, since multi-worker loading is + incompatible with DTensor CP collectives in the dataset. + + """ + super().__init__() + if cfg.num_workers != 0: + raise NotImplementedError("num_workers != 0 is not supported for CP") + + self.cfg = cfg + self.device_mesh = device_mesh + self.device_mesh_cpu = device_mesh_cpu + self._serial_module = Boltz2TrainingDataModuleSerial(cfg=cfg) + self.val_group_mapper = self._serial_module.val_group_mapper + + self._train_set = TrainingDatasetCPWithDTensorV2( + serial_dataset=self._serial_module._train_set, + device_mesh=self.device_mesh, + device_mesh_cpu=self.device_mesh_cpu, + ) + self._val_set = ValidationDatasetCPWithDTensorV2( + serial_dataset=self._serial_module._val_set, + device_mesh=self.device_mesh, + device_mesh_cpu=self.device_mesh_cpu, + val_skip_sample_threshold_tokens=cfg.val_skip_sample_threshold_tokens, + val_skip_sample_threshold_atoms=cfg.val_skip_sample_threshold_atoms, + val_skip_sample_threshold_seqs=cfg.val_skip_sample_threshold_seqs, + ) + + def setup(self, stage: Optional[str] = None) -> None: # noqa: ARG002 + """Run the setup for the DataModule. + + No-op because the serial module and CP-wrapped datasets are fully + initialized in ``__init__``. + + Parameters + ---------- + stage : str, optional + The stage, one of 'fit', 'validate', 'test'. + + """ + return + + def train_dataloader(self) -> DataLoader: + """Get the training dataloader. + + Returns + ------- + DataLoader + The training dataloader with a ``DistributedSampler`` partitioned + across data-parallel replicas and a ``CollateDTensor`` collate + function. + + """ + sampler = DistributedSampler( + self._train_set, + num_replicas=self.device_mesh_cpu.shape[0], + rank=self.device_mesh_cpu.get_local_rank(0), + shuffle=False, + drop_last=False, + ) + custom_collate = CollateDTensor(self.device_mesh_cpu) + return DataLoader( + self._train_set, + sampler=sampler, + batch_size=self.cfg.batch_size, + num_workers=self.cfg.num_workers, + pin_memory=self.cfg.pin_memory, + shuffle=False, + collate_fn=custom_collate, + ) + + def val_dataloader(self) -> DataLoader: + """Get the validation dataloader. + + Returns + ------- + DataLoader + The validation dataloader with a ``DistributedSampler`` partitioned + across data-parallel replicas and a ``CollateDTensor`` collate + function. + + """ + sampler = DistributedSampler( + self._val_set, + num_replicas=self.device_mesh_cpu.shape[0], + rank=self.device_mesh_cpu.get_local_rank(0), + shuffle=False, + drop_last=False, + ) + custom_collate = CollateDTensor(self.device_mesh_cpu) + return DataLoader( + self._val_set, + sampler=sampler, + batch_size=self.cfg.val_batch_size, + num_workers=self.cfg.num_workers, + pin_memory=self.cfg.pin_memory, + shuffle=False, + collate_fn=custom_collate, + ) + + def transfer_batch_to_device( + self, + batch: dict, + device: torch.device, + dataloader_idx: int, # noqa: ARG002 + ) -> dict: + """Transfer a batch from CPU DTensors to the target device. + + DTensor values are moved by extracting the local shard, transferring it + to ``device``, and re-wrapping with the GPU ``device_mesh``. Plain + tensors and lists of tensors are transferred directly. + + Parameters + ---------- + batch : dict + The batch to transfer. + device : torch.device + The target device (typically a CUDA device). + dataloader_idx : int + The dataloader index (unused). + + Returns + ------- + dict + The batch with all tensor values on ``device``. + + """ + for key, value in batch.items(): + if isinstance(value, DTensor): + batch_local = value.to_local().to(device) + batch[key] = DTensor.from_local( + batch_local, + device_mesh=self.device_mesh, + placements=value.placements, + shape=value.shape, + stride=value.stride(), + ) + elif isinstance(value, list): + batch[key] = [item.to(device) if isinstance(item, torch.Tensor) else item for item in value] + elif isinstance(value, torch.Tensor): + batch[key] = value.to(device) + + return batch diff --git a/src/boltz/distributed/data/types.py b/src/boltz/distributed/data/types.py new file mode 100644 index 000000000..54fffad66 --- /dev/null +++ b/src/boltz/distributed/data/types.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from enum import Enum + + +class PairMaskMode(Enum): + """Controls pair mask generation for atom features.""" + + NONE = "None" + GLOBAL_ATOM_ATTENTION = "GlobalAtomAttention" + SEQUENCE_LOCAL_ATTENTION = "SequenceLocalAttention" diff --git a/src/boltz/distributed/data/utils.py b/src/boltz/distributed/data/utils.py new file mode 100644 index 000000000..42ea7c659 --- /dev/null +++ b/src/boltz/distributed/data/utils.py @@ -0,0 +1,646 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from collections import OrderedDict +from typing import NamedTuple, Optional + +import torch +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, distribute_tensor +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.types import Tensor + +from boltz.data.pad import pad_to_max +from boltz.distributed.data.feature.featurizer_utils import ATOM_INDEX_FEATURES, remap_atom_indices_repad +from boltz.distributed.manager import DistributedManager +from boltz.distributed.utils import LayoutRightMap + +# Placement types used inside the DataSet class which only sees the NxN CP group +PLACEMENT_TYPE_SHARD0_REPLICATE = (Shard(0), Replicate()) # For single repr +PLACEMENT_TYPE_SHARD0_SHARD1 = (Shard(0), Shard(1)) # For pair repr +PLACEMENT_TYPE_SHARD1_REPLICATE = (Shard(1), Replicate()) # For ensemble-aware features + + +ATOM_FEATURES = { + "ref_pos", + "atom_resolved_mask", + "ref_element", + "ref_charge", + "ref_atom_name_chars", + "ref_space_uid", + "coords", + "atom_pad_mask", + "atom_to_token", + "atom_counts_per_token", + "pair_mask", + "token_to_rep_atom", + "r_set_to_rep_atom", + "frames_idx", +} + +ATOM_FEATURES_V2 = ATOM_FEATURES | { + "ref_chirality", + "atom_backbone_feat", + "bfactor", + "plddt", +} + +SYMMETRY_FEATURES = { + "all_coords", + "all_resolved_mask", + "crop_to_all_atom_map", + "chain_symmetries", + "amino_acids_symmetries", + "ligand_symmetries", +} +SYMMETRY_FEATURES_V2 = SYMMETRY_FEATURES | { + "chain_swaps", +} + +RESIDUE_CONSTRAINT_FEATURES = { + "rdkit_bounds_index", + "rdkit_bounds_bond_mask", + "rdkit_bounds_angle_mask", + "rdkit_upper_bounds", + "rdkit_lower_bounds", + "chiral_atom_index", + "chiral_reference_mask", + "chiral_atom_orientations", + "stereo_bond_index", + "stereo_reference_mask", + "stereo_bond_orientations", + "planar_bond_index", + "planar_ring_5_index", + "planar_ring_6_index", +} + +CHAIN_CONSTRAINT_FEATURES = { + "connected_chain_index", + "connected_atom_index", + "symmetric_chain_index", +} + +CONTACT_CONSTRAINT_FEATURES = { + "contact_pair_index", + "contact_union_index", + "contact_negation_mask", + "contact_thresholds", +} + +NON_SHARDED_FEATURES = {"record"} | SYMMETRY_FEATURES | RESIDUE_CONSTRAINT_FEATURES | CHAIN_CONSTRAINT_FEATURES +TRAINING_METADATA_FEATURES = { + "chain_swaps", + "activity_name", + "activity_qualifier", + "idx_dataset", + "sid", + "cid", + "normalized_protein_accession", + "pair_id", + "pdb_id", +} + +LIGAND_GEOMETRY_FEATURES = { + "ligand_edge_index", + "ligand_edge_lower_bounds", + "ligand_edge_upper_bounds", + "ligand_edge_bond_mask", + "ligand_edge_angle_mask", + "connections_edge_index", + "ligand_chiral_atom_index", + "ligand_chiral_check_mask", + "ligand_chiral_atom_orientations", + "ligand_stereo_bond_index", + "ligand_stereo_check_mask", + "ligand_stereo_bond_orientations", + "ligand_aromatic_5_ring_index", + "ligand_aromatic_6_ring_index", + "ligand_planar_double_bond_index", +} + +NON_SHARDED_FEATURES_V2 = ( + {"record"} + | SYMMETRY_FEATURES_V2 + | RESIDUE_CONSTRAINT_FEATURES + | CHAIN_CONSTRAINT_FEATURES + | { + "affinity_mw", + "ensemble_ref_idxs", + "template_force", + "template_force_threshold", + } + | CONTACT_CONSTRAINT_FEATURES + | TRAINING_METADATA_FEATURES + | LIGAND_GEOMETRY_FEATURES +) + + +class TensorMetadata(NamedTuple): + dtype: torch.dtype + shape: torch.Size + + +def broadcast_feature_tensors_metadata( + features: dict[str, Tensor] | None, group: ProcessGroup, src_rank_global: int +) -> OrderedDict[str, tuple[torch.dtype, torch.Size]]: + """Broadcast tensor metadata from source rank to all ranks in a process group. + + This function extracts metadata (dtype and shape) from feature tensors on the source + rank, broadcasts this information to all other ranks in the process group, and returns + it as an OrderedDict to preserve feature ordering consistency across ranks. This allows + non-source ranks to know the structure of tensors without having the actual tensor data, + and ensures all ranks iterate over features in the same order. + + Parameters + ---------- + features : dict[str, Tensor] | None + Dictionary mapping feature names to tensors. Must be a dict on the source rank + and None on all other ranks. Feature names can include random prefixes to test + ordering consistency across distributed ranks. + group : ProcessGroup + The distributed process group to broadcast within. + src_rank_global : int + The global rank that serves as the source for broadcasting metadata. + This rank must be included in the process group. + + Returns + ------- + OrderedDict[str, tuple[torch.dtype, torch.Size]] + Ordered dictionary mapping feature names to tuples of (dtype, shape) where dtype + is the PyTorch data type and shape is the tensor's size. The order is preserved + from the source rank to ensure consistent iteration order across all ranks. + + Raises + ------ + ValueError + If source rank doesn't provide a dict input, non-source ranks provide non-None + input, or feature values are not tensors. + + Notes + ----- + - Returns an OrderedDict to guarantee consistent feature iteration order across ranks + - This is critical for distributed processing where all ranks must process features + in the same sequence to maintain synchronization + - The metadata includes TensorMetadata named tuples containing dtype and shape info + - Feature names may include random prefixes for testing ordering robustness + """ + rank_global = torch.distributed.get_rank() + is_src_rank = rank_global == src_rank_global + + # Check that src_rank_global is in the process group + ranks_in_group = torch.distributed.get_process_group_ranks(group) + if src_rank_global not in ranks_in_group: + raise ValueError(f"Source rank {src_rank_global} not in group {ranks_in_group}") + + if is_src_rank and not isinstance(features, dict): + raise ValueError(f"Source rank {src_rank_global} must have a dict input, got {type(features)}") + elif not is_src_rank and features is not None: + raise ValueError(f"Non-source rank (not {src_rank_global}) must have None input, this rank is {rank_global}") + if is_src_rank: + metadata = OrderedDict() + for k, v in features.items(): + if not isinstance(v, Tensor): + raise ValueError(f"Feature {k} is not a tensor, got {type(v)}") + metadata[k] = TensorMetadata(dtype=v.dtype, shape=v.shape) + else: + metadata = None + l_metadata = [metadata] + torch.distributed.broadcast_object_list(l_metadata, src=src_rank_global, group=group) + metadata = l_metadata[0] + return metadata + + +def distribute_features( + features: dict[str, Tensor], + placements: dict[str, Placement], + group: ProcessGroup, + src_rank_global: int, + device_mesh: DeviceMesh, +) -> dict[str, DTensor]: + """Distribute feature tensors from source rank to all ranks as DTensors. + + This function takes feature tensors from a source rank and distributes them across + a device mesh according to specified placements. The metadata (dtype and shape) is + first broadcast to all ranks using broadcast_feature_tensors_metadata (which preserves + ordering), then each rank creates or uses existing tensors to form DTensors with the + specified placements. All ranks process features in the same order for consistency. + + Parameters + ---------- + features : dict[str, Tensor] + Dictionary mapping feature names to tensors. Must contain tensors on the source + rank and can be None on non-source ranks (only validated on source rank). + Feature names may include random prefixes for testing ordering robustness. + placements : dict[str, Placement] + Dictionary mapping feature names to their desired tensor placements (e.g., Shard, + Replicate). Must have the same keys as features on the source rank. + group : ProcessGroup + The distributed process group that contains all ranks in the device mesh. + src_rank_global : int + The global rank that serves as the source for the original tensors. + This rank must be included in the process group. + device_mesh : DeviceMesh + The device mesh defining the distributed tensor layout. Its ranks must match + the ranks in the process group. + + Returns + ------- + dict[str, DTensor] + Dictionary mapping feature names to distributed tensors (DTensors) with the + specified placements across the device mesh. While returned as a regular dict, + the internal processing ensures consistent ordering across ranks. + + Raises + ------ + ValueError + If features and placements don't have the same keys (validated only on source rank), + if the ranks in the process group don't match the ranks in the device mesh, + or if the source rank is not in the process group. + + Notes + ----- + - Uses broadcast_feature_tensors_metadata to ensure consistent feature ordering across ranks + - Non-source ranks create empty tensors with the correct dtype and shape from metadata + - Each tensor is distributed using PyTorch's distribute_tensor function with specified placements + - The resulting DTensors are ready for distributed computation with consistent sharding + - Order preservation is critical for synchronization during distributed training + - Returns regular dict for backward compatibility, but internal ordering is guaranteed + + Examples + -------- + >>> features = {"abc123_feature_0": torch.randn(10, 20)} # Only on source rank + >>> placements = {"abc123_feature_0": Shard(0)} + >>> dtensors = distribute_features(features, placements, group, 0, mesh) + """ + rank_global = torch.distributed.get_rank() + is_src_rank = rank_global == src_rank_global + + if is_src_rank and features.keys() != placements.keys(): + raise ValueError( + f"Features and placements must have the same keys, got {sorted(features.keys())} and {sorted(placements.keys())}" + ) + ranks_in_group = torch.distributed.get_process_group_ranks(group) + ranks_in_mesh = device_mesh.mesh.flatten().tolist() + if ranks_in_group != ranks_in_mesh: + raise ValueError( + f"Ranks in group {ranks_in_group} do not match ranks in mesh {ranks_in_mesh}, got {ranks_in_group} and {ranks_in_mesh}" + ) + + if src_rank_global not in ranks_in_group: + raise ValueError(f"Source rank {src_rank_global} not in group {ranks_in_group}") + + # this guarantees only src_rank_global have the features so the returned metadata is ordered by the iteration order + # of the features dictionary in the src_rank_global. This is important for the later for loop to iterate over the features + # in the same order among all ranks + metadata: OrderedDict[str, TensorMetadata] = broadcast_feature_tensors_metadata(features, group, src_rank_global) + + # on the other hand, to stay backward compatible with other usage in the data processing code, we return + # regular dict and it's up to the caller to make sure the iteration order is consistent among ranks + ans = {} + for name_feature, m in metadata.items(): + dtype, shape = m + placement = placements[name_feature] + if is_src_rank: + t = features[name_feature].to(device=device_mesh.device_type) + else: + t = torch.empty(shape, dtype=dtype, device=device_mesh.device_type) + ans[name_feature] = distribute_tensor(t, device_mesh, placements=placement) + return ans + + +def broadcast_tensors( + features: dict[str, Tensor] | None, + group: ProcessGroup, + src_rank_global: int, + device: str = "cpu", +) -> OrderedDict[str, Tensor]: + """Broadcast tensors from source rank to all ranks in a process group. + + This function broadcasts tensor data from the source rank to all other ranks + in the process group. The metadata (dtype and shape) is first broadcast using + broadcast_feature_tensors_metadata, then each tensor is broadcast individually. + All ranks process features in the same order (determined by the source rank's + iteration order) to ensure consistent communication and avoid deadlocks. + + Parameters + ---------- + features : dict[str, Tensor] | None + Dictionary mapping feature names to tensors. Must be a dict containing tensors + on the source rank and None on all other ranks. + group : ProcessGroup + The distributed process group to broadcast within. + src_rank_global : int + The global rank that serves as the source for broadcasting. + This rank must be included in the process group. + device : str + Device to use for broadcasting. For NCCL backend, this should be "cuda". + For gloo backend, this can be "cpu". Default is "cpu". + + Returns + ------- + OrderedDict[str, Tensor] + Ordered dictionary mapping feature names to broadcast tensors on the specified + device. The order is preserved from the source rank to ensure consistent + iteration order across all ranks. + + Raises + ------ + ValueError + If source rank doesn't provide a dict input, non-source ranks provide non-None + input, or feature values are not tensors (raised by broadcast_feature_tensors_metadata). + + Notes + ----- + - Uses broadcast_feature_tensors_metadata to ensure consistent feature ordering across ranks + - Non-source ranks create empty tensors with the correct dtype and shape from metadata + - Each tensor is broadcast using torch.distributed.broadcast + - Output tensors are on the specified device + - Order preservation is critical for synchronization during distributed processing + + Examples + -------- + >>> # On source rank (rank 0) + >>> features = {"coords": torch.randn(10, 3), "mask": torch.ones(10)} + >>> result = broadcast_tensors(features, group, src_rank_global=0, device="cuda") + >>> # On non-source ranks + >>> result = broadcast_tensors(None, group, src_rank_global=0, device="cuda") + >>> # All ranks now have the same tensors in result + """ + rank_global = torch.distributed.get_rank() + is_src_rank = rank_global == src_rank_global + + # broadcast_feature_tensors_metadata validates inputs and returns an OrderedDict + # with consistent ordering across all ranks (based on source rank's iteration order) + metadata: OrderedDict[str, TensorMetadata] = broadcast_feature_tensors_metadata(features, group, src_rank_global) + + # Iterate over metadata in consistent order across all ranks + ans = OrderedDict() + for name_feature, m in metadata.items(): + dtype, shape = m + if is_src_rank: + tensor = features[name_feature].to(device=device).contiguous() + else: + tensor = torch.empty(shape, dtype=dtype, device=device) + torch.distributed.broadcast(tensor, src=src_rank_global, group=group) + ans[name_feature] = tensor + + return ans + + +class CollateDTensor: + def __init__(self, output_device_mesh: DeviceMesh): + # Check that shape is like (dp, cp_axis_0, cp_axis_1) + if output_device_mesh.ndim != 3: + raise ValueError(f"CollateDTensor expects a DP-CP-CP device mesh but got ndim {output_device_mesh.ndim}") + + self._output_device_mesh = output_device_mesh + + def __call__(self, data: list[dict[str, DTensor]]) -> dict[str, DTensor]: + """Collate the data. + + Parameters + ---------- + data : List[Dict[str, DTensor]] + The data to collate. + + Returns + ------- + Dict[str, DTensor] + The collated data. + + """ + # Get the keys + keys = data[0].keys() + + # Pre-scan: determine final atom dim for atom-index remapping. + # Atom-index features store indices into a padded global atom array + # whose stride is max_atoms_per_shard. When collation pads the atom + # dimension (because samples differ or DP ranks differ), that stride + # changes and every stored index must be adjusted. + _ATOM_DIM_REFERENCE_KEY = "atom_pad_mask" + has_atom_index_features = any(k in ATOM_INDEX_FEATURES for k in keys) + if has_atom_index_features and _ATOM_DIM_REFERENCE_KEY in keys: + ref_locals = [d[_ATOM_DIM_REFERENCE_KEY].to_local() for d in data] + old_atoms_per_sample = [v.shape[0] for v in ref_locals] + local_max_atoms = max(old_atoms_per_sample) + + global_max_atoms_t = torch.tensor([local_max_atoms], device=self._output_device_mesh.device_type) + group = self._output_device_mesh.get_group(0) + torch.distributed.all_reduce(global_max_atoms_t, op=torch.distributed.ReduceOp.MAX, group=group) + final_atoms_per_shard = int(global_max_atoms_t.item()) + else: + old_atoms_per_sample = None + final_atoms_per_shard = None + + # Collate the data + collated = {} + for key in keys: + # special handling for non-DTensor features + if key in NON_SHARDED_FEATURES_V2: + collated[key] = [d[key] for d in data] + continue + + # change batch dim to shard, special handling for coords since it has a leading singleton dim + values = [d[key] for d in data] + placements = values[0].placements + + values_local = [value.to_local() for value in values] + + # Remap atom-index features before padding so indices reflect the + # final (post-collation) padded atom layout. + if key in ATOM_INDEX_FEATURES and old_atoms_per_sample is not None: + for i in range(len(values_local)): + values_local[i] = remap_atom_indices_repad( + values_local[i], old_atoms_per_sample[i], final_atoms_per_shard + ) + + # local collate + values_local, _ = pad_to_max(values_local, 0) # internally will stack if shapes are the same + values_local = values_local.contiguous() # contiguous implies layout right + + # global collate + local_shape_max = torch.tensor( + values_local.shape, + device=self._output_device_mesh.device_type, + ) + group = self._output_device_mesh.get_group(0) + torch.distributed.all_reduce(local_shape_max, op=torch.distributed.ReduceOp.MAX, group=group) + + # Pad local tensor to match global shape if needed + if values_local.shape != tuple(local_shape_max.tolist()): + # Calculate padding needed for each dimension + current_shape = values_local.shape + target_shape = tuple(local_shape_max.tolist()) + num_dims = len(current_shape) + + # Create padding tuple (reverse order for F.pad) + padding = [] + for i in range(num_dims): + dim_idx = num_dims - 1 - i # Reverse order for F.pad + pad_needed = target_shape[dim_idx] - current_shape[dim_idx] + padding.extend([0, pad_needed]) # [left, right] for each dimension + + # Pad the tensor + values_local = torch.nn.functional.pad(values_local, tuple(padding), value=0) + + # expand to get global shape and strides + shape_scaling = torch.ones_like(local_shape_max) + shape_scaling[0] = self._output_device_mesh.shape[0] # dp dimension + + new_placements = [Shard(0)] + for placement in placements: + if isinstance(placement, Shard): + shape_scaling[placement.dim + 1] = self._output_device_mesh.shape[ + placement.dim + 1 + ] # +1 because input placement lacks dp dimension + new_placements.append(Shard(placement.dim + 1)) + elif isinstance(placement, Replicate): + new_placements.append(placement) + else: + raise ValueError(f"Unsupported placement: {placement}") + + global_shape = local_shape_max * shape_scaling + strides = LayoutRightMap(tuple(global_shape.tolist())).strides # coherent with contiguous layout + + collated[key] = DTensor.from_local( + values_local, + device_mesh=self._output_device_mesh, + placements=new_placements, + shape=torch.Size(global_shape.tolist()), + stride=strides, + ) + + return collated + + +def map_subgroup_mesh_to_cpu(dist_manager: "DistributedManager") -> DeviceMesh: + """Map the subgroup mesh to the CPU device mesh. + + Parameters + ---------- + manager : DistributedManager + The distributed manager. + """ + device_mesh = dist_manager.device_mesh_subgroups + if device_mesh.device_type == "cpu": + return DeviceMesh.from_group( + group=[ + dist_manager.group["dp"], + dist_manager.group["cp_axis_0"], + dist_manager.group["cp_axis_1"], + ], + device_type="cpu", + mesh=device_mesh.mesh.clone(), + mesh_dim_names=( + "dp_cpu", + "cp_axis_0_cpu", + "cp_axis_1_cpu", + ), + ) + elif device_mesh.device_type == "cuda": + if "dp_cpu" not in dist_manager.group_ranks: + dist_manager.create_group( + "dp_cpu", + dist_manager.group_ranks["dp"], + backend="gloo", + use_local_synchronization=True, + ) + dist_manager.create_group( + "cp_axis_0_cpu", + dist_manager.group_ranks["cp_axis_0"], + backend="gloo", + use_local_synchronization=True, + ) + dist_manager.create_group( + "cp_axis_1_cpu", + dist_manager.group_ranks["cp_axis_1"], + backend="gloo", + use_local_synchronization=True, + ) + + device_mesh_cpu = DeviceMesh.from_group( + group=[ + dist_manager.group["dp_cpu"], + dist_manager.group["cp_axis_0_cpu"], + dist_manager.group["cp_axis_1_cpu"], + ], + device_type="cpu", + mesh=device_mesh.mesh.clone(), + mesh_dim_names=( + "dp_cpu", + "cp_axis_0_cpu", + "cp_axis_1_cpu", + ), + ) + + # Check if cpu and cuda group ranks match + cuda_group_ranks = ( + dist_manager.group_ranks["dp"], + dist_manager.subgroups_ranks["cp"][0], + dist_manager.subgroups_ranks["cp"][1], + ) + cpu_group_ranks = tuple( + torch.distributed.get_process_group_ranks(group) for group in device_mesh_cpu.get_all_groups() + ) + for cuda_group_ranks_, cpu_group_ranks_ in zip(cuda_group_ranks, cpu_group_ranks): + if set(cuda_group_ranks_) != set(cpu_group_ranks_): + raise ValueError( + f"New CPU group ranks {cpu_group_ranks_} do not match with existing CUDA group ranks {cuda_group_ranks_}" + ) + + return device_mesh_cpu + else: + raise ValueError(f"Unknown device type {device_mesh.device_type}") + + +def get_flattened_group(device_mesh: DeviceMesh, backend: Optional[str] = None) -> ProcessGroup: + """Get the flattened process group from a device mesh. + + The original _flatten method creates a new group using default parameters, which can lead to + inconsistent backend with the original mesh. This function creates an additional group with a + pre-specified backend. + + Examples: + mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1], [2, 3]]) + group = get_flattened_group(mesh, backend="gloo") + print(group) + # ProcessGroup(type: GLOO, backend: gloo, devices: [0, 1, 2, 3]) + + Parameters + ---------- + device_mesh : DeviceMesh + The device mesh. + + Returns + ------- + group : ProcessGroup + The flattened group. + """ + new_group = device_mesh._flatten().get_group() + if backend is None: + return new_group + + return torch.distributed.new_group( + torch.distributed.get_process_group_ranks(new_group), + backend=backend, + use_local_synchronization=True, + ) diff --git a/src/boltz/distributed/lightning_strategy.py b/src/boltz/distributed/lightning_strategy.py new file mode 100644 index 000000000..d3ffa5e9b --- /dev/null +++ b/src/boltz/distributed/lightning_strategy.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Lightning strategy helpers for DTensor context-parallel training.""" + +import logging +from pathlib import Path +from typing import Any, Mapping, Optional + +import torch +from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.strategies import SingleDeviceStrategy +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_optimizer_state_dict, + set_optimizer_state_dict, +) +from torch.distributed.tensor import DTensor, Replicate +from typing_extensions import override + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.utils import ( + convert_distributed_checkpoint_to_serial_state_dict, + convert_dtensors_to_tensors, + convert_serial_checkpoint_to_distributed_state_dict, +) + +logger = logging.getLogger(__name__) + + +def _redistribute_optimizer_state_to_params(optimizer: torch.optim.Optimizer) -> None: + """Redistribute plain-tensor optimizer state to match DTensor parameters. + + After loading a checkpoint, optimizer state buffers (``exp_avg``, + ``exp_avg_sq``, etc.) are plain tensors. If the corresponding parameter + is a DTensor, the state must be re-distributed to the same mesh and + placements to avoid mixed-type errors on ``optimizer.step()``. + """ + for group in optimizer.param_groups: + for param in group["params"]: + if not isinstance(param, DTensor): + continue + param_state = optimizer.state.get(param) + if param_state is None: + continue + if not all(isinstance(p, Replicate) for p in param.placements): + raise ValueError( + f"Only Replicate placements are supported for optimizer state redistribution, " + f"got {param.placements}" + ) + for state_key, state_val in param_state.items(): + if isinstance(state_val, torch.Tensor) and not isinstance(state_val, DTensor): + state_val = state_val.to(device=param.device_mesh.device_type) + # All ranks load the same checkpoint, so state_val is + # already identical across ranks. from_local avoids the + # redundant all-gather that distribute_tensor would do. + param_state[state_key] = DTensor.from_local( + state_val, + device_mesh=param.device_mesh, + placements=param.placements, + shape=state_val.shape, + stride=state_val.stride(), + ) + + +class BoltzContextParallelStrategy(SingleDeviceStrategy): + """DTensor-aware strategy for context-parallel checkpoint handling. + + This strategy intentionally stays close to single-device semantics while + customizing checkpoint save/load behavior: + + - save: convert model DTensors to regular tensors for portable checkpoints + - load: map serial checkpoint tensors into the currently-materialized model + state_dict template (which may include DTensor parameters/buffers) + """ + + strategy_name = "boltz_context_parallel" + + def __init__(self, dist_manager: DistributedManager, *args: Any, **kwargs: Any) -> None: + self.dist_manager = dist_manager + super().__init__(*args, device=self.dist_manager.device, **kwargs) + self.global_rank = self.dist_manager.rank + self.local_rank = self.dist_manager.local_rank + self.world_size = self.dist_manager.world_size + + @property + @override + def is_global_zero(self) -> bool: + return self.global_rank == 0 + + @override + def barrier(self, *args: Any, **kwargs: Any) -> None: + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + @override + def save_checkpoint( + self, checkpoint: dict[str, Any], filepath: str | Path, storage_options: Optional[Any] = None + ) -> None: + """Save checkpoint with all DTensors converted to plain tensors. + + This ensures *every* DTensor in the checkpoint (model state_dict, + optimizer states, EMA shadow weights, etc.) is stored in a portable + serial format, not just the ``state_dict`` key. + """ + checkpoint = convert_dtensors_to_tensors(checkpoint) + super().save_checkpoint(checkpoint, filepath, storage_options=storage_options) + + @override + def load_checkpoint(self, checkpoint_path: str | Path) -> dict[str, Any]: + # Route through Lightning's checkpoint_io so strategy-level checkpoint + # backends and remapping behavior are respected. + # Use map_location="cpu" (string, not callable) to avoid loading onto a + # serialized CUDA device that may not exist on the current node. + return self.checkpoint_io.load_checkpoint(checkpoint_path, map_location="cpu") + + @override + def model_to_device(self) -> None: + # Distributed models are expected to be materialized on the target device + # before strategy setup; calling .to(...) here can cause context issues. + return None + + @override + def lightning_module_state_dict(self) -> dict[str, Any]: + if self.model is None: + raise RuntimeError( + "BoltzContextParallelStrategy.model is not set. " + "Attach the model before exporting a checkpoint state_dict." + ) + checkpoint = {"state_dict": self.model.state_dict()} + return convert_distributed_checkpoint_to_serial_state_dict(checkpoint) + + @override + def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = False) -> None: + if self.lightning_module is None: + raise RuntimeError( + "BoltzContextParallelStrategy.lightning_module is not set. " + "Attach the LightningModule before loading model state." + ) + state_template = self.lightning_module.state_dict() + distributed_state_dict = convert_serial_checkpoint_to_distributed_state_dict( + checkpoint=checkpoint, + strict=strict, + state_dict_template=state_template, + ) + self.lightning_module.load_state_dict(distributed_state_dict, strict=strict) + + @override + def optimizer_state(self, optimizer: Any) -> dict[str, Any]: + """Return optimizer state dict with FQN (fully qualified name) keys. + + Uses ``torch.distributed.checkpoint.state_dict.get_optimizer_state_dict`` + to produce parameter-name-keyed state (e.g. ``"distogram_module.weight"`` + instead of ``0``), making checkpoints portable across model topologies + regardless of parameter registration order. + """ + if isinstance(optimizer, LightningOptimizer): + optimizer = optimizer._optimizer + if self.lightning_module is None: + raise RuntimeError( + "BoltzContextParallelStrategy.lightning_module is not set. " + "Attach the LightningModule before saving optimizer state." + ) + return get_optimizer_state_dict(self.lightning_module, optimizer) + + @override + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + """Load optimizer state, handling both FQN-keyed and legacy int-keyed formats. + + Checkpoints saved by this strategy use FQN (fully qualified name) string + keys produced by ``get_optimizer_state_dict``. Legacy checkpoints (or + those saved by the default Lightning strategy) use integer keys. + + This method auto-detects the format by inspecting the first key in + ``state_dict["state"]``: + + - **FQN keys (str):** loaded via ``set_optimizer_state_dict`` which maps + parameter names back to the live optimizer's internal indices and + redistributes plain tensors to match DTensor parameter placements. + - **Integer keys (int):** loaded via ``optimizer.load_state_dict()`` + followed by manual DTensor redistribution (legacy path). + """ + if not self.optimizers: + return + + if "optimizer_states" not in checkpoint: + raise ValueError("Checkpoint is passed into load_optimizer_state_dict but no optimizer_states found") + + optimizer_states = checkpoint["optimizer_states"] + if not isinstance(optimizer_states, (list, tuple)): + raise TypeError(f"Checkpoint field 'optimizer_states' must be a list/tuple, got {type(optimizer_states)}") + if len(optimizer_states) != len(self.optimizers): + raise ValueError( + "Optimizer-state length mismatch: " + f"checkpoint has {len(optimizer_states)} state entries but strategy has {len(self.optimizers)} optimizers" + ) + + for index, optimizer in enumerate(self.optimizers): + optimizer_state = optimizer_states[index] + if isinstance(optimizer, LightningOptimizer): + optimizer = optimizer._optimizer + + state_keys = list(optimizer_state.get("state", {}).keys()) + uses_fqn_keys = state_keys and isinstance(state_keys[0], str) + + if uses_fqn_keys: + if self.lightning_module is None: + raise RuntimeError( + "BoltzContextParallelStrategy.lightning_module is not set. " + "Attach the LightningModule before loading optimizer state." + ) + set_optimizer_state_dict( + self.lightning_module, + optimizer, + optim_state_dict=optimizer_state, + options=StateDictOptions(full_state_dict=True), + ) + else: + logger.warning( + "Loading optimizer state with legacy integer keys. " + "This is expected when resuming from a checkpoint saved by " + "the default Lightning strategy or an older version of the " + "distributed trainer. Future checkpoints will use FQN keys." + ) + optimizer.load_state_dict(optimizer_state) + + # Ensure optimizer state buffers match DTensor parameter placements. + # For the FQN path, set_optimizer_state_dict may already redistribute; + # this is a safe no-op when state is already distributed correctly. + _redistribute_optimizer_state_to_params(optimizer) diff --git a/src/boltz/distributed/main.py b/src/boltz/distributed/main.py new file mode 100644 index 000000000..378b562ac --- /dev/null +++ b/src/boltz/distributed/main.py @@ -0,0 +1,343 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""CLI entrypoint for distributed Boltz-2 inference with DTensor context parallelism. + +Thin click wrapper that resolves checkpoint / CCD-molecule paths and forwards +every option to :func:`boltz.distributed.predict.run_predict`. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Optional + +import click + +from boltz.data import const +from boltz.distributed.data.types import PairMaskMode +from boltz.distributed.model.modules.utils import Precision, SDPAWithBiasBackend, TriAttnBackend +from boltz.main import download_boltz2, get_cache_path + + +@click.group() +def cli() -> None: + """Boltz-2 distributed inference.""" + return + + +@cli.command() +@click.argument("data", type=click.Path(exists=True)) +@click.option( + "--out_dir", + type=click.Path(exists=False), + help="The path where to save the predictions.", + default="./", +) +@click.option( + "--cache", + type=click.Path(exists=False), + help=("The directory where to download the data and model. " "Default is ~/.boltz, or $BOLTZ_CACHE if set."), + default=get_cache_path, +) +@click.option( + "--checkpoint", + type=click.Path(exists=True), + help="Path to a Boltz-2 model checkpoint. Downloaded automatically if not provided.", + default=None, +) +@click.option( + "--mol_dir", + type=click.Path(exists=True), + help="Directory containing per-residue CCD molecule pickle files. " "Resolved from --cache if not provided.", + default=None, +) +@click.option( + "--size_dp", + type=int, + help="Number of data-parallel ranks. Default is 1.", + default=1, +) +@click.option( + "--size_cp", + type=int, + help="Total number of context-parallel ranks (must be a perfect square). Default is 1.", + default=1, +) +@click.option( + "--accelerator", + type=click.Choice(["gpu", "cpu"]), + help="The accelerator to use for prediction. Default is gpu.", + default="gpu", +) +@click.option( + "--recycling_steps", + type=int, + help="The number of recycling steps. Default is 3.", + default=3, +) +@click.option( + "--sampling_steps", + type=int, + help="The number of diffusion sampling steps. Default is 200.", + default=200, +) +@click.option( + "--diffusion_samples", + type=int, + help="The number of independent diffusion samples per input. Default is 1.", + default=1, +) +@click.option( + "--max_parallel_samples", + type=int, + help="Max diffusion samples to run in parallel (None = all at once). Default is None.", + default=None, +) +@click.option( + "--step_scale", + type=float, + help=( + "Step scale for the diffusion schedule. Lower values increase diversity " + "among samples (recommended between 1 and 2). Default is 1.5." + ), + default=1.5, +) +@click.option( + "--output_format", + type=click.Choice(["pdb", "mmcif"]), + help="The output structure format. Default is mmcif.", + default="mmcif", +) +@click.option( + "--seed", + type=int, + help="Random seed for reproducibility. Default is None.", + default=None, +) +@click.option( + "--max_msa_seqs", + type=int, + help=f"Maximum number of MSA sequences. Default is {const.max_msa_seqs}.", + default=const.max_msa_seqs, +) +@click.option( + "--msa_pad_to_max_seqs", + is_flag=True, + help="Whether to pad MSA to max_msa_seqs. Default is False.", +) +@click.option( + "--input_format", + type=click.Choice(["preprocessed", "config_files"], case_sensitive=False), + help="Data format for input. If 'preprocessed', expects a folder with " + "manifest.json, msa/ folder and structures/ folder with preprocessed data. " + "If 'config_files', expects yaml/fasta files. Default is preprocessed.", + default="preprocessed", +) +@click.option( + "--timeout_nccl", + type=float, + help="NCCL timeout in minutes. Default is 30.", + default=30, +) +@click.option( + "--timeout_gloo", + type=float, + help="Gloo timeout in minutes. Default is 30.", + default=30, +) +@click.option( + "--precision", + type=click.Choice([Precision.BF16.value, Precision.BF16_MIXED.value, Precision.TF32.value, Precision.FP32.value]), + help="Model precision mode. Default is BF16_MIXED.", + default=Precision.BF16_MIXED.value, +) +@click.option( + "--atoms_per_window_queries_keys", + nargs=2, + type=int, + help="(queries, keys) window sizes for atom attention batching. Default is 32 128.", + default=(32, 128), +) +@click.option( + "--pair_mask_mode", + type=click.Choice( + [PairMaskMode.NONE.value, PairMaskMode.GLOBAL_ATOM_ATTENTION.value, PairMaskMode.SEQUENCE_LOCAL_ATTENTION.value] + ), + help="Pair mask mode. Default is None (window batching).", + default=PairMaskMode.NONE.value, +) +@click.option( + "--local_batch_size", + type=int, + help="Per-rank batch size. Default is 1.", + default=1, +) +@click.option( + "--num_ensembles", + type=int, + help="Number of ensemble members. Default is 1.", + default=1, +) +@click.option( + "--write_full_pae", + is_flag=True, + help="Whether to write full PAE matrices. Default is False.", +) +@click.option( + "--use_templates", + type=bool, + help="Whether to use template features. Default is True.", + default=True, +) +@click.option( + "--triattn_backend", + type=click.Choice([TriAttnBackend.CUEQ.value, TriAttnBackend.TRIFAST.value, TriAttnBackend.REFERENCE.value]), + help="Triangle attention backend to use. Default is cueq.", + default=TriAttnBackend.CUEQ.value, +) +@click.option( + "--sdpa_with_bias_backend", + type=click.Choice([SDPAWithBiasBackend.REFERENCE.value, SDPAWithBiasBackend.TORCH_FLEX_ATTN.value]), + help="SDPA backend for ring-attention AttentionPairBias layers. Default is torch_flex_attn.", + default=SDPAWithBiasBackend.TORCH_FLEX_ATTN.value, +) +@click.option( + "--sdpa_with_bias_shardwise_backend", + type=click.Choice( + [ + SDPAWithBiasBackend.REFERENCE.value, + SDPAWithBiasBackend.TORCH_SDPA_EFFICIENT_ATTENTION.value, + SDPAWithBiasBackend.TORCH_FLEX_ATTN.value, + ] + ), + help="SDPA backend for window-batched AttentionPairBiasShardwise layers. Default is torch_flex_attn.", + default=SDPAWithBiasBackend.TORCH_FLEX_ATTN.value, +) +@click.option( + "--auto_pad_tokens_for_sm100f/--no_auto_pad_tokens_for_sm100f", + default=True, + help="Pad token counts so each CP shard is a multiple of 8 for SM100f cuEq TriAttn. Default is True.", +) +@click.option( + "--cuda_memory_profile", + is_flag=True, + default=False, + help="Profile CUDA memory usage and dump a snapshot pickle per rank.", +) +@click.option( + "--override", + is_flag=True, + default=False, + help="Override existing predictions even if output already exists. Default is False.", +) +def predict( + data: str, + out_dir: str, + cache: str, + checkpoint: Optional[str], + mol_dir: Optional[str], + size_dp: int, + size_cp: int, + accelerator: str, + recycling_steps: int, + sampling_steps: int, + diffusion_samples: int, + max_parallel_samples: Optional[int], + step_scale: float, + output_format: str, + seed: Optional[int], + max_msa_seqs: int, + msa_pad_to_max_seqs: bool, + input_format: str, + timeout_nccl: float, + timeout_gloo: float, + precision: str, + atoms_per_window_queries_keys: tuple[int, int], + pair_mask_mode: str, + local_batch_size: int, + num_ensembles: int, + write_full_pae: bool, + use_templates: bool, + triattn_backend: str, + sdpa_with_bias_backend: str, + sdpa_with_bias_shardwise_backend: str, + auto_pad_tokens_for_sm100f: bool, + cuda_memory_profile: bool, + override: bool, +) -> None: + """Run distributed Boltz-2 structure prediction. + + DATA is the path to the input data directory. + """ + cache_path = Path(cache).expanduser() + cache_path.mkdir(parents=True, exist_ok=True) + + # Resolve checkpoint: download if not provided + if checkpoint is None: + download_boltz2(cache_path) + checkpoint = str(cache_path / "boltz2_conf.ckpt") + + # Resolve mol_dir: use cache default if not provided + if mol_dir is None: + download_boltz2(cache_path) + mol_dir = str(cache_path / "mols") + + from boltz.distributed.predict import run_predict + + run_predict( + data=data, + out_dir=out_dir, + mol_dir=mol_dir, + checkpoint=checkpoint, + size_dp=size_dp, + size_cp=size_cp, + accelerator=accelerator, + recycling_steps=recycling_steps, + sampling_steps=sampling_steps, + diffusion_samples=diffusion_samples, + max_parallel_samples=max_parallel_samples, + step_scale=step_scale, + output_format=output_format, + seed=seed, + max_msa_seqs=max_msa_seqs, + msa_pad_to_max_seqs=msa_pad_to_max_seqs, + input_format=input_format, + timeout_nccl=timeout_nccl, + timeout_gloo=timeout_gloo, + precision=Precision(precision), + atoms_per_window_queries_keys=atoms_per_window_queries_keys, + pair_mask_mode=PairMaskMode(pair_mask_mode), + local_batch_size=local_batch_size, + num_ensembles=num_ensembles, + write_full_pae=write_full_pae, + use_templates=use_templates, + triattn_backend=TriAttnBackend(triattn_backend), + sdpa_with_bias_backend=SDPAWithBiasBackend(sdpa_with_bias_backend), + sdpa_with_bias_shardwise_backend=SDPAWithBiasBackend(sdpa_with_bias_shardwise_backend), + auto_pad_tokens_for_sm100f=auto_pad_tokens_for_sm100f, + cuda_memory_profile=cuda_memory_profile, + override=override, + ) + + +if __name__ == "__main__": + cli() diff --git a/src/boltz/distributed/manager.py b/src/boltz/distributed/manager.py new file mode 100644 index 000000000..6433cbda2 --- /dev/null +++ b/src/boltz/distributed/manager.py @@ -0,0 +1,900 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import os +from copy import deepcopy +from math import prod +from typing import Any, Dict, Optional, OrderedDict, Union +from warnings import warn + +import torch + +from boltz.distributed.utils import LayoutMap, LayoutRightMap + +# grid_group_sizes objects must have +# (1) .values() attribute, (2) .items() attribute +_GridGroupSizesType = OrderedDict[str, Union[int, tuple[int, ...]]] + + +class DistributedManager: + """ + Borg-style singleton class for managing distributed state. + + Parameters + ---------- + None + + Attributes + ---------- + rank : int + Rank of the process. + world_size : int + Total number of processes. + local_rank : int + Local rank of the process. + device : torch.device + Device used by the process. + backend : str + Backend used for distributed communication. + device_mesh : torch.distributed.device_mesh + Device mesh used for distributed communication. + group : dict + Dictionary of groups used for distributed communication. + group_rank : dict + Dictionary of group ranks used for distributed communication. + group_ranks : dict + Dictionary of group ranks used for distributed communication. + method_init : str + Method used to initialize the distributed manager. + + Examples + -------- + >>> DistributedManager.initialize() + >>> manager = DistributedManager() + >>> manager.rank + 0 + >>> manager.world_size + 1 + """ + + _state = {} + + def __new__(cls): + """ + Creates a new instance of the DistributedManager class. + + Parameters + ---------- + None + Returns + ------- + instance : DistributedManager + New instance of the DistributedManager class. + """ + instance = super().__new__(cls) + instance.__dict__ = cls._state + # default the properties so that the default initialize() could work + if not hasattr(instance, "_initialized"): + instance._initialized = False + if not hasattr(instance, "_has_dist"): + instance._has_dist = False + if not hasattr(instance, "_rank"): + instance._rank = 0 + if not hasattr(instance, "_world_size"): + instance._world_size = 1 + if not hasattr(instance, "_local_rank"): + instance._local_rank = 0 + if not hasattr(instance, "_device"): + instance._device = torch.device("cpu") + if not hasattr(instance, "_backend"): + instance._backend = None + if not hasattr(instance, "_device_mesh"): + instance._device_mesh = None + if not hasattr(instance, "_layout_device_mesh"): + instance._layout_device_mesh = None + if not hasattr(instance, "_has_subgroups"): + instance._has_subgroups = False + if not hasattr(instance, "_device_mesh_subgroups"): + instance._device_mesh_subgroups = None + if not hasattr(instance, "_layout_device_mesh_subgroups"): + instance._layout_device_mesh_subgroups = None + if not hasattr(instance, "_group"): + instance._group = {} + if not hasattr(instance, "_group_rank"): + instance._group_rank = {} + if not hasattr(instance, "_group_ranks"): + instance._group_ranks = {} + if not hasattr(instance, "_subgroups"): + instance._subgroups = {} + if not hasattr(instance, "_subgroups_rank"): + instance._subgroups_rank = {} + if not hasattr(instance, "_subgroups_ranks"): + instance._subgroups_ranks = {} + if not hasattr(instance, "_layout_subgroups"): + instance._layout_subgroups = {} + if not hasattr(instance, "_method_init"): + instance._method_init = None + return instance + + @classmethod + def methods_init_available(cls) -> set[str]: + """ + Returns the set of available initialization methods for the DistributedManager. + + Returns + ------- + methods : set[str] + Set of available initialization methods. + """ + return {"ENV", "SLURM"} + + @classmethod + def backend_for_device(cls) -> Dict[str, Optional[str]]: + """ + Returns the mapping of device types to their default backend. + + Returns + ------- + backend_for_device : dict[str, str or None] + Mapping of device types to their default backend. + """ + backend_for_device = { + "cuda": "nccl" if torch.distributed.is_nccl_available() else None, + "cpu": "gloo" if torch.distributed.is_gloo_available() else None, + } + return backend_for_device + + @classmethod + def is_initialized(cls) -> bool: + """ + Checks if the DistributedManager singleton has been initialized. + + Parameters + ---------- + None + + Returns + ------- + initialized : bool + True if the DistributedManager singleton has been initialized, False otherwise. + """ + return cls._state.get("_initialized", False) + + def __init__(self): + """ + Initializes the DistributedManager instance. + + Parameters + ---------- + None + + Raises + ------ + RuntimeError + If the DistributedManager instance is being instantiated before the singleton class is initialized. + """ + if not self._initialized: + raise RuntimeError( + "A DistributedManager instance is being instantiated before " + "the singleton class is initialized, which can lead to communication " + "failure among processes. Please call DistributedManager.initialize() " + "before instantiating any `DistributedManager` instance. " + ) + super().__init__() + + def __getattr__(self, name: str) -> Any: + """ + Gets the attribute of the DistributedManager instance. + + Parameters + ---------- + name : str + Name of the attribute to get. + + Returns + ------- + attribute : Any + Attribute of the DistributedManager instance. + + Raises + ------ + AttributeError + If the attribute does not exist. + """ + # to enable read-only access to the shared _state data + key_state = f"_{name}" + has_key_shared_state = key_state in self.__dict__ + has_key = name in self.__dict__ + if has_key_shared_state: + return self.__dict__[key_state] + elif has_key: + return self.__dict__[name] + else: + raise AttributeError(f'Attribute "{name}" or "_{name}" not found.') + + def __str__(self): + """ + Returns a string representation of the DistributedManager instance. + + Parameters + ---------- + None + + Returns + ------- + str : str + String representation of the DistributedManager instance. + """ + output = ( + f"Initialized process {self.rank} of {self.world_size} using " + f"method '{self.method_init}'. Device set to {str(self.device)}. Backend is {self.backend}" + ) + return output + + @staticmethod + def _setup( + grid_group_sizes: Optional[_GridGroupSizesType] = None, + device_type: str = "cuda", + backend: Optional[str] = None, + rank: int = -1, + node_rank: int = -1, + world_size: int = -1, + local_rank: Optional[int] = None, + addr: str = "localhost", + port: str = "29500", + method_init: str = "ENV", + **kwargs_init_pg, + ): + """ + Sets up the DistributedManager instance. + + Parameters + ---------- + grid_group_sizes : OrderedDict, optional + OrderedDict of group sizes used for distributed communication. + See create_grid_group() for details and examples. + device_type : str, optional + Type of device used for distributed communication. + backend : str, optional + Backend used for distributed communication. + rank : int, optional + Rank of the process. + node_rank : int, optional + Node rank of the process. + world_size : int, optional + Total number of processes. + local_rank : int, optional + Local rank of the process. + addr : str, optional + Address used for distributed communication. + port : str, optional + Port used for distributed communication. + method_init : str, optional + Method used to initialize the distributed manager. + kwargs_init_pg: + kwargs to forward to torch.distributed.init_process_group call + + Returns + ------- + None + """ + # TODO: could relax this to allow, e.g., "cuda" for "gloo" + if device_type == "cuda" and not torch.cuda.is_available(): + raise RuntimeError(f"Input device type {device_type} but torch.cuda is not available") + + if world_size != -1 and grid_group_sizes is not None: + total_size = 1 + assert hasattr(grid_group_sizes, "values") + for value in grid_group_sizes.values(): + if isinstance(value, tuple) and all(isinstance(v, int) for v in value): + total_size *= prod(value) + elif isinstance(value, int): + total_size *= value + else: + raise RuntimeError( + f"Values in grid_group_sizes must be either int or tuple[int, ...], got {type(value)}" + ) + + if world_size != total_size: + raise RuntimeError( + f"Non-default world_size {world_size} != product of grid_group_sizes values ({total_size})" + ) + + backend_for_device = DistributedManager.backend_for_device() + + if backend_for_device["cpu"] is None and backend_for_device["cuda"] is None: + raise RuntimeError(f"No backend available for the supported device types: {backend_for_device.keys()}") + + if device_type not in backend_for_device: + raise RuntimeError(f"Invalid input device type {device_type}: only supports {backend_for_device.keys()}") + + if backend is None: + backend = backend_for_device[device_type] + elif backend != backend_for_device[device_type]: + raise RuntimeError(f"Invalid input backend {backend} for input device type {device_type}") + + # set these in order to call torch.distributed.init_process_group + os.environ["MASTER_ADDR"] = addr + os.environ["MASTER_PORT"] = str(port) + + # instantiate the singleton + DistributedManager._state["_initialized"] = True + manager = DistributedManager() + + manager._has_dist = torch.distributed.is_available() + + manager._rank = rank + manager._world_size = world_size + manager._node_rank = node_rank + if device_type == "cuda": + if manager.world_size > torch.cuda.device_count() and manager.world_size % torch.cuda.device_count(): + warn( + "world_size is not a multiple of torch.cuda.device_count() so cuda devices could be shared by multiple ranks" + ) + # will try to guess a local_rank from GPU counts + if local_rank is None: + manager._local_rank = manager.rank % torch.cuda.device_count() + else: + manager._local_rank = local_rank + manager._device = torch.device(f"cuda:{manager.local_rank}") + else: + if local_rank is not None: + manager._local_rank = local_rank + manager._device = torch.device("cpu") + + if not manager.has_dist: + warn("DistributedManager initialized without torch.distributed package") + # TODO: triage the importance of having a default device according to the + # input backend + return + + if manager.device.type == "cuda": + # set device before init_process_group to avoid unintended + # cuda context and to avoid potential NCCL issues + torch.cuda.set_device(manager.device) + torch.cuda.device(manager.device) + torch.cuda.empty_cache() + + manager._backend = backend + + # initialize torch.distributed + if manager.device.type == "cuda" and backend == "nccl": + try: + # to prevent nccl hang and other potential issues: + # see e.g., https://github.com/pytorch/pytorch/issues/142356 + torch.distributed.init_process_group( + manager.backend, + rank=manager.rank, + world_size=manager.world_size, + device_id=manager.device, + **kwargs_init_pg, + ) + except TypeError: + torch.distributed.init_process_group( + manager.backend, rank=manager.rank, world_size=manager.world_size, **kwargs_init_pg + ) + else: + torch.distributed.init_process_group( + manager.backend, rank=manager.rank, world_size=manager.world_size, **kwargs_init_pg + ) + + manager._group["world"] = torch.distributed.group.WORLD + manager._group_rank["world"] = manager.rank + manager._group_ranks["world"] = torch.distributed.get_process_group_ranks(manager.group["world"]) + + manager._method_init = method_init + + if grid_group_sizes is not None: + DistributedManager.create_grid_group(grid_group_sizes) + + @staticmethod + def _create_device_mesh_and_groups(name: list[str], shape: list[int], suffix_mesh: Optional[str] = None) -> None: + """ + Creates a device mesh and associated process groups for distributed communication. + + Parameters + ---------- + name : list[str] + Names of the dimensions in the device mesh. + shape : list[int] + Shape of the device mesh, representing the sizes of each dimension. + suffix_mesh : str, optional + Suffix to append to the device mesh name for identification, defaults to None. + + Returns + ------- + None + + Raises + ------ + RuntimeError + If DistributedManager is not initialized. + RuntimeError + If torch.distributed package is not available. + RuntimeError + If method_init is invalid or not available. + RuntimeError + If backend is invalid or not available. + RuntimeError + If device type is invalid or not available. + RuntimeError + If world_size does not match the expected world size computed from shape. + """ + if not DistributedManager.is_initialized(): + raise RuntimeError("DistributedManager is not initialized upon calling _create_device_mesh_and_groups") + if not DistributedManager._state["_has_dist"] or not torch.distributed.is_available(): + raise RuntimeError( + "_create_device_mesh_and_groups requires torch.distributed package, which is not available" + ) + if ( + DistributedManager._state["_method_init"] is None + or DistributedManager._state["_method_init"] not in DistributedManager.methods_init_available() + ): + raise RuntimeError( + f"Invalid DistributedManager method_init {DistributedManager._state['_method_init']} " + "(most likely because it was default initialized)" + ) + if ( + DistributedManager._state["_backend"] is None + or DistributedManager._state["_backend"] not in DistributedManager.backend_for_device().values() + ): + raise RuntimeError( + f"Invalid DistributedManager backend {DistributedManager._state['_backend']} " + "(most likely because it was default initialized)" + ) + if ( + DistributedManager._state["_device"] is None + or DistributedManager._state["_device"].type not in DistributedManager.backend_for_device().keys() + ): + raise RuntimeError( + f"Invalid DistributedManager device type {DistributedManager._state['_device'].type} " + "(most likely because it was default initialized)" + ) + + world_size_expected = prod(shape) + + if world_size_expected != DistributedManager._state["_world_size"]: + raise RuntimeError( + f"world_size {DistributedManager._state['_world_size']} does not match the expected world size " + f"{world_size_expected} computed from the input shape {shape}" + ) + + device_type = DistributedManager._state["_device"].type + name_mesh = f"_device_mesh_{suffix_mesh}" if suffix_mesh is not None else "_device_mesh" + + # TODO: support arbitrary user-input layout + layout = LayoutRightMap(tuple(shape)) + DistributedManager._state[f"_layout{name_mesh}"] = layout + + grid2rank = torch.as_strided(torch.arange(world_size_expected), size=layout.shape, stride=layout.strides) + DistributedManager._state[name_mesh] = torch.distributed.device_mesh.DeviceMesh( + device_type, grid2rank, mesh_dim_names=tuple(name) + ) + + for i_group in range(len(name)): + name_group = name[i_group] + + if name_group in DistributedManager._state["_group"]: + # skip those already created, e.g., from another call of this function + continue + + DistributedManager._state["_group"][name_group] = DistributedManager._state[name_mesh].get_group(name_group) + DistributedManager._state["_group_rank"][name_group] = torch.distributed.get_group_rank( + DistributedManager._state["_group"][name_group], DistributedManager._state["_rank"] + ) + DistributedManager._state["_group_ranks"][name_group] = torch.distributed.get_process_group_ranks( + DistributedManager._state["_group"][name_group] + ) + + @staticmethod + def create_grid_group(grid_group_sizes: _GridGroupSizesType) -> None: + """ + Creates a grid group for distributed communication. + + Parameters + ---------- + grid_group_sizes : OrderedDict[str, int | tuple[int, ...]] + Dictionary of group sizes used for distributed communication. The keys of the OrderedDict + are the group names and the values are the group sizes. The group sizes can be an integer + or a tuple of integers. If it is a tuple of integers, they partition the ranks of that group + into a subgrid of the corresponding shape. The layout of the ranks in the groups and subgroups + follows the LayoutRightMap convention, where the last group's (or its last subgroup) ranks are + contiguous global rank on the device grid. + + Returns + ------- + None + + Notes + ----- + This method updates the following dictionaries in the DistributedManager._state: + - _group: Maps group names to torch.distributed.ProcessGroup objects + - _group_rank: Maps group names to the rank of the current process in that group + - _group_ranks: Maps group names to lists of all ranks in that group + - _subgroups: Maps parent group names to lists of their subgroups' ProcessGroup objects + - _subgroups_rank: Maps parent group names to lists of ranks within each subgroup + - _subgroups_ranks: Maps parent group names to lists of all ranks for each subgroup + - _layout_subgroups: Maps parent group names to LayoutMap objects for subgroup coordinate mapping + - _device_mesh: The main PyTorch DeviceMesh object created from shape_groups + - _device_mesh_subgroups: The DeviceMesh object for subgroups created from shape_subgroups + + This method creates PyTorch DeviceMesh objects to facilitate distributed tensor operations + and communication. The DeviceMesh provides a logical view of the physical device layout, + enabling efficient collective operations across process groups and subgroups. + + Examples + -------- + >>> # Create a grid with data parallel (dp) dimension of 1 and + >>> # a 2x2 communication parallel (cp) grid + >>> from collections import OrderedDict + >>> grid_group_sizes = OrderedDict([("dp", 1), ("cp", (2, 2))]) + >>> DistributedManager.initialize(grid_group_sizes, device_type="cuda") + >>> manager = DistributedManager() + >>> # After initialization, the following dictionaries are populated: + >>> # manager.group contains ProcessGroup objects for 'world', 'dp', 'cp' + >>> # and also 'cp_axis_0' and 'cp_axis_1' for the cp subgroups + >>> # manager.group_rank contains the current process's rank in each group + >>> # manager.group_ranks contains all ranks for each group + >>> + >>> # Device mesh objects are created and accessible + >>> device_mesh = manager._device_mesh # DeviceMesh for parent groups ['dp', 'cp'] + >>> device_mesh_subgroups = manager._device_mesh_subgroups # DeviceMesh for ['dp', 'cp_axis_0', 'cp_axis_1'] + >>> print(f"Parent device mesh shape: {device_mesh.shape()}") # Expected: (1, 4) + >>> print(f"Subgroups device mesh shape: {device_mesh_subgroups.shape()}") # Expected: (1, 2, 2) + >>> + >>> # Verify that the subgroups are accessible both directly and via the parent group + >>> assert "cp_axis_0" in manager.group + >>> assert "cp_axis_1" in manager.group + >>> assert manager.subgroups["cp"][0] is manager.group["cp_axis_0"] + >>> assert manager.subgroups["cp"][1] is manager.group["cp_axis_1"] + >>> + >>> # The subgroups represent slices along specific axes of the mesh + >>> # For a rank with coordinates (i, j) in a 2x2 grid: + >>> # "cp_axis_0" contains all ranks (0:2, j) - varying along the first axis, fixed j + >>> # "cp_axis_1" contains all ranks (i, 0:2) - fixed i, varying along the second axis + >>> # This is consistent with torch.distributed.DeviceMesh's group definition + >>> # For rank 1 (coords: (0, 1)): + >>> # - It belongs to cp_axis_0 group containing ranks [1, 3] (all ranks with j=1) + >>> # - It belongs to cp_axis_1 group containing ranks [0, 1] (all ranks with i=0) + >>> + >>> # Examine the layout of layout_subgroups['cp'] + >>> layout = manager.layout_subgroups['cp'] + >>> print(f"Shape: {layout.shape}") # Expected: (2, 2) + >>> print(f"Strides: {layout.strides}") # Expected: (2, 1) for LayoutRightMap + >>> + >>> # Demonstration of how the layout maps coordinates to ranks + >>> for i in range(layout.shape[0]): + ... for j in range(layout.shape[1]): + ... rank = layout.ravel((i, j)) + ... coords = layout.unravel(rank) + ... print(f"Coordinates ({i}, {j}) map to rank {rank}, which maps back to {coords}") + ... + Coordinates (0, 0) map to rank 0, which maps back to (0, 0) + Coordinates (0, 1) map to rank 1, which maps back to (0, 1) + Coordinates (1, 0) map to rank 2, which maps back to (1, 0) + Coordinates (1, 1) map to rank 3, which maps back to (1, 1) + + Raises + ------ + RuntimeError + If DistributedManager is not initialized. + RuntimeError + If torch.distributed package is not available. + RuntimeError + If method_init is invalid or not available. + RuntimeError + If backend is invalid or not available. + RuntimeError + If device type is invalid or not available. + RuntimeError + If values in grid_group_sizes are not int or tuple[int, ...]. + """ + shape_groups = [] + name_groups = [] + shape_subgroups = [] + name_subgroups = [] + group2subgroup = {} + group2subgroup_axes = {} + assert hasattr(grid_group_sizes, "items") + for k, v in grid_group_sizes.items(): + if isinstance(v, tuple) and all(isinstance(v_i, int) for v_i in v): + # Create a new dimension of the DeviceMesh for each group + shape_groups.append(prod(v)) + name_groups.append(k) + # Create a new dimension of the DeviceMesh for each subgroup + # to allow torch DTensor placement on the subgroups' DeviceMesh, + # where each subgroup axis is treated as a separate dimension in the mesh + shape_subgroups.extend(v) + names_this_subgroup = [f"{k}_axis_{i}" for i in range(len(v))] + name_subgroups.extend(names_this_subgroup) + # map each group to its subgroups along each axis + group2subgroup[k] = names_this_subgroup + group2subgroup_axes[k] = list(range(len(name_subgroups) - len(v), len(name_subgroups))) + elif isinstance(v, int): + shape_groups.append(v) + name_groups.append(k) + shape_subgroups.append(v) + name_subgroups.append(k) + else: + raise RuntimeError(f"Values in grid_group_sizes must be either int or tuple[int, ...], got {type(v)}") + + # TODO: might not always need the device_mesh for parent groups + # but one could just create them via create_group() + DistributedManager._create_device_mesh_and_groups(name_groups, shape_groups) + if (name_groups == name_subgroups) != (shape_groups == shape_subgroups): + raise RuntimeError( + f"Inconsistent group ({name_groups}, {shape_groups}) and " + f"subgroup ({name_subgroups}, {shape_subgroups}) settings" + ) + + DistributedManager._state["_has_subgroups"] = name_groups != name_subgroups + if DistributedManager._state["_has_subgroups"]: + if len(group2subgroup) == 0: + raise RuntimeError("group2subgroup is empty while _has_subgroups is True") + DistributedManager._create_device_mesh_and_groups(name_subgroups, shape_subgroups, suffix_mesh="subgroups") + layout = DistributedManager._state["_layout_device_mesh_subgroups"] + coords = DistributedManager._state["_device_mesh_subgroups"].get_coordinate() + for name_group, name_subgroups in group2subgroup.items(): + # map the parent process group name to the subgroups + DistributedManager._state["_subgroups"][name_group] = [ + DistributedManager._state["_group"][name_subgroup] for name_subgroup in name_subgroups + ] + DistributedManager._state["_subgroups_ranks"][name_group] = [ + DistributedManager._state["_group_ranks"][name_subgroup] for name_subgroup in name_subgroups + ] + DistributedManager._state["_subgroups_rank"][name_group] = [ + DistributedManager._state["_group_rank"][name_subgroup] for name_subgroup in name_subgroups + ] + # create the subgroup layout for each parent group + # TODO: support LayoutMap.reshape to simplify this + axes_subgroup = group2subgroup_axes[name_group] + slices = deepcopy(coords) + for axis in axes_subgroup: + slices[axis] = slice(None) + layout_subgroup = layout[*slices] + # create a LayoutMap for the subgroups with offset 0 to be used for a bijective mapping + # between rank within the subgroups and the subgrid + DistributedManager._state["_layout_subgroups"][name_group] = LayoutMap( + layout_subgroup.strides, layout_subgroup.shape, offset=0 + ) + + @staticmethod + def create_group(name: str, ranks: list[int], **kwargs_dist_ng) -> None: + """ + Creates a new process group for distributed communication. + + Parameters + ---------- + name : str + Name of the group. + ranks : list[int] + Ranks of the processes in the group. + **kwargs_dist_ng + Keyword arguments to pass to torch.distributed.new_group. + + Returns + ------- + None + + Notes + ----- + This method creates a new process group with the given name and ranks, + and stores the group, ranks, and group rank in the DistributedManager state. + """ + DistributedManager._state["_group"][name] = torch.distributed.new_group(ranks=ranks, **kwargs_dist_ng) + DistributedManager._state["_group_ranks"][name] = ranks + DistributedManager._state["_group_rank"][name] = torch.distributed.get_group_rank( + DistributedManager._state["_group"][name], DistributedManager._state["_rank"] + ) + + @staticmethod + def _initialize_env(*args, **kwargs): + """ + Initializes the DistributedManager instance using environment variables. + + Parameters + ---------- + *args : list + Variable length argument list. + **kwargs : dict + Arbitrary keyword arguments. + + Returns + ------- + None + """ + if not ("RANK" in os.environ and "WORLD_SIZE" in os.environ): + raise RuntimeError( + "environment variable RANK and WORLD_SIZE must be set to initialize " + "torch.distributed using the env:// method" + ) + rank = os.environ.get("RANK") + world_size = os.environ.get("WORLD_SIZE") + local_rank = os.environ.get("LOCAL_RANK") + # From LightningEnvironment.node_rank() + group_rank = os.environ.get("GROUP_RANK", 0) + node_rank = int(os.environ.get("NODE_RANK", group_rank)) + try: + rank = int(rank) + world_size = int(world_size) + if local_rank is not None: + local_rank = int(local_rank) + except TypeError: + raise RuntimeError( + "environment variables RANK, LOCAL_RANK and WORLD_SIZE must be specified as integer " + f"but got rank={rank}, local_rank={local_rank}, world_size={world_size}" + ) + + DistributedManager._setup( + *args, + rank=rank, + node_rank=node_rank, + world_size=world_size, + local_rank=local_rank, + addr=os.environ.get("MASTER_ADDR"), + port=os.environ.get("MASTER_PORT"), + method_init="ENV", + **kwargs, + ) + + @staticmethod + def _initialize_slurm(*args, **kwargs): + """ + Initializes the DistributedManager instance using SLURM environment variables. + + Parameters + ---------- + *args : list + Variable length argument list. + **kwargs : dict + Arbitrary keyword arguments. + + Returns + ------- + None + """ + keys = ("SLURM_PROCID", "SLURM_NPROCS", "SLURM_LOCALID", "SLURM_LAUNCH_NODE_IPADDR") + if not all(k in os.environ for k in keys): + raise RuntimeError( + f"environment variables {keys} must be set to initialize torch.distributed using the slurm" + ) + rank = os.environ.get("SLURM_PROCID") + node_rank = int(os.environ.get("SLURM_NODEID", 0)) + world_size = os.environ.get("SLURM_NPROCS") + local_rank = os.environ.get("SLURM_LOCALID") + addr = os.environ.get("SLURM_LAUNCH_NODE_IPADDR") + try: + rank = int(rank) + world_size = int(world_size) + if local_rank is not None: + local_rank = int(local_rank) + except TypeError: + raise RuntimeError( + "environment variables SLURM_{PROCID,NPROCS,LOCALID} must be specified as integer " + f"but got PROCID={rank}, LOCALID={local_rank}, NPROCS={world_size}" + ) + + DistributedManager._setup( + *args, + rank=rank, + node_rank=node_rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + method_init="SLURM", + **kwargs, + ) + + @staticmethod + def initialize( + grid_group_sizes: Optional[OrderedDict[str, int | tuple[int, ...]]] = None, + device_type: str = "cuda", + backend: Optional[str] = None, + **kwargs_init_pg, + ): + """ + Initializes the DistributedManager instance. + + Parameters + ---------- + grid_group_sizes : OrderedDict[str, int | tuple[int, ...]] + Dictionary of group sizes used for distributed communication. The keys of the OrderedDict + are the group names and the values are the group sizes. The group sizes can be an integer + or a tuple of integers. If it is a tuple of integers, they partition the ranks of that group + into a subgrid of the corresponding shape. The layout of the ranks in the groups and subgroups + follows the LayoutRightMap convention, where the last group's (or its last subgroup) ranks are + contiguous global rank on the device grid. See create_grid_group() for details and examples + device_type : str, optional + Type of device used for distributed communication. + backend : str, optional + Backend used for distributed communication. + kwargs_init_pg: + kwargs to forward to torch.distributed.init_process_group call + + Returns + ------- + None + """ + if DistributedManager.is_initialized(): + warn("DistributedManager is already initialized. Skip initialize()") + return + if backend == "nccl": + # https://pytorch.org/docs/master/notes/cuda.html#id5 + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" + method_init = os.getenv("BOLTZ_DISTRIBUTED_INIT_METHOD") + if method_init is not None and method_init not in DistributedManager.methods_init_available(): + raise ValueError( + f"Unknown value set for BOLTZ_DISTRIBUTED_INIT_METHOD={method_init}. " + f"Allowed options are one of {DistributedManager.methods_init_available()}" + ) + if method_init is None: + try: + DistributedManager._initialize_env( + grid_group_sizes, device_type=device_type, backend=backend, **kwargs_init_pg + ) + except RuntimeError as except_env: + try: + DistributedManager._initialize_slurm( + grid_group_sizes, device_type=device_type, backend=backend, **kwargs_init_pg + ) + except RuntimeError as except_slurm: + warn( + "Could not initialize DistributedManager with either the env:// method nor the slurm method.\n" + f"Error from env:// method: {except_env} \n" + f"Error from the slurm method: {except_slurm} \n" + "Will default initialize DistributedManager" + ) + DistributedManager._state["_initialized"] = True + elif method_init == "ENV": + DistributedManager._initialize_env( + grid_group_sizes, device_type=device_type, backend=backend, **kwargs_init_pg + ) + elif method_init == "SLURM": + DistributedManager._initialize_slurm( + grid_group_sizes, device_type=device_type, backend=backend, **kwargs_init_pg + ) + + @staticmethod + def cleanup(): + """ + Cleans up the DistributedManager instance. + + Parameters + ---------- + None + + Returns + ------- + None + """ + if DistributedManager._state.get("_group", {}) != {}: + if torch.distributed.is_initialized(): + # somewhere else has already called torch.distributed.destroy_process_group() + # need to skip to avoid double destruction + if DistributedManager._state["_device"].type == "cuda" and torch.cuda.is_available(): + torch.distributed.barrier(device_ids=[DistributedManager._state["_local_rank"]]) + else: + torch.distributed.barrier() + torch.distributed.destroy_process_group() + else: + # otherwise, just clean up the state + DistributedManager._state = {} diff --git a/src/boltz/distributed/model/__init__.py b/src/boltz/distributed/model/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/src/boltz/distributed/model/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/src/boltz/distributed/model/layers/__init__.py b/src/boltz/distributed/model/layers/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/src/boltz/distributed/model/layers/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/src/boltz/distributed/model/layers/atom_to_token.py b/src/boltz/distributed/model/layers/atom_to_token.py new file mode 100644 index 000000000..ae9a69d5a --- /dev/null +++ b/src/boltz/distributed/model/layers/atom_to_token.py @@ -0,0 +1,653 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +######################################################################################## +# On sharding strategy and DTensor placements of atom_to_token +######################################################################################## +# For performance optimization, we shard atom_to_token to minimize communication overhead for any operations on atom_to_token. These include: +# 1. translation of token- to atom-level single representation through torch.bmm +# 2. translation of atom- to token-level single representation through torch.bmm +# 3. pair representation token-to-atom through einsum +# +# Instead of a distributed matrix multiplication, we pad atom_to_token with padding atoms and +# tokens for context parallelism, such that only the block diagonal locations are non-zero. For +# example, say we have 3 tokens and 4 atoms on cp = (2, 2), atom_to_token without parallelism can be: +# +# [[1, 0, 0], +# [0, 1, 0], +# [0, 1, 0], +# [0, 0, 1]] +# +# In context parallelism, we first pad with virtual tokens such that n_tokens is divisible by cp_size, +# and then pad with virtual atoms such that n_atoms is divisible by cp_size. Let us visualize it on the device mesh: +# +# 1 0 | 0 0 +# 1 0 | 0 1 0 | 0 0 0 1 | 0 0 +# 0 1 | 0 pad tokens 0 1 | 0 0 pad atoms 0 1 | 0 0 +# --------- ---------> --------- --------> --------- +# 0 1 | 0 0 1 | 0 0 0 0 | 1 0 +# 0 0 | 1 0 0 | 1 0 0 0 | 0 0 +# 0 0 | 0 0 +# +# Note that all tokens inside of each shard now have every of their own atoms in the same shard. As +# such, any mapping between atom to token or vice versa can be done locally on the diagonal ranks. +# +# To enable this for the off-diagonal ranks, we broadcast the block diagonal matrix row-wise, and +# since all single representations are replicated by row, now all ranks can perform these mapping +# locally without communication. +# +# 1 0 | 0 0 1 0 | 1 0 +# 0 1 | 0 0 0 1 | 0 1 +# 0 1 | 0 0 row-wise broadcast 0 1 | 0 1 +# --------- -----------------> --------- +# 0 0 | 1 0 1 0 | 1 0 +# 0 0 | 0 0 0 0 | 0 0 +# 0 0 | 0 0 0 0 | 0 0 +# +# Mapping on pair representation is the exception where a transposition of atom_to_token on the +# device mesh is needed. +######################################################################################## + + +import torch +from torch import Tensor +from torch.distributed.tensor import DTensor, Replicate, Shard + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.utils import update_exhaustive_strides + + +class SingleReprTokenToAtomFunction(torch.autograd.Function): + """Autograd function for transforming token-level single representation to atom-level single representation.""" + + @staticmethod + def forward( + ctx, + token_single_repr: DTensor, + atom_to_token: DTensor, + ) -> DTensor: + """ + Transform a token-level single representation to an atom-level single representation. + + Args: + token_single_repr: The token-level single representation. Shape: (B, n_tokens, D) and placement: (Shard(0), Shard(1), Replicate()) + atom_to_token: The atom to token one-hot mapping except for padding atoms/tokens. Shape: (B, n_atoms, n_tokens) and placement: (Shard(0), Shard(1), Replicate()) + + Returns: + The atom-level single representation. Shape: (B, n_atoms, D) + """ + single_repr_placements = (Shard(dim=0), Shard(dim=1), Replicate()) # same as atom_to_token placements + if atom_to_token.placements != single_repr_placements: + raise ValueError( + f"Expect atom_to_token to have placements {single_repr_placements}, but got {atom_to_token.placements}" + ) + if token_single_repr.placements != single_repr_placements: + raise ValueError( + f"Expect token_single_repr to have placements {single_repr_placements}, but got {token_single_repr.placements}" + ) + + # Perform local bmm and distribute + atom_to_token_local = atom_to_token.to_local().to( + dtype=token_single_repr.dtype + ) # NOTE in case atom_to_token is int + token_single_repr_local = token_single_repr.to_local() + o = torch.einsum("bij,bj...->bi...", atom_to_token_local, token_single_repr_local) + + # Compute output shape and stride for the global DTensor + # For einsum "bij,bj...->bi...", output shape is (B, n_atoms, D) + # where the last axis could be omitted if the input token_single_repr is 2D + # tensor + shape_output = atom_to_token.shape[:2] + o.shape[2:] + + # Use LayoutRightMap for the output shape + strides_output = update_exhaustive_strides(o.shape, o.stride(), shape_output) + + o = DTensor.from_local( + o, atom_to_token.device_mesh, single_repr_placements, shape=shape_output, stride=strides_output + ) + + if token_single_repr.requires_grad: + ctx.device_mesh = atom_to_token.device_mesh + ctx.single_repr_placements = single_repr_placements + ctx.token_single_repr_shape = token_single_repr.shape + ctx.token_single_repr_stride = token_single_repr.stride() + ctx.save_for_backward(atom_to_token_local) + + return o + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[DTensor, None, None]: + """ + Backward pass for single_repr_token_to_atom. + """ + if grad_output.placements != (Shard(dim=0), Shard(dim=1), Replicate()): + raise ValueError( + f"Expect grad_output to have placements {(Shard(dim=0), Shard(dim=1), Replicate())}, but got {grad_output.placements}" + ) + if grad_output.device_mesh != ctx.device_mesh: + raise ValueError( + f"Expect grad_output to have device mesh {ctx.device_mesh}, but got {grad_output.device_mesh}" + ) + + do = grad_output.to_local() + (atom_to_token_local,) = ctx.saved_tensors + + d_token_single_repr = torch.einsum("bji,bj...->bi...", atom_to_token_local, do) + d_token_single_repr = DTensor.from_local( + d_token_single_repr, + ctx.device_mesh, + ctx.single_repr_placements, + shape=ctx.token_single_repr_shape, + stride=ctx.token_single_repr_stride, + ) + + return d_token_single_repr, None, None + + +class SingleReprAtomToTokenFunction(torch.autograd.Function): + """Autograd function for transforming atom-level single representation to token-level single representation.""" + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward(ctx, atom_single_repr: DTensor, atom_to_token: DTensor) -> DTensor: + """ + Transform an atom-level single representation to a token-level single representation. + + Args: + atom_single_repr: The atom-level single representation. + Shape: (B, n_atoms, D) and placement: (Shard(0), Shard(1), Replicate()) + atom_to_token: The atom to token one-hot mapping except for padding atoms/tokens. + Shape: (B, n_atoms, n_tokens_per_rank) and placement: (Shard(0), Shard(1), Replicate()) + + Returns: + The token-level single representation. Shape: (B, n_tokens, D) + """ + single_repr_placements = (Shard(dim=0), Shard(dim=1), Replicate()) # same as atom_to_token placements + if atom_to_token.placements != single_repr_placements: + raise ValueError( + f"Expect atom_to_token to have placements {single_repr_placements}, but got {atom_to_token.placements}" + ) + if atom_single_repr.placements != single_repr_placements: + raise ValueError( + f"Expect atom_single_repr to have placements {single_repr_placements}, but got {atom_single_repr.placements}" + ) + + # TODO potential performance optimization by moving division after bmm + + # Normalize atom_to_token + atom_to_token_local = atom_to_token.to_local().to( + dtype=atom_single_repr.dtype + ) # NOTE in case atom_to_token is int + atom_to_token_sum = atom_to_token_local.sum(dim=1, keepdim=True).clamp(min=1) + atom_to_token_mean = atom_to_token_local / atom_to_token_sum + + # Perform local bmm and distribute + atom_single_repr_local = atom_single_repr.to_local() + o = torch.einsum("bji,bj...->bi...", atom_to_token_mean, atom_single_repr_local) + + # Compute output shape and stride for the global DTensor + # Output should be (B, n_tokens, D). By definition, atom_to_token.shape[2] == n_tokens_per_rank + # which by definition guarantee atom -> token mapping is uniform across ranks + # so n_tokens == n_tokens_per_rank * size_cp + n_tokens = atom_to_token.shape[2] * atom_to_token.device_mesh.get_group(1).size() + # where the last axis could be omitted if the input atom_single_repr is 2D + shape_output = (atom_to_token.shape[0], n_tokens) + o.shape[2:] + + # Use LayoutRightMap for the output shape + strides_output = update_exhaustive_strides(o.shape, o.stride(), shape_output) + + o = DTensor.from_local( + o, atom_to_token.device_mesh, single_repr_placements, shape=shape_output, stride=strides_output + ) + + if atom_single_repr.requires_grad: + ctx.device_mesh = atom_to_token.device_mesh + ctx.single_repr_placements = single_repr_placements + ctx.atom_single_repr_shape = atom_single_repr.shape + ctx.atom_single_repr_stride = atom_single_repr.stride() + ctx.save_for_backward(atom_to_token_mean) + + return o + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_output: DTensor) -> tuple[DTensor, None, None]: + """ + Backward pass for single_repr_atom_to_token. + """ + if grad_output.placements != (Shard(dim=0), Shard(dim=1), Replicate()): + raise ValueError( + f"Expect grad_output to have placements {(Shard(dim=0), Shard(dim=1), Replicate())}, but got {grad_output.placements}" + ) + if grad_output.device_mesh != ctx.device_mesh: + raise ValueError( + f"Expect grad_output to have device mesh {ctx.device_mesh}, but got {grad_output.device_mesh}" + ) + + do = grad_output.to_local() + (atom_to_token_mean,) = ctx.saved_tensors + + # Perform local bmm and distribute + d_atom_single_repr = torch.einsum("bij,bj...->bi...", atom_to_token_mean, do) + d_atom_single_repr = DTensor.from_local( + d_atom_single_repr, + ctx.device_mesh, + ctx.single_repr_placements, + shape=ctx.atom_single_repr_shape, + stride=ctx.atom_single_repr_stride, + ) + + return d_atom_single_repr, None, None + + +class PairReprTokenToAtomFunction(torch.autograd.Function): + """Autograd function for transforming token-level pair representation to atom-level pair representation.""" + + @staticmethod + def forward( + ctx, + token_repr: DTensor, + atom_to_token: DTensor, + transpose_comm: TransposeComm, + ) -> DTensor: + """ + Transform a token-level pair representation to an atom-level pair representation. + + Args: + token_repr: The token-level pair representation. Shape: (B, n_tokens, n_tokens, D) and placement: (Shard(0), Shard(1), Shard(2)) + atom_to_token: The atom to token one-hot mapping except for padding atoms/tokens. Shape: (B, n_atoms, n_tokens) and placement: (Shard(0), Shard(1), Replicate()) + transpose_comm: The transpose communication object. + + Returns: + The atom-level pair representation. Shape: (B, n_atoms, n_atoms, D) + """ + single_repr_placements = (Shard(dim=0), Shard(dim=1), Replicate()) # same as atom_to_token placements + if atom_to_token.placements != single_repr_placements: + raise ValueError( + f"Expect atom_to_token to have placements {single_repr_placements}, but got {atom_to_token.placements}" + ) + pair_repr_placements = (Shard(dim=0), Shard(dim=1), Shard(dim=2)) + if token_repr.placements != pair_repr_placements: + raise ValueError( + f"Expect token_repr to have placements {pair_repr_placements}, but got {token_repr.placements}" + ) + + if atom_to_token.requires_grad: + raise ValueError("atom_to_token should not require grad") + + # Perform transpose communication to get atom_to_token_local_j + atom_to_token_local = atom_to_token.to_local().to(dtype=token_repr.dtype) # NOTE in case atom_to_token is int + atom_to_token_local = atom_to_token_local.contiguous() # for both forward and backward + atom_to_token_local_j = transpose_comm.enqueue_to_dispatch(atom_to_token_local) + + # Perform overlapped einsum operation + token_repr_local = token_repr.to_local() + + # TODO potential performance optimization by op fusion versus communication overlap + o = torch.einsum("bijd,bmi->bmjd", token_repr_local, atom_to_token_local) + transpose_comm.wait_until_finished() + o = torch.einsum("bmjd,bnj->bmnd", o, atom_to_token_local_j) + + # Compute output shape and stride for the global DTensor + # Output should be (B, n_atoms, n_atoms, D) + shape_output = (atom_to_token.shape[0], atom_to_token.shape[1], atom_to_token.shape[1]) + o.shape[3:] + + # Use LayoutRightMap for the output shape + strides_output = update_exhaustive_strides(o.shape, o.stride(), shape_output) + + o = DTensor.from_local( + o, token_repr.device_mesh, pair_repr_placements, shape=shape_output, stride=strides_output + ) + + # Save tensors needed for backward pass + if token_repr.requires_grad: + ctx.transpose_comm = transpose_comm + ctx.device_mesh = token_repr.device_mesh + ctx.pair_repr_placements = pair_repr_placements + ctx.token_repr_shape = token_repr.shape + ctx.token_repr_stride = token_repr.stride() + ctx.save_for_backward(atom_to_token_local) + + return o + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[DTensor, None, None, None]: + """ + Backward pass for pair_repr_token_to_atom. + + Args: + grad_output: Gradient w.r.t. output with shape (B, n_atoms, n_atoms, D) + + Returns: + Tuple of gradients: (grad_token_repr, None, None) + - grad_token_repr: Gradient w.r.t. token_repr + - None: No gradient for atom_to_token (as specified) + - None: No gradient for transpose_comm (not differentiable) + """ + (atom_to_token_local,) = ctx.saved_tensors + transpose_comm = ctx.transpose_comm + atom_to_token_local_j = ctx.transpose_comm.enqueue_to_dispatch(atom_to_token_local) + + # Perform overlapped einsum operation + do_local = grad_output.to_local() + + d_token_repr = torch.einsum("bmnd,bmi->bind", do_local, atom_to_token_local) + transpose_comm.wait_until_finished() + d_token_repr = torch.einsum("bind,bnj->bijd", d_token_repr, atom_to_token_local_j) + d_token_repr = DTensor.from_local( + d_token_repr, + ctx.device_mesh, + ctx.pair_repr_placements, + shape=ctx.token_repr_shape, + stride=ctx.token_repr_stride, + ) + + return d_token_repr, None, None, None + + +class SingleReprRepAtomToTokenFunction(torch.autograd.Function): + """Autograd function for token representative-atom projection.""" + + @staticmethod + def forward(ctx, atom_single_repr: DTensor, token_to_rep_atom: DTensor) -> DTensor: + """Project atom-level single representation to token-level via representative atoms. + + Supports a multiplicity factor: atom_single_repr may have shape (B*mult, n_atoms, D) + while token_to_rep_atom has shape (B, n_tokens, n_atoms). The token_to_rep_atom map + is broadcast across the multiplicity dimension. + + Args: + atom_single_repr: Atom-level representation. Shape: (B*mult, n_atoms, D), + placements: (Shard(0), Shard(1), Replicate()). + token_to_rep_atom: Token->representative-atom one-hot map. + Shape: (B, n_tokens, n_atoms), placements: (Shard(0), Shard(1), Replicate()). + + Returns: + Token-level representation. Shape: (B*mult, n_tokens, D), + placements: (Shard(0), Shard(1), Replicate()). + """ + single_repr_placements = (Shard(0), Shard(1), Replicate()) + if atom_single_repr.placements != single_repr_placements: + raise ValueError( + f"Expect atom_single_repr to have placements {single_repr_placements}, but got {atom_single_repr.placements}" + ) + if token_to_rep_atom.placements != single_repr_placements: + raise ValueError( + f"Expect token_to_rep_atom to have placements {single_repr_placements}, but got {token_to_rep_atom.placements}" + ) + + token_to_rep_local = token_to_rep_atom.to_local().to(dtype=atom_single_repr.dtype) + atom_single_repr_local = atom_single_repr.to_local() + + # atom_single_repr may carry a multiplicity factor: (B*mult, N_atom, D) vs (B, N_token, N_atom). + # Reshape to (B, mult, N_atom, D) so token_to_rep_local broadcasts over the mult dimension. + B_local = token_to_rep_local.shape[0] + mult = atom_single_repr_local.shape[0] // B_local + atom_reshaped = atom_single_repr_local.reshape(B_local, mult, *atom_single_repr_local.shape[1:]) + # (B, N_token, N_atom) @ (B, mult, N_atom, D) -> (B, mult, N_token, D) + o = torch.einsum("btj,bmj...->bmt...", token_to_rep_local, atom_reshaped) + # Flatten back to (B*mult, N_token, D) + o = o.reshape(B_local * mult, *o.shape[2:]) + + shape_output = (token_to_rep_atom.shape[0] * mult,) + token_to_rep_atom.shape[1:2] + o.shape[2:] + strides_output = update_exhaustive_strides(o.shape, o.stride(), shape_output) + o = DTensor.from_local( + o, token_to_rep_atom.device_mesh, single_repr_placements, shape=shape_output, stride=strides_output + ) + + if atom_single_repr.requires_grad: + ctx.device_mesh = atom_single_repr.device_mesh + ctx.single_repr_placements = single_repr_placements + ctx.atom_single_repr_shape = atom_single_repr.shape + ctx.atom_single_repr_stride = atom_single_repr.stride() + ctx.mult = mult + ctx.save_for_backward(token_to_rep_local) + + return o + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[DTensor, None, None]: + """Backward pass for representative atom projection.""" + if grad_output.placements != (Shard(0), Shard(1), Replicate()): + raise ValueError( + f"Expect grad_output to have placements {(Shard(0), Shard(1), Replicate())}, but got {grad_output.placements}" + ) + if grad_output.device_mesh != ctx.device_mesh: + raise ValueError( + f"Expect grad_output to have device mesh {ctx.device_mesh}, but got {grad_output.device_mesh}" + ) + + do = grad_output.to_local() + (token_to_rep_local,) = ctx.saved_tensors + mult = ctx.mult + B_local = token_to_rep_local.shape[0] + # Reshape grad from (B*mult, N_token, D) to (B, mult, N_token, D) + do_reshaped = do.reshape(B_local, mult, *do.shape[1:]) + # (B, N_token, N_atom)^T @ (B, mult, N_token, D) -> (B, mult, N_atom, D) + d_atom_reshaped = torch.einsum("btj,bmt...->bmj...", token_to_rep_local, do_reshaped) + # Flatten back to (B*mult, N_atom, D) + d_atom_single_repr = d_atom_reshaped.reshape(B_local * mult, *d_atom_reshaped.shape[2:]) + d_atom_single_repr = DTensor.from_local( + d_atom_single_repr, + ctx.device_mesh, + ctx.single_repr_placements, + shape=ctx.atom_single_repr_shape, + stride=ctx.atom_single_repr_stride, + ) + return d_atom_single_repr, None, None + + +def single_repr_token_to_atom( + token_single_repr: DTensor, + atom_to_token: DTensor, +) -> DTensor: + """ + Transform a token-level single representation to an atom-level single representation. + + Args: + token_single_repr: The token-level single representation. Shape: (B, n_tokens, D) and placement: (Shard(0), Shard(1), Replicate()) + atom_to_token: The atom to token mapping. Shape: (B, n_tokens, n_atoms) and placement: (Shard(0), Shard(1), Replicate()) + device_mesh: The device mesh. + + Returns: + The atom-level single representation. Shape: (B, n_atoms, D) + """ + return SingleReprTokenToAtomFunction.apply(token_single_repr, atom_to_token) + + +def single_repr_atom_to_token( + atom_single_repr: DTensor, + atom_to_token: DTensor, +) -> DTensor: + """ + Transform an atom-level single representation to a token-level single representation. + + Args: + atom_single_repr: The atom-level single representation. Shape: (B, n_atoms, D) and placement: (Shard(0), Shard(1), Replicate()) + atom_to_token: The atom to token mapping. Shape: (B, n_atoms, n_tokens) and placement: (Shard(0), Shard(1), Replicate()) + device_mesh: The device mesh. + + Returns: + The token-level single representation. Shape: (B, n_tokens, D) + """ + return SingleReprAtomToTokenFunction.apply(atom_single_repr, atom_to_token) + + +def pair_repr_token_to_atom( + token_repr: DTensor, + atom_to_token: DTensor, + transpose_comm: TransposeComm, +) -> DTensor: + """ + Transform a token-level pair representation to an atom-level pair representation. + + Args: + token_repr: The token-level pair representation. Shape: (B, n_tokens, n_tokens, D) and placement: (Shard(0), Shard(1), Shard(2)) + atom_to_token: The atom to token mapping. Shape: (B, n_tokens, n_atoms) and placement: (Shard(0), Shard(1), Replicate()) + transpose_comm: The transpose communication object. + + Returns: + The atom-level pair representation. Shape: (B, n_atoms, n_atoms, D) + """ + return PairReprTokenToAtomFunction.apply(token_repr, atom_to_token, transpose_comm) + + +def single_repr_rep_atom_to_token( + atom_single_repr: DTensor, + token_to_rep_atom: DTensor, +) -> DTensor: + """Project atom-level single representation to token-level using representative atoms.""" + return SingleReprRepAtomToTokenFunction.apply(atom_single_repr, token_to_rep_atom) + + +def _reconstruct_onehot_diag_block_global( + dtensor: DTensor, +) -> Tensor: + """Reconstruct a diagonally-sharded one-hot DTensor into a full plain tensor. + + Both ``atom_to_token`` (atoms×tokens) and ``token_to_rep_atom`` (tokens×atoms) + use diagonal block sharding: shard *i* only contains the non-zero block + relating atoms of shard *i* to tokens of shard *i*. After all-reduce the + complete global matrix is recovered. + + Parameters + ---------- + dtensor : DTensor + Diagonally-sharded DTensor with placements ``(Shard(0), Shard(1), Replicate())``. + + Returns + ------- + Tensor + Reconstructed global tensor of shape ``(B, N_atoms_global, N_tokens_global)`` + or ``(B, N_tokens_global, N_atoms_global)``. + """ + device_mesh = dtensor.device_mesh + + expected_placements = (Shard(dim=0), Shard(dim=1), Replicate()) + if dtensor.placements != expected_placements: + raise ValueError(f"Expected placements {expected_placements}, got {dtensor.placements}") + + for i_dim_mesh, placement in enumerate(expected_placements): + if isinstance(placement, Shard) and dtensor.shape[placement.dim] % device_mesh.shape[i_dim_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {dtensor.shape[placement.dim]} " + f"along device mesh dimension {i_dim_mesh} of size {device_mesh.shape[i_dim_mesh]} is not supported" + ) + + local = dtensor.to_local() + B = local.shape[0] + assert B == 1, "Only batch size 1 is supported" + + n_per_shard_dim1 = local.shape[1] + n_per_shard_dim2 = local.shape[2] + + cp_axis_0_size = device_mesh.get_group("cp_axis_0").size() + cp_axis_0_rank, cp_axis_1_rank = device_mesh.get_coordinate()[1:] + + n_global_dim1 = n_per_shard_dim1 * cp_axis_0_size + n_global_dim2 = n_per_shard_dim2 * cp_axis_0_size + + result = torch.zeros( + B, + n_global_dim1, + n_global_dim2, + dtype=local.dtype, + device=local.device, + ) + + start_dim1 = cp_axis_0_rank * n_per_shard_dim1 + end_dim1 = start_dim1 + n_per_shard_dim1 + start_dim2 = cp_axis_0_rank * n_per_shard_dim2 + end_dim2 = start_dim2 + n_per_shard_dim2 + + if cp_axis_1_rank == 0: + n_non_zeros = local.sum(dim=2) + if not ((n_non_zeros == 0) | (n_non_zeros == 1)).all(): + raise ValueError( + f"Input DTensor shard is not one-hot for CP rank ({cp_axis_0_rank}, {cp_axis_1_rank}): " + f"found rows with sum not in {{0, 1}}" + ) + result[:, start_dim1:end_dim1, start_dim2:end_dim2] = local + + torch.distributed.all_reduce(result, op=torch.distributed.ReduceOp.SUM, group=device_mesh.get_group("cp_axis_0")) + torch.distributed.all_reduce(result, op=torch.distributed.ReduceOp.SUM, group=device_mesh.get_group("cp_axis_1")) + + return result + + +def reconstruct_token_to_rep_atom_global(token_to_rep_atom_dtensor: DTensor) -> Tensor: + """Reconstruct the full ``token_to_rep_atom`` matrix from a diagonally-sharded DTensor. + + The reconstruction mirrors :func:`reconstruct_atom_to_token_global` but for + the transposed mapping ``(B, N_tokens, N_atoms)``. + + Parameters + ---------- + token_to_rep_atom_dtensor : DTensor + Diagonally-sharded DTensor with placements ``(Shard(0), Shard(1), Replicate())``. + Local shape: ``(B, N_tokens_per_shard, max_atoms_per_shard)``. + + Returns + ------- + Tensor + ``(B, N_tokens_global, N_atoms_global)`` + """ + return _reconstruct_onehot_diag_block_global(token_to_rep_atom_dtensor) + + +def reconstruct_r_set_to_rep_atom_global(r_set_dtensor: DTensor) -> Tensor: + """Reconstruct the full ``r_set_to_rep_atom`` matrix from a diagonally-sharded DTensor. + + Parameters + ---------- + r_set_dtensor : DTensor + Diagonally-sharded DTensor with placements ``(Shard(0), Shard(1), Replicate())``. + Local shape: ``(B, max_r_set_per_shard, max_atoms_per_shard)``. + + Returns + ------- + Tensor + ``(B, N_R_global, N_atoms_global)`` + """ + return _reconstruct_onehot_diag_block_global(r_set_dtensor) + + +def reconstruct_atom_to_token_global(atom_to_token_dtensor: DTensor) -> Tensor: + """ + Reconstruct the original full atom_to_token tensor from a DTensor with (Shard, Shard, Replicate) placements. + + This function reverses the context parallel sharding strategy by: + 1. Gathering the local tensor from each rank + 2. Reconstructing the block diagonal structure + 3. Removing padding to get the original tensor + + Args: + atom_to_token_dtensor: DTensor with placements (Shard(0), Shard(1), Replicate()) + Shape: (global_batch_size, n_atoms_per_rank, n_tokens) + + Returns: + Tensor: The reconstructed global atom_to_token tensor + Shape: (local_batch_size, n_atoms_global, n_tokens_global) + """ + result = _reconstruct_onehot_diag_block_global(atom_to_token_dtensor) + assert torch.max(result) == 1 + return result diff --git a/src/boltz/distributed/model/layers/attention.py b/src/boltz/distributed/model/layers/attention.py new file mode 100644 index 000000000..ff73e76be --- /dev/null +++ b/src/boltz/distributed/model/layers/attention.py @@ -0,0 +1,658 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from collections import OrderedDict +from typing import Callable, Union + +from torch import nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor + +from boltz.distributed.comm import AttentionPairBiasComm +from boltz.distributed.model.layers.attention_impl import ( + _AttentionPairBiasContextVecParams, + _AttentionPairBiasContexVecImpl, + _AttentionPairBiasShardwiseImpl, +) +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.linear import LinearParamsReplicated +from boltz.distributed.model.layers.sigmoid_gate import sigmoid_gate +from boltz.distributed.model.modules.utils import SDPAWithBiasBackend +from boltz.model.layers.attention import AttentionPairBias as AttentionPairBiasSerialV1 +from boltz.model.layers.attentionv2 import AttentionPairBias as AttentionPairBiasSerialV2 + + +class AttentionPairBias(nn.Module): + """Attention pair bias module based on DTensor with ring attention. + + This module implements global (non-window-batched) attention with pair bias + using ring communication patterns for context parallelism. + + The __init__() method follows the pattern of distribute_module(), and + so takes a device mesh as an argument. See the following link for details: + https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.distribute_module + + Configuration Flags + ------------------- + The module supports both V1 (Boltz-1x) and V2 (Boltz-2) API styles through + configuration flags: + + - apply_initial_norm: V1=True (has norm_s LayerNorm), V2=False (no initial norm) + - compute_pair_bias: V1=True (project z via LayerNorm+Linear), + V2=configurable (False for DiffusionTransformerLayer where z is pre-computed bias). + Mutually exclusive with use_model_cache when False. + - use_model_cache: V1=True (cache z projection), V2=False (no caching). + Only valid when compute_pair_bias=True. + + Use Cases + --------- + U1: PairFormerModule (global attention, no window batching) + - multiplicity=1, compute_pair_bias=True + - k_in=s (queries equal keys) + + U2: AtomDiffusion with multiplicity (non-window-batched, all-to-all) + - multiplicity >= 1 + - k_in=s or pre-computed + + Note: Window batching use cases should use AttentionPairBiasShardwise instead. + """ + + def __init__( + self, + attn_pair_bias: nn.Module, + device_mesh: DeviceMesh, + ring_comm: AttentionPairBiasComm, + sdpa_with_bias_backend: SDPAWithBiasBackend = SDPAWithBiasBackend.REFERENCE, + # Configuration flags for V1/V2 API compatibility + apply_initial_norm: bool = False, # V1=True, V2=False (default to V2) + compute_pair_bias: bool = True, # V1=True, V2=configurable (False for DiffusionTransformerLayer) + use_model_cache: bool = False, # V1=True, V2=False (default to V2) + ) -> None: + """Initialize the attention pair bias layer. + + Parameters + ---------- + attn_pair_bias : nn.Module + The serial attention pair bias layer to convert to DTensor. + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + ring_comm : AttentionPairBiasComm + The ring communication object for context parallelism. + sdpa_with_bias_backend : SDPAWithBiasBackend, optional + The attention backend to use. Default is REFERENCE. + apply_initial_norm : bool, optional + Whether to apply LayerNorm to input s. V1=True, V2=False. Default False. + compute_pair_bias : bool, optional + Whether to compute pair bias (LayerNorm + Linear on z). V1=True, + V2=configurable (False for DiffusionTransformerLayer where z is + pre-computed bias). Mutually exclusive with use_model_cache=True + when False. Default True. + use_model_cache : bool, optional + Whether to cache z projection for diffusion rollout. V1=True, V2=False. + Only valid when compute_pair_bias=True. Default False. + + Raises + ------ + TypeError + If device_mesh is not a DeviceMesh or ring_comm is not an AttentionPairBiasComm. + ValueError + If sdpa_with_bias_backend is not supported, or if use_model_cache=True + with compute_pair_bias=False. + """ + super().__init__() + + # (0) Type check on serial module + if not isinstance(attn_pair_bias, (AttentionPairBiasSerialV1, AttentionPairBiasSerialV2)): + raise TypeError( + ", ".join( + [ + f"Instance {attn_pair_bias} should have type " + f"{AttentionPairBiasSerialV1} or {AttentionPairBiasSerialV2}", + f"but instead has type {type(attn_pair_bias)}.", + ] + ) + ) + + # (1) Set non-module, non-parameter attributes from serial module + self.c_s = attn_pair_bias.c_s + self.num_heads = attn_pair_bias.num_heads + self.head_dim = attn_pair_bias.head_dim + self.inf = attn_pair_bias.inf + # Mutable backend selection for scaled dot-product attention. + # Default is REFERENCE. To switch backend for the entire model, use + # ``model.apply(SetAttnPairBiasBackend(backend))`` + # (see boltz.distributed.model.modules.utils.SetAttnPairBiasBackend). + self.sdpa_with_bias_backend = ( + sdpa_with_bias_backend + if isinstance(sdpa_with_bias_backend, SDPAWithBiasBackend) + else SDPAWithBiasBackend(sdpa_with_bias_backend) + ) + if self.sdpa_with_bias_backend not in [ + SDPAWithBiasBackend.TORCH_FLEX_ATTN, + SDPAWithBiasBackend.REFERENCE, + ]: + raise ValueError( + f"Unsupported sdpa_with_bias_backend: {self.sdpa_with_bias_backend}. " + f"Only TORCH_FLEX_ATTN and REFERENCE are supported." + ) + + # Configuration flags — use_model_cache caches the z projection output, + # which requires compute_pair_bias=True. + if use_model_cache and not compute_pair_bias: + raise ValueError( + "use_model_cache=True requires compute_pair_bias=True because the cache " + "stores the z projection output. Got compute_pair_bias=False." + ) + self.apply_initial_norm = apply_initial_norm + self.compute_pair_bias = compute_pair_bias + self.use_model_cache = use_model_cache + + # Ring attention does not support window batching + self.use_window_batching = False + + self.device_mesh = device_mesh + self.ring_comm = ring_comm + + # (2) Sanity checks on non-module, non-parameter attributes + if not isinstance(self.device_mesh, DeviceMesh): + raise TypeError(f"Input '{device_mesh}' must be of type {DeviceMesh}. Got type {type(self.device_mesh)}.") + if not isinstance(self.ring_comm, AttentionPairBiasComm): + raise TypeError( + f"Input '{ring_comm}' must be of type {AttentionPairBiasComm}. Got type {type(self.ring_comm)}." + ) + + # (3) Initialize child modules explicitly from serial module + if self.apply_initial_norm: + self.norm_s = LayerNormParamsReplicated(attn_pair_bias.norm_s, device_mesh=device_mesh) + + self.proj_q = LinearParamsReplicated(layer_local=attn_pair_bias.proj_q, device_mesh=device_mesh) + self.proj_k = LinearParamsReplicated(layer_local=attn_pair_bias.proj_k, device_mesh=device_mesh) + self.proj_v = LinearParamsReplicated(layer_local=attn_pair_bias.proj_v, device_mesh=device_mesh) + self.proj_g = LinearParamsReplicated(layer_local=attn_pair_bias.proj_g, device_mesh=device_mesh) + self.proj_o = LinearParamsReplicated(layer_local=attn_pair_bias.proj_o, device_mesh=device_mesh) + + # (4) proj_z: Strip the Rearrange to avoid changing placements of z. + # When compute_pair_bias=True, serial proj_z is Sequential(LayerNorm, Linear, Rearrange) + # -> keep only LayerNorm and Linear. The permute is done manually in the forward pass. + # When compute_pair_bias=False, serial proj_z is just a Rearrange (no projection + # needed; z is already the pre-computed bias). + if self.compute_pair_bias: + self.proj_z = nn.Sequential( + LayerNormParamsReplicated(attn_pair_bias.proj_z[0], device_mesh=device_mesh), + LinearParamsReplicated(layer_local=attn_pair_bias.proj_z[1], device_mesh=device_mesh), + ) + + def forward( + self, + s: DTensor, + z: DTensor, + mask: DTensor, + pair_mask: Union[DTensor, None] = None, + multiplicity: int = 1, + k_in: Union[DTensor, None] = None, + model_cache: Union[OrderedDict, None] = None, + ) -> DTensor: + """Forward pass for ring attention with pair bias. + + Parameters + ---------- + s : DTensor + The input sequence tensor (queries), with shape (B, N, c_s) or (B*M, N, c_s) + where M is multiplicity. + z : DTensor + The input pairwise tensor, with shape (B, N, N, c_z). + mask : DTensor + The token mask tensor with shape (B, N) or (B*M, N). + pair_mask : DTensor or None, optional + The pairwise mask tensor with shape (B, N, N). If None, only uses 1D mask. + multiplicity : int, optional + The diffusion batch size, by default 1. + k_in : DTensor or None, optional + Pre-computed key input tensor. If None, uses s as key input (k_in=s). + For V2 API, caller should pass k_in explicitly. + model_cache : OrderedDict or None, optional + Cache for storing projected z during diffusion rollout. Only used if + use_model_cache=True was set at init. + + Returns + ------- + DTensor + The output tensor, with shape (B*M, N, c_s). + + Raises + ------ + ValueError + If mask shape is incompatible with k_in. + """ + # ------------------------------------------------- + # Begin DTensor ops + # DTensor metadata checks done for each operation + # ------------------------------------------------- + if self.apply_initial_norm: + s: DTensor = self.norm_s(s) # Layer norm + + # V2 API: k_in is passed explicitly; V1 API: k_in defaults to s + if k_in is None: + k_in = s + + # Sanity check: mask should have same sequence length as k_in + if mask.shape[-1] != k_in.shape[-2]: + raise ValueError( + f"mask sequence length ({mask.shape[-1]}) must match k_in sequence length ({k_in.shape[-2]}). " + f"For V2 API with to_keys transformation, transform mask before passing." + ) + + # Compute projections + q_proj_out: DTensor = self.proj_q(s) # (B, N, c_s) + k_proj_out: DTensor = self.proj_k(k_in) # (B, N, c_s) or (B, H, c_s) if transformed + v_proj_out: DTensor = self.proj_v(k_in) # (B, N, c_s) or (B, H, c_s) if transformed + g_proj_out: DTensor = self.proj_g(s) # (B, N, c_s) + + # ------------------------------------------------------------ + # Project z to num_heads dimensions (V1: compute_pair_bias=True) + # or use z as-is (V2: compute_pair_bias=False, z is pre-computed bias) + # ------------------------------------------------------------ + if self.compute_pair_bias: + # input z: (B, N, N, c_z) + # output z: (B, N, N, num_heads) after proj_z (without Rearrange) + if self.use_model_cache and model_cache is not None: + if "z" not in model_cache: + z: DTensor = self.proj_z(z) # (B, N, N, num_heads) + model_cache["z"] = z + else: + z = model_cache["z"] + else: + z: DTensor = self.proj_z(z) # (B, N, N, num_heads) + # else: z is already the pre-computed bias with shape (B, N, N, num_heads) + + # ------------------------------------------------------------ + # Compute context vectors + # ------------------------------------------------------------ + apb_context_vec_params = _AttentionPairBiasContextVecParams( + ring_comm=self.ring_comm, + multiplicity=multiplicity, + num_heads=self.num_heads, + head_dim=self.head_dim, + inf=self.inf, + use_window_batching=self.use_window_batching, + sdpa_with_bias_backend=self.sdpa_with_bias_backend, + ) + o_contex_vec = _AttentionPairBiasContexVecImpl.apply( + q_proj_out, + k_proj_out, + v_proj_out, + z, # (B, N, N, H) + mask, + pair_mask, + apb_context_vec_params, + ) + # ------------------------------------------------------------ + # Gate and project context vectors + # ------------------------------------------------------------ + gated_context_vec: DTensor = sigmoid_gate(x=o_contex_vec, g=g_proj_out) + o: DTensor = self.proj_o(gated_context_vec) + + return o + + +class AttentionPairBiasShardwise(nn.Module): + """Shardwise attention with pair bias for window-batched context parallelism. + + This module implements multi-head attention with pair bias specifically designed + for window batching scenarios in context parallelism (CP). Unlike the standard + `AttentionPairBias` which uses ring communication patterns, this implementation + operates on sharded windows where each shard can be processed independently. + + The key difference from `AttentionPairBias`: + - Used for window batching scenarios + - Uses `to_keys` function OR pre-computed `k_in` to transform queries to key space + - Operates on 4D single representations (B, K, W, D) and 5D pair representations + (B, K, W, H, num_heads) + - Does not apply multiplicity to z/mask, instead broadcasts them + + Configuration Flags + ------------------- + The module supports both V1 (Boltz-1x) and V2 (Boltz-2) API styles: + + - apply_initial_norm: V1=True (has norm_s LayerNorm), V2=False (no initial norm) + - compute_pair_bias: V1=True (always compute via LayerNorm+Linear), + V2=False (z is pre-computed bias, no projection needed). + Mutually exclusive with use_model_cache when False. + - use_model_cache: V1=True (cache z projection), V2=False (no caching). + Only valid when compute_pair_bias=True. + + Use Cases + --------- + V1 API (to_keys inside forward): + AtomTransformer.forward(to_keys=to_keys) + AttentionPairBiasShardwise.forward(to_keys=to_keys) + + V2 API (k_in pre-computed): + DiffusionTransformerLayer.forward(to_keys=to_keys) + k_in = to_keys(s) + mask = to_keys(mask) + AttentionPairBiasShardwise.forward(k_in=k_in, mask=mask) + + Attributes + ---------- + c_s : int + Hidden dimension of single representation (num_heads * head_dim). + num_heads : int + Number of attention heads. + head_dim : int + Dimension per attention head. + inf : float + Large value used for masking invalid positions. + apply_initial_norm : bool + Whether to apply layer normalization to input (V1=True, V2=False). + compute_pair_bias : bool + Whether to compute pair bias via LayerNorm+Linear (V1=True, V2=False). + use_model_cache : bool + Whether to cache z projection (V1=True, V2=False). + device_mesh : DeviceMesh + The device mesh for distributed computation. + sdpa_with_bias_backend : SDPAWithBiasBackend + Backend for scaled dot-product attention computation. + """ + + def __init__( + self, + attn_pair_bias: nn.Module, + device_mesh: DeviceMesh, + sdpa_with_bias_backend: SDPAWithBiasBackend = SDPAWithBiasBackend.REFERENCE, + # Configuration flags for V1/V2 API compatibility + apply_initial_norm: bool = False, # V1=True, V2=False (default to V2) + compute_pair_bias: bool = True, # V1=True, V2=False + use_model_cache: bool = False, # V1=True, V2=False (default to V2) + ) -> None: + """Initialize the shardwise attention pair bias layer. + + Parameters + ---------- + attn_pair_bias : nn.Module + The serial attention pair bias layer to convert to DTensor. + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + sdpa_with_bias_backend : SDPAWithBiasBackend, optional + Backend for computing scaled dot-product attention with bias. + Default is REFERENCE. + apply_initial_norm : bool, optional + Whether to apply LayerNorm to input s. V1=True, V2=False. Default False. + compute_pair_bias : bool, optional + Whether to compute pair bias (LayerNorm + Linear on z). V1=True, V2=False. + Mutually exclusive with use_model_cache=True when compute_pair_bias=False. + Default True. + use_model_cache : bool, optional + Whether to cache z projection for diffusion rollout. V1=True, V2=False. + Only valid when compute_pair_bias=True. Default False. + + Raises + ------ + TypeError + If device_mesh is not a DeviceMesh instance, or if attn_pair_bias is not + a recognized serial AttentionPairBias type. + ValueError + If use_model_cache=True with compute_pair_bias=False. + """ + super().__init__() + + # (0) Type check on serial module + if not isinstance(attn_pair_bias, (AttentionPairBiasSerialV1, AttentionPairBiasSerialV2)): + raise TypeError( + ", ".join( + [ + f"Instance {attn_pair_bias} should have type " + f"{AttentionPairBiasSerialV1} or {AttentionPairBiasSerialV2}", + f"but instead has type {type(attn_pair_bias)}.", + ] + ) + ) + + # (1) Set non-module, non-parameter attributes + self.c_s = attn_pair_bias.c_s + self.num_heads = attn_pair_bias.num_heads + self.head_dim = attn_pair_bias.head_dim + self.inf = attn_pair_bias.inf + + # Configuration flags — compute_pair_bias and use_model_cache are mutually exclusive: + # use_model_cache caches the z projection, which requires compute_pair_bias=True. + if use_model_cache and not compute_pair_bias: + raise ValueError( + "use_model_cache=True requires compute_pair_bias=True because the cache " + "stores the z projection output. Got compute_pair_bias=False." + ) + self.apply_initial_norm = apply_initial_norm + self.compute_pair_bias = compute_pair_bias + self.use_model_cache = use_model_cache + + self.device_mesh = device_mesh + # Mutable backend selection for scaled dot-product attention. + # Default is REFERENCE. To switch backend for the entire model, use + # ``model.apply(SetAttnPairBiasShardwiseBackend(backend))`` + # (see boltz.distributed.model.modules.utils.SetAttnPairBiasShardwiseBackend). + self.sdpa_with_bias_backend = sdpa_with_bias_backend + + # (2) Sanity checks on non-module, non-parameter attributes + if not isinstance(self.device_mesh, DeviceMesh): + raise TypeError(f"Input '{device_mesh}' must be of type {DeviceMesh}. Got type {type(self.device_mesh)}.") + + # (3) Initialize child modules explicitly from serial module + if self.apply_initial_norm: + self.norm_s = LayerNormParamsReplicated(attn_pair_bias.norm_s, device_mesh=device_mesh) + + self.proj_q = LinearParamsReplicated(layer_local=attn_pair_bias.proj_q, device_mesh=device_mesh) + self.proj_k = LinearParamsReplicated(layer_local=attn_pair_bias.proj_k, device_mesh=device_mesh) + self.proj_v = LinearParamsReplicated(layer_local=attn_pair_bias.proj_v, device_mesh=device_mesh) + self.proj_g = LinearParamsReplicated(layer_local=attn_pair_bias.proj_g, device_mesh=device_mesh) + self.proj_o = LinearParamsReplicated(layer_local=attn_pair_bias.proj_o, device_mesh=device_mesh) + + # (4) proj_z: Strip the Rearrange to avoid changing placements of z. + # When compute_pair_bias=True, serial proj_z is Sequential(LayerNorm, Linear, Rearrange) + # -> keep only LayerNorm and Linear. + # When compute_pair_bias=False, serial proj_z is just a Rearrange (no projection + # needed; z is already the pre-computed bias). + if self.compute_pair_bias: + self.proj_z = nn.Sequential( + LayerNormParamsReplicated(attn_pair_bias.proj_z[0], device_mesh=device_mesh), + LinearParamsReplicated(layer_local=attn_pair_bias.proj_z[1], device_mesh=device_mesh), + ) + + def forward( + self, + s: DTensor, + z: DTensor, + mask: DTensor, + to_keys: Union[Callable[[DTensor], DTensor], None] = None, + k_in: Union[DTensor, None] = None, + model_cache: Union[OrderedDict, None] = None, + ) -> DTensor: + """Forward pass for shardwise attention with pair bias. + + Computes multi-head attention with pair bias on window-batched inputs. + The attention is computed within each window shard independently. + + Two API modes are supported: + + V1 API (to_keys provided): + - to_keys transforms s to k_in internally + - mask is transformed by to_keys internally + - mask shape: (B, K, W) - query-aligned + + V2 API (k_in provided): + - k_in is pre-computed by caller + - mask is pre-transformed by caller to key-aligned shape + - mask shape: (B, K, H) - key-aligned + + Parameters + ---------- + s : DTensor + Input single representation tensor with shape (B * M, K, W, c_s) where: + - B is batch size + - M is multiplicity (diffusion samples) + - K is number of windows + - W is window size (typically 32) + - c_s is hidden dimension + z : DTensor + Input pair representation tensor with shape (B, K, W, H, c_z) where: + - H is the attention key dimension (typically 128) + - c_z is pair hidden dimension + Note: z is NOT multiplied by M; it broadcasts along the multiplicity axis. + mask : DTensor + Mask tensor indicating valid positions. + - V1 API (to_keys): shape (B, K, W) - will be transformed to (B, K, H) + - V2 API (k_in): shape (B, K, H) - already key-aligned + to_keys : Callable or None, optional + Function to transform tensors from query space (B, K, W, ...) to + key space (B, K, H, ...). Mutually exclusive with k_in. + k_in : DTensor or None, optional + Pre-computed key input tensor with shape (B*M, K, H, c_s). + Mutually exclusive with to_keys. + model_cache : OrderedDict or None, optional + Cache for storing projected z during diffusion rollout. Only used if + use_model_cache=True was set at init. + + Returns + ------- + DTensor + Output tensor with shape (B * M, K, W, c_s). + + Raises + ------ + ValueError + If both to_keys and k_in are provided (mutually exclusive). + If neither to_keys nor k_in is provided. + If s does not have 4 dimensions. + If z does not have 5 dimensions. + If mask dimensions don't match expected shapes. + If s.shape[0] is not divisible by z.shape[0]. + + Notes + ----- + This module avoids the multiplicity memory overhead by broadcasting z and mask + along the multiplicity dimension rather than replicating them. + """ + # Check mutual exclusivity of to_keys and k_in + if to_keys is not None and k_in is not None: + raise ValueError("to_keys and k_in are mutually exclusive. Provide only one.") + if to_keys is None and k_in is None: + raise ValueError("Either to_keys or k_in must be provided.") + + # Shape validations + if s.ndim != 4: + raise ValueError(f"s must have 4 dimensions (B*M, K, W, D), but got s.ndim={s.ndim}") + if z.ndim != 5: + raise ValueError(f"z must have 5 dimensions (B, K, W, H, c_z), but got z.ndim={z.ndim}") + + if s.shape[1:3] != z.shape[1:3]: + raise ValueError( + f"s.shape[1:3] must be equal to z.shape[1:3], but got s.shape[1:3]={s.shape[1:3]} " + f"and z.shape[1:3]={z.shape[1:3]}" + ) + + if s.shape[0] % z.shape[0] != 0: + # NOTE: this module doesn't apply multiplicity to z because it broadcasts z (and mask) + # to the attention score by design. This avoids multiplying the memory storage of + # the pair representation throughout the entire AtomTransformer and its submodules. + raise ValueError( + f"s.shape[0] must be divisible by z.shape[0], but got s.shape[0]={s.shape[0]} " + f"and z.shape[0]={z.shape[0]}" + ) + + # Validate mask shape based on API mode + if mask is not None: + if to_keys is not None: + # V1 API: mask should be query-aligned (B, K, W) + if mask.ndim != 3: + raise ValueError(f"V1 API: mask must have 3 dimensions (B, K, W), but got mask.ndim={mask.ndim}") + if mask.shape != z.shape[:3]: + raise ValueError( + f"V1 API: mask.shape must equal z.shape[:3], but got mask.shape={mask.shape} " + f"and z.shape[:3]={z.shape[:3]}" + ) + else: + # V2 API: mask should be key-aligned (B, K, H) + if mask.ndim != 3: + raise ValueError(f"V2 API: mask must have 3 dimensions (B, K, H), but got mask.ndim={mask.ndim}") + # For V2 API, mask.shape[2] should be H (key dimension), not W (query dimension) + if mask.shape[:2] != z.shape[:2]: + raise ValueError( + f"V2 API: mask.shape[:2] must equal z.shape[:2], but got mask.shape[:2]={mask.shape[:2]} " + f"and z.shape[:2]={z.shape[:2]}" + ) + + # ------------------------------------------------- + # Begin DTensor ops + # ------------------------------------------------- + if self.apply_initial_norm: + s: DTensor = self.norm_s(s) # Layer norm + + # Compute k_in and mask_key based on API mode + if to_keys is not None: + # V1 API: transform s and mask using to_keys + k_in_computed = to_keys(s) + mask_key = to_keys(mask) + else: + # V2 API: use provided k_in and mask (already key-aligned) + k_in_computed = k_in + mask_key = mask + + # Project z to num_heads dimensions (V1: compute_pair_bias=True) + # or use z as-is (V2: compute_pair_bias=False, z is pre-computed bias) + if self.compute_pair_bias: + # V1: project z through LayerNorm+Linear, optionally cache the result + if self.use_model_cache and model_cache is not None: + if "z" not in model_cache: + z: DTensor = self.proj_z(z) # (B, K, W, H, num_heads) + model_cache["z"] = z + else: + z = model_cache["z"] + else: + z: DTensor = self.proj_z(z) # (B, K, W, H, num_heads) + # else: V2 — z is already the pre-computed bias with shape + # (B, K, W, H, num_heads), no projection needed. + + # Compute projections + q_proj_out: DTensor = self.proj_q(s) # (B, K, W, c_s) + k_proj_out: DTensor = self.proj_k(k_in_computed) # (B, K, H, c_s) + v_proj_out: DTensor = self.proj_v(k_in_computed) # (B, K, H, c_s) + g_proj_out: DTensor = self.proj_g(s) # (B, K, W, c_s) + + o = _AttentionPairBiasShardwiseImpl.apply( + q_proj_out, + k_proj_out, + v_proj_out, + z, + mask_key, + self.sdpa_with_bias_backend, + self.num_heads, + self.head_dim, + self.inf, + ) + + # ------------------------------------------------------------ + # Gate and project context vectors + # ------------------------------------------------------------ + o_gated: DTensor = sigmoid_gate(x=o, g=g_proj_out) + # (B, K, W, c_s) + o_final: DTensor = self.proj_o(o_gated) + + return o_final diff --git a/src/boltz/distributed/model/layers/attention_impl.py b/src/boltz/distributed/model/layers/attention_impl.py new file mode 100644 index 000000000..919212959 --- /dev/null +++ b/src/boltz/distributed/model/layers/attention_impl.py @@ -0,0 +1,1510 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import warnings +from typing import NamedTuple, Union + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.autograd.function import FunctionCtx +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.nn.attention import SDPBackend, sdpa_kernel + +from boltz.distributed.comm import AttentionPairBiasComm +from boltz.distributed.model.layers.dtensor_metadata_tools import ( + raise_if_incorrect_dtensor_metadata_args, +) +from boltz.distributed.model.modules.utils import Precision, SDPAWithBiasBackend, setup_tf32_env +from boltz.distributed.utils import tiled_softmax_attention_update, update_exhaustive_strides + +try: + from torch.nn.attention.flex_attention import flex_attention + + flex_attention_compiled = torch.compile(flex_attention) + HAS_FLEX_ATTN = True +except ImportError: + flex_attention_compiled = None + HAS_FLEX_ATTN = False + + +class _AttentionPairBiasContextVecParams(NamedTuple): + """NamedTuple for attention pair bias context vector parameters.""" + + ring_comm: AttentionPairBiasComm + multiplicity: int + num_heads: int + head_dim: int + inf: float + use_window_batching: bool + sdpa_with_bias_backend: SDPAWithBiasBackend = SDPAWithBiasBackend.TORCH_FLEX_ATTN + + +class _AttentionPairBiasShardwiseImpl(torch.autograd.Function): + """Autograd function for shardwise attention with pair bias. + + This implements the forward and backward passes for window-batched attention + with pair bias in a DTensor-compatible manner. The computation is performed + locally on each shard without cross-rank communication (except for the implicit + DTensor distribution). + + The attention computation follows: + attn = softmax(q @ k.T / sqrt(head_dim) + z + mask_bias, dim=-2) + o = attn @ v + + Where dimensions are: + - q: (B * M, K, W, num_heads, head_dim) - queries per window + - k: (B * M, K, H, num_heads, head_dim) - keys (H = full attention span) + - v: (B * M, K, H, num_heads, head_dim) - values + - z: (B, K, W, H, num_heads) - pair bias (broadcasts over M) + - mask_key: (B, K, H) - key mask (broadcasts over M) + + The backward pass uses PyTorch autograd on the local computation graph, + avoiding the need for manual gradient derivation. + + FIXME: bf16 is currently broken in _AttentionPairBiasShardwiseImpl when using activation checkpointing with torch flex attention + + See Also + -------- + AttentionPairBiasShardwise : The nn.Module wrapper that calls this function. + SDPAWithBiasBackend : Backend options for the attention computation. + """ + + @staticmethod + def forward( + ctx: FunctionCtx, + q: DTensor, + k: DTensor, + v: DTensor, + z: DTensor, + mask_key: DTensor | None, + sdpa_with_bias_backend: SDPAWithBiasBackend, + num_heads: int, + head_dim: int, + inf: float, + ) -> DTensor: + """Forward pass for shardwise attention with pair bias. + + Computes multi-head attention with pair bias on window-batched inputs. + Supports multiple backends for the core SDPA computation. + + Parameters + ---------- + ctx : FunctionCtx + The autograd context object for saving tensors for backward. + q : DTensor + Query tensor with shape (B * M, K, W, D) where: + - B is batch size + - M is multiplicity (diffusion samples) + - K is number of windows + - W is window size (typically 32) + - D is hidden dimension (num_heads * head_dim) + k : DTensor + Key tensor with shape (B * M, K, H, D) where: + - H is the full attention key dimension (typically 128) + v : DTensor + Value tensor with shape (B * M, K, H, D), same shape as k. + z : DTensor + Pair bias tensor with shape (B, K, W, H, num_heads). + Note: Does not include multiplicity dimension; broadcasts over M. + mask_key : DTensor or None + Key mask tensor with shape (B, K, H) indicating valid key positions. + None if no masking is needed. + sdpa_with_bias_backend : SDPAWithBiasBackend + Backend for computing scaled dot-product attention: + - REFERENCE: Manual einsum implementation (most compatible) + - TORCH_SDPA_EFFICIENT_ATTENTION: PyTorch's scaled_dot_product_attention kernel with EFFICIENT_ATTENTION backend + - TORCH_FLEX_ATTN: PyTorch's FlexAttention with compiled score_mod + num_heads : int + Number of attention heads. + head_dim : int + Dimension per attention head. + inf : float + Large value used for masking invalid positions in attention. + + Returns + ------- + DTensor + Output tensor with shape (B * M, K, W, D). + + Raises + ------ + ValueError + If tensor dimensions don't match expected shapes. + If q, k, v, z, mask_key have inconsistent placements or device meshes. + If q.shape[0] is not divisible by z.shape[0] (multiplicity check). + + Notes + ----- + - The softmax is computed over the key dimension (dim=-2 in attention matrix) + - All inputs must have the same DTensor placements and device mesh + - Computation is promoted to at least FP32 for numerical stability + - The local computation graph is preserved for backward pass via autograd + """ + + if q.ndim != 4: # (B * M, K, W, D) + raise ValueError(f"Input q must have 4 dimensions. Got {q.ndim}.") + + if k.ndim != 4: # (B * M, K, H, D) + raise ValueError(f"Input k must have 4 dimensions. Got {k.ndim}.") + + if v.ndim != 4: # (B * M, K, H, D) + raise ValueError(f"Input v must have 4 dimensions. Got {v.ndim}.") + + if z.ndim != 5: # (B, K, W, H, num_heads) with potentially no multiplicity + raise ValueError(f"Input z must have 5 dimensions. Got {z.ndim}.") + + if mask_key is not None: + if mask_key.ndim != 3: # (B, K, H) with potentially no multiplicity + raise ValueError(f"Input mask_key must have 3 dimensions. Got {mask_key.ndim}.") + + if mask_key.shape != z.shape[:2] + (z.shape[3],): + raise ValueError( + f"Input mask_key must have the same shape as z.shape[:3]. Got {mask_key.shape} and {z.shape[:3]}." + ) + + # Shape checks on the input + if q.shape[:2] != k.shape[:2] or q.shape[:2] != v.shape[:2]: # B, K + raise ValueError( + f"Input q, k, v must have the same leading two B and K dimensions. Got {q.shape[:2]} and {k.shape[:2]} and {v.shape[:2]}" + ) + + if q.shape[-1] != k.shape[-1] or q.shape[-1] != v.shape[-1]: # D + raise ValueError( + f"Input q, k, v must have the same last dimension. Got {q.shape[-1]} and {k.shape[-1]} and {v.shape[-1]}." + ) + + if q.shape[-1] != num_heads * head_dim: + raise ValueError( + f"Input q.shape[-1] and num_heads * head_dim must have the same shape. Got {q.shape[-1]} and {num_heads * head_dim}." + ) + + if k.shape != v.shape: + raise ValueError(f"Input k and v must have the same shape. Got {k.shape} and {v.shape}.") + + if q.shape[0] % z.shape[0] != 0: # B * M % B == 0 + raise ValueError( + f"Input q.shape[0] must be a multiple of z.shape[0]. Got q.shape[0]={q.shape[0]} and z.shape[0]={z.shape[0]}." + ) + + if z.shape[1:3] != q.shape[1:3]: # K, W + raise ValueError( + f"Input q.shape[1:3] and z.shape[1:3] must have the same shape. Got {q.shape[1:3]} and {z.shape[1:3]}." + ) + + if z.shape[3] != k.shape[2]: # H + raise ValueError( + f"Input z.shape[3] and k.shape[2] must have the same shape. Got {z.shape[3]} and {k.shape[2]}." + ) + + if z.shape[-1] != num_heads: + raise ValueError( + f"Input z.shape[-1] and num_heads must have the same shape. Got {z.shape[-1]} and {num_heads}." + ) + + if sdpa_with_bias_backend == SDPAWithBiasBackend.TORCH_SDPA_EFFICIENT_ATTENTION and ( + q.shape[-1] % 4 != 0 or k.shape[-1] % 4 != 0 or v.shape[-1] % 4 != 0 + ): + # torch SDPA errors are shown as warnings instead of errors so we raise for it instead + raise ValueError( + f"Torch SDPA Efficient Attention kernel requires q, k, v must have a last dimension that is divisible by 4. " + f"Got {q.shape[-1]} and {k.shape[-1]} and {v.shape[-1]}." + ) + + # placements and device mesh checks + if ( + q.placements != k.placements + or q.placements != v.placements + or q.placements != z.placements + or q.placements != mask_key.placements + ): + raise ValueError( + f"Input q, k, v, z, and mask must have the same placements. Got {q.placements} and {k.placements} and {v.placements} and {z.placements} and {mask_key.placements}." + ) + if ( + q.device_mesh != k.device_mesh + or q.device_mesh != v.device_mesh + or q.device_mesh != z.device_mesh + or q.device_mesh != mask_key.device_mesh + ): + raise ValueError( + f"Input q, k, v, z, and mask must be on the same device mesh. Got {q.device_mesh} and {k.device_mesh} and {v.device_mesh} and {z.device_mesh} and {mask_key.device_mesh}." + ) + + multiplicity = q.shape[0] // z.shape[0] + + q_local_orig = q.to_local().detach().requires_grad_(q.requires_grad) + k_local_orig = k.to_local().detach().requires_grad_(k.requires_grad) + v_local_orig = v.to_local().detach().requires_grad_(v.requires_grad) + z_local_orig = z.to_local().detach().requires_grad_(z.requires_grad) # (B, K, W, H, num_heads) + if mask_key is not None: + mask_key_bias_local_orig = mask_key.to_local().detach().requires_grad_(False) # (B, K, H) + else: + mask_key_bias_local_orig = None + + with torch.enable_grad(): + # enable grad to build a local graph for the shardwise operations + # We detach inputs to create 'leaf' nodes for our local graph. + # NOTE: mask_key_bias_local is not differentiable but nonetheless we need to detach it + q_local = q_local_orig.unflatten(-1, (num_heads, head_dim)) # (B * M, K, W, num_heads, head_dim) + k_local = k_local_orig.unflatten(-1, (num_heads, head_dim)) # (B * M, K, H, num_heads, head_dim) + v_local = v_local_orig.unflatten(-1, (num_heads, head_dim)) # (B * M, K, H, num_heads, head_dim) + z_local = z_local_orig + + if mask_key_bias_local_orig is not None: + mask_key_bias_local = mask_key_bias_local_orig[:, :, None, :, None] # (B, K, 1, H, 1) + mask_key_bias_local = (1 - mask_key_bias_local.to(dtype=q_local.dtype)) * -inf + else: + mask_key_bias_local = None + + if multiplicity > 1: + # unflatten the multiplicity axis so that mask and z are broadcasted along it + # This has to be (B * multiplicity, ...) -> (B, multiplicity, ...) + # but never (B * multiplicity, ...) -> (multiplicity, B, ...) due to the upstream + # order of multiplicity application + q_local = q_local.unflatten(0, (-1, multiplicity)) + k_local = k_local.unflatten(0, (-1, multiplicity)) + v_local = v_local.unflatten(0, (-1, multiplicity)) + # add singleton axis to mask and z for broadcasting + z_local = z_local.unsqueeze(1) + if mask_key_bias_local is not None: + mask_key_bias_local = mask_key_bias_local.unsqueeze(1) + + # use at least FP32 for AttnPairBias + dtype_compute = torch.promote_types(q_local.dtype, torch.float32) + q_local = q_local.to(dtype_compute) + k_local = k_local.to(dtype_compute) + v_local = v_local.to(dtype_compute) + z_local = z_local.to(dtype_compute) + if mask_key_bias_local is not None: + mask_key_bias_local = mask_key_bias_local.to(dtype_compute) + + if ( + sdpa_with_bias_backend == SDPAWithBiasBackend.TORCH_SDPA_EFFICIENT_ATTENTION + or sdpa_with_bias_backend == SDPAWithBiasBackend.TORCH_FLEX_ATTN + ): + # save shape for later reshaping the kernel output back + if multiplicity > 1: + B_local, M_local, K_local, W_local = q_local.shape[:4] + else: + B_local, K_local, W_local = q_local.shape[:3] + # torch sdpa kernel only supports 4-axes input tensors so we need to + # move up the head axis then flatten + # (..., W or H, num_heads, head_dim) -> (..., num_heads, W or H, head_dim) + q_local = q_local.transpose(-3, -2) + k_local = k_local.transpose(-3, -2) + v_local = v_local.transpose(-3, -2) + # sdpa kernel only accepts bias so we need to sum mask and z into bias + b_local = z_local + if mask_key_bias_local is not None: + b_local = b_local + mask_key_bias_local + # (..., W, H, num_heads) -> (..., num_heads, W, H) + b_local = b_local.moveaxis(-1, -3) + if multiplicity > 1: + # (B, M, K, H, ...) -> (M, B, K, H, ...) + q_local = q_local.moveaxis(1, 0).flatten(1, 3) + k_local = k_local.moveaxis(1, 0).flatten(1, 3) + v_local = v_local.moveaxis(1, 0).flatten(1, 3) + b_local = b_local.moveaxis(1, 0).flatten(1, 3) + else: + # (B, K, H, ...) -> (B*K, H, ...) + q_local = q_local.flatten(0, 1) + k_local = k_local.flatten(0, 1) + v_local = v_local.flatten(0, 1) + b_local = b_local.flatten(0, 1) + # run the kernel + # NOTE: except for SDPBackend.MATH, other kernels can't guarantee consistent backward pass + # results for the invalid atoms' gradients, which doesn't matter for applications but do matter + # for testing requirements. + if sdpa_with_bias_backend == SDPAWithBiasBackend.TORCH_SDPA_EFFICIENT_ATTENTION: + # NOTE: dtype_compute is at least FP32 so technically CUDNN_ATTENTION and FLASH_ATTENTION + # will not work. + with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): + # NOTE: the scale factor is already applied by the kernel, which is default to 1/sqrt(head_dim) + # Kernel requires all input to have stride-1 along the last axis + o_local = torch.nn.functional.scaled_dot_product_attention( + q_local, k_local, v_local, attn_mask=b_local.contiguous() + ) + elif sdpa_with_bias_backend == SDPAWithBiasBackend.TORCH_FLEX_ATTN: + # flex_attention (compiled Triton/Inductor) requires power-of-2 head_dim and head_dim >= 16. + if not ( + HAS_FLEX_ATTN + and q_local.is_cuda + and q_local.dtype != torch.float64 + and is_power_of_2(head_dim) + and head_dim >= 16 + ): + raise RuntimeError( + f"flex_attention requirements not met: " + f"HAS_FLEX_ATTN={HAS_FLEX_ATTN}, is_cuda={q_local.is_cuda}, " + f"dtype={q_local.dtype}, head_dim={head_dim}, " + f"q_seq_len={q_local.size(-2)}, k_seq_len={k_local.size(-2)}" + ) + + if multiplicity > 1: + # Squeeze out the M=1 dimension so b_local is (B*K*H, Sq, Sk) to avoid + # data-dependent indexing of b[0, ...] + # (B * K * num_heads, W, H) + b_local = b_local.squeeze(0) + + def add_bias_to_attn_score( + score: torch.Tensor, + batch: torch.Tensor, + head: torch.Tensor, + q_idx: torch.Tensor, + k_idx: torch.Tensor, + ) -> torch.Tensor: + return score + b_local[head, q_idx, k_idx] + + else: + # b_local is (B * K, num_heads, W, H) + # with same batch size as q/k/v + def add_bias_to_attn_score( + score: torch.Tensor, + batch: torch.Tensor, + head: torch.Tensor, + q_idx: torch.Tensor, + k_idx: torch.Tensor, + ) -> torch.Tensor: + return score + b_local[batch, head, q_idx, k_idx] + + with setup_tf32_env(Precision.FP32): + o_local = flex_attention_compiled(q_local, k_local, v_local, score_mod=add_bias_to_attn_score) + # reshape the tensor back: + if multiplicity > 1: + # (M, B * K * num_heads, ...) -> (B, M, K, ...) + o_local = o_local.unflatten(1, (B_local, K_local, num_heads)).moveaxis(0, 1) + else: + # (B * K, ...) -> (B, K, ...) + o_local = o_local.unflatten(0, (B_local, K_local)) + # (..., num_heads, W, head_dim) -> (..., W, num_heads, head_dim) + o_local = o_local.transpose(-3, -2) + elif sdpa_with_bias_backend == SDPAWithBiasBackend.REFERENCE: + with setup_tf32_env(Precision.FP32), torch.amp.autocast("cuda", enabled=False): + attn = torch.einsum("...wnd,...hnd->...whn", q_local, k_local) + attn = attn / head_dim**0.5 + attn = attn + z_local + if mask_key_bias_local is not None: + attn = attn + mask_key_bias_local + attn = attn.softmax(dim=-2) # axis = -2 is the key dimension + o_local = torch.einsum("...whn,...hnd->...wnd", attn, v_local) # (B, K, W, num_heads, head_dim) + # (..., W, num_heads, head_dim) -> (..., W, num_heads * head_dim) + o_local = o_local.flatten(-2, -1).to(q.dtype) # (B, K, W, c_s) + + if multiplicity > 1: + # (B, multiplicity, ...) -> (B * multiplicity, ...) + o_local = o_local.flatten(0, 1) + + # save the detached tensors for backward pass -- they hold the graph structure + ctx.save_for_backward(q_local_orig, k_local_orig, v_local_orig, z_local_orig, mask_key_bias_local_orig, o_local) + ctx.device_mesh = q.device_mesh + ctx.placements = q.placements + ctx.q_shape = q.shape + ctx.q_stride = q.stride() + ctx.k_shape = k.shape + ctx.k_stride = k.stride() + ctx.v_shape = v.shape + ctx.v_stride = v.stride() + ctx.z_shape = z.shape + ctx.z_stride = z.stride() + + o_dtensor = DTensor.from_local( + o_local.detach(), + device_mesh=q.device_mesh, + placements=q.placements, + shape=q.shape, + stride=q.stride(), + ) + + return o_dtensor + + @staticmethod + def backward( + ctx: FunctionCtx, + grad_output: DTensor, + ) -> tuple[DTensor | None, DTensor | None, DTensor | None, DTensor | None, None, None, None, None, None]: + """Backward pass for shardwise attention with pair bias. + + Computes gradients by backpropagating through the local computation graph + that was built during the forward pass. This leverages PyTorch's autograd + rather than manual gradient computation. + + Parameters + ---------- + ctx : FunctionCtx + The autograd context containing saved tensors from forward: + - q_local_orig, k_local_orig, v_local_orig, z_local_orig: Input tensors + - mask_key_bias_local_orig: Mask bias (non-differentiable) + - o_local: Output tensor that holds the computation graph + - device_mesh, placements: DTensor metadata + - q_shape, q_stride, etc.: Shape/stride info for DTensor reconstruction + grad_output : DTensor + Gradient of loss with respect to output, shape (B * M, K, W, D). + + Returns + ------- + tuple[DTensor | None, ...] + Gradients for each forward input in order: + - dq: DTensor or None, gradient for q with shape (B * M, K, W, D) + - dk: DTensor or None, gradient for k with shape (B * M, K, H, D) + - dv: DTensor or None, gradient for v with shape (B * M, K, H, D) + - dz: DTensor or None, gradient for z with shape (B, K, W, H, num_heads) + - None: mask_key (non-differentiable) + - None: sdpa_with_bias_backend (non-differentiable) + - None: num_heads (non-differentiable) + - None: head_dim (non-differentiable) + - None: inf (non-differentiable) + + Notes + ----- + The gradient computation follows the chain rule for attention: + + Forward (with einsum notation, ignoring multiplicity for clarity): + q_local: (B, K, W, num_heads, head_dim) - "bkwid" + k_local: (B, K, H, num_heads, head_dim) - "bkhid" + v_local: (B, K, H, num_heads, head_dim) - "bkhid" + attn = softmax(einsum("bkwid,bkhid->bkwhi", q, k) / sqrt(d) + z + mask_bias, dim=h) + o = einsum("bkwhi,bkhid->bkwid", attn, v) + + Backward: + dv = einsum("bkwhi,bkwid->bkhid", attn, grad_output) + d_attn = einsum("bkwid,bkhid->bkwhi", grad_output, v) + d_pre_softmax = attn * (d_attn - sum(attn * d_attn, dim=h, keepdim=True)) + dz = d_pre_softmax + dq = einsum("bkwhi,bkhid->bkwid", d_pre_softmax, k) / sqrt(d) + dk = einsum("bkwhi,bkwid->bkhid", d_pre_softmax, q) / sqrt(d) + + The actual implementation uses torch.autograd.grad on the saved local + computation graph for correctness and maintainability. + """ + + # retrieve the leaf nodes for the local graph + q_local, k_local, v_local, z_local, _, o_local = ctx.saved_tensors + inputs_needing_grad = [t for t in (q_local, k_local, v_local, z_local) if t.requires_grad] + if not inputs_needing_grad: + # Short-circuit if nothing needed gradients (rare but possible) + return None, None, None, None, None, None, None, None, None + grad_output_local = grad_output.to_local() + # backprop via the local graph -- grads_local only contains the grads for those in inputs_needing_grad + with setup_tf32_env(Precision.FP32), torch.amp.autocast("cuda", enabled=False): + grads_local = torch.autograd.grad( + outputs=[o_local], + inputs=inputs_needing_grad, + grad_outputs=[grad_output_local], + retain_graph=False, # Frees the local graph immediately + ) + + iter_grads_local = iter(grads_local) + dq_local = next(iter_grads_local) if q_local.requires_grad else None + dk_local = next(iter_grads_local) if k_local.requires_grad else None + dv_local = next(iter_grads_local) if v_local.requires_grad else None + dz_local = next(iter_grads_local) if z_local.requires_grad else None + + if dq_local is not None: + dq = DTensor.from_local( + dq_local, device_mesh=ctx.device_mesh, placements=ctx.placements, shape=ctx.q_shape, stride=ctx.q_stride + ) + else: + dq = None + + if dk_local is not None: + dk = DTensor.from_local( + dk_local, device_mesh=ctx.device_mesh, placements=ctx.placements, shape=ctx.k_shape, stride=ctx.k_stride + ) + else: + dk = None + + if dv_local is not None: + dv = DTensor.from_local( + dv_local, device_mesh=ctx.device_mesh, placements=ctx.placements, shape=ctx.v_shape, stride=ctx.v_stride + ) + else: + dv = None + + if dz_local is not None: + dz = DTensor.from_local( + dz_local, device_mesh=ctx.device_mesh, placements=ctx.placements, shape=ctx.z_shape, stride=ctx.z_stride + ) + else: + dz = None + + return dq, dk, dv, dz, None, None, None, None, None + + +class _AttentionPairBiasContexVecImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx: FunctionCtx, + q: DTensor, + k: DTensor, + v: DTensor, + z: DTensor, + mask: DTensor, + pair_mask: Union[DTensor, None], + apb_context_vec_params: _AttentionPairBiasContextVecParams, + ) -> DTensor: + """ + + c_s = num_heads * head_dim checked in vanilla AttentionPairBias.__init__ + + Below, N is the global number of tokens, H is the number of heads. + + Parameters + ---------- + ctx: FunctionCtx + The context object. + q : DTensor + query vectors computed by projection, (B, N, c_s) + k : DTensor + key vectors computed by projection, (B, N, c_s) + v : DTensor + value vectors computed by projection, (B, N, c_s) + z : DTensor + z, (B, H, N, N) + mask : torch.Tensor + The pairwise mask tensor (B, N) + multiplicity : int, optional + The diffusion batch size, by default 1 + pair_mask: Union[DTensor, None] + The pairwise mask tensor. + apb_context_vec_params: tuple + The parameters for the attention pair bias context vector. + + Key features: + - Distributed computation across device meshes with various sharding strategies + - Memory-efficient implementation that operates on local tensor chunks + - Supports gradient computation through custom backward pass + - Validates tensor compatibility (type, device mesh, placements, shapes)` + + Raises + ------ + TypeError + If dtensor_instance is not a DTensor. + ValueError + If the DTensor metadata is incorrect. + """ + # Check input metadata + _AttentionPairBiasContexVecImpl.check_forward_input_metadata_and_store( + ctx, q, k, v, z, mask, pair_mask, apb_context_vec_params + ) + # Check implementation scope + _AttentionPairBiasContexVecImpl.check_forward_input_for_impl_state(apb_context_vec_params) + + # ---------------------------------------------- + # Setup inputs to RingAttention.forward() + # ------------------------------------------------------- + ( + ring_comm, + multiplicity, + _, + _, + inf, + use_window_batching, + sdpa_with_bias_backend, + ) = apb_context_vec_params + + requires_grad = any(p.requires_grad for p in (q, k, v, z)) + + ctx.mark_non_differentiable(mask) + mask_local: Tensor = mask.to_local() + + # overlay mask comm with qkv projection, pair_mask_ij <- pair_mask_ij + mask_j + mask_recv: Tensor = ring_comm.comm_transpose_mask.enqueue_to_dispatch(mask_local.contiguous()) + + dtype_input = q.dtype + dtype_compute = torch.promote_types(dtype_input, torch.float32) + ctx.dtype_compute = dtype_compute + ctx.dtype_input = dtype_input + + q_local: Tensor = q.to_local().to(dtype=dtype_compute) + k_local: Tensor = k.to_local().to(dtype=dtype_compute) + v_local: Tensor = v.to_local().to(dtype=dtype_compute) + z_local: Tensor = z.to_local().to(dtype=dtype_compute) + + ctx.B_each_chunk = q_local.shape[0] + ctx.N_each_chunk = q_local.shape[1] + + single_rep_view_shape = ctx.B_each_chunk, ctx.N_each_chunk, ctx.H, ctx.head_dim + q_local = q_local.view(single_rep_view_shape).requires_grad_(q.requires_grad) + k_local = k_local.view(single_rep_view_shape).requires_grad_(k.requires_grad) + v_local = v_local.view(single_rep_view_shape).requires_grad_(v.requires_grad) + z_local = z_local.permute(0, 3, 1, 2).requires_grad_(z.requires_grad) + + if requires_grad: + ctx.multiplicity = multiplicity + + ring_comm.comm_transpose_mask.wait_until_finished() + + if use_window_batching or pair_mask is None: # original behavior + pair_mask_local = mask_recv[:, None, None, :] + else: # only atom-level has pair_mask + ctx.mark_non_differentiable(pair_mask) + pair_mask_local: Tensor = pair_mask.to_local() # shape = (B, I, J) + pair_mask_local = pair_mask_local[:, None, :, :] * mask_recv[:, None, None, :] + + pair_mask_local = pair_mask_local.to(dtype=dtype_compute) + + with torch.autocast("cuda", enabled=False): + o_local, ring_attention_simple_data_for_bw = ring_attention_simple_forward( + q_local, + k_local, + v_local, + z_local, + pair_mask_local, + ring_comm, + inf, + sdpa_with_bias_backend, + ) + if requires_grad: + # Unpack tensors for save_for_backward to enable automatic memory management and hook support + ctx.save_for_backward( + ring_attention_simple_data_for_bw.q_store, + ring_attention_simple_data_for_bw.k_t_store, + ring_attention_simple_data_for_bw.v_t_store, + ring_attention_simple_data_for_bw.z_store, + ring_attention_simple_data_for_bw.lse_m, + ring_attention_simple_data_for_bw.o_store, + ) + ctx.ring_comm = ring_attention_simple_data_for_bw.ring_comm + ctx.sdpa_with_bias_backend = ring_attention_simple_data_for_bw.sdpa_with_bias_backend + + # --------------------------------------------------------- + # end custom communication + # --------------------------------------------------------- + o_local = o_local.reshape(ctx.B_each_chunk, ctx.N_each_chunk, ctx.c_s) # o_local_b + + # Compute output shape and stride + shape_output = v.shape[:-1] + (o_local.shape[-1],) + + strides_output = update_exhaustive_strides(o_local.shape, o_local.stride(), shape_output) + + o = DTensor.from_local( + o_local.to(dtype=dtype_input), + device_mesh=ctx.device_mesh, + placements=ctx.single_rep_placements, + shape=shape_output, + stride=strides_output, + ) + return o + + @staticmethod + def backward( + ctx: FunctionCtx, + grad_output: DTensor, + ) -> tuple[DTensor, DTensor, DTensor, DTensor, None, None, None]: + """Backward pass implementation. + + Parameters + ---------- + ctx: FunctionCtx + The context object. + grad_output: DTensor + The gradient of the output tensor. + + Raises + ------ + TypeError + If dtensor_instance is not a DTensor. + ValueError + If the DTensor metadata is incorrect. + """ + _AttentionPairBiasContexVecImpl.check_backward_input_metadata(ctx, grad_output) + + # Get ref to local tensor, and reshape to (B, I, H, D) + grad_output_local: Tensor = grad_output.to_local().reshape( + ctx.B_each_chunk, ctx.N_each_chunk, ctx.H, ctx.head_dim + ) + grad_output_local = grad_output_local.to(dtype=ctx.dtype_compute) + + # Call backward separately via refactored-out function + q_store, k_t_store, v_t_store, z_store, lse_m, o_store = ctx.saved_tensors + data_for_backward = RingAttentionSimpleDataForBackward( + q_store=q_store, + k_t_store=k_t_store, + v_t_store=v_t_store, + z_store=z_store, + lse_m=lse_m, + ring_comm=ctx.ring_comm, + sdpa_with_bias_backend=ctx.sdpa_with_bias_backend, + o_store=o_store, + multiplicity=ctx.multiplicity, + ) + del ctx.ring_comm + del ctx.sdpa_with_bias_backend + + grad_q, grad_k, grad_v, grad_z = ring_attention_simple_backward( + data_for_backward=data_for_backward, + do=grad_output_local, + ) + del data_for_backward # free up q_store, k_t_store, v_t_store, z_store, lse_m immediately + grad_q: Tensor = grad_q.to(dtype=ctx.dtype_input) # (B_each_chunk, N_each_chunk, H, D) + grad_k: Tensor = grad_k.to(dtype=ctx.dtype_input) # (B_each_chunk, N_each_chunk, H, D) + grad_v: Tensor = grad_v.to(dtype=ctx.dtype_input) # (B_each_chunk, N_each_chunk, H, D) + grad_z: Tensor = grad_z.to(dtype=ctx.dtype_input) # (B_each_chunk, N_each_chunk, H, D) + + # Reshape, allocate new memory + single_rep_target_shape = (ctx.B_each_chunk, ctx.N_each_chunk, ctx.c_s) + grad_q_flat: Tensor = grad_q.reshape(single_rep_target_shape) + grad_k_flat: Tensor = grad_k.reshape(single_rep_target_shape) + grad_v_flat: Tensor = grad_v.reshape(single_rep_target_shape) + + grad_z_flat: Tensor = grad_z.permute((0, 2, 3, 1)) + + grad_q_dtensor = DTensor.from_local( + grad_q_flat, + device_mesh=ctx.device_mesh, + placements=ctx.single_rep_placements, + shape=ctx.shape_q, + stride=ctx.stride_q, + ) + grad_k_dtensor = DTensor.from_local( + grad_k_flat, + device_mesh=ctx.device_mesh, + placements=ctx.single_rep_placements, + shape=ctx.shape_k, + stride=ctx.stride_k, + ) + grad_v_dtensor = DTensor.from_local( + grad_v_flat, + device_mesh=ctx.device_mesh, + placements=ctx.single_rep_placements, + shape=ctx.shape_v, + stride=ctx.stride_v, + ) + grad_z_dtensor = DTensor.from_local( + grad_z_flat, + device_mesh=ctx.device_mesh, + placements=ctx.pair_rep_placements, + shape=ctx.shape_z, + stride=ctx.stride_z, + ) + _AttentionPairBiasContexVecImpl.check_backward_output_metadata( + ctx, + grad_q_dtensor, + grad_k_dtensor, + grad_v_dtensor, + ) + return grad_q_dtensor, grad_k_dtensor, grad_v_dtensor, grad_z_dtensor, None, None, None + + @staticmethod + def check_forward_input_metadata_and_store( + ctx: FunctionCtx, + q: DTensor, + k: DTensor, + v: DTensor, + z: DTensor, + mask: DTensor, + pair_mask: DTensor | None, + apb_context_vec_params: _AttentionPairBiasContextVecParams, + ) -> None: + ( + _, + _, + num_heads, + head_dim, + _, + _, + _, + ) = apb_context_vec_params + + ctx.H = num_heads + ctx.head_dim = head_dim + + ctx.B = q.shape[0] + ctx.N = q.shape[1] + ctx.c_s = q.shape[-1] + + placements_single_expected = (Shard(0), Shard(1), Replicate()) + placements_pair_expected = (Shard(0), Shard(1), Shard(2)) + ctx.single_rep_placements = q.placements + ctx.pair_rep_placements = z.placements + ctx.device_mesh = q.device_mesh + ctx.shape_q = q.shape + ctx.stride_q = q.stride() + ctx.shape_k = k.shape + ctx.stride_k = k.stride() + ctx.shape_v = v.shape + ctx.stride_v = v.stride() + ctx.shape_z = z.shape + ctx.stride_z = z.stride() + + check_metadata = raise_if_incorrect_dtensor_metadata_args + + check_metadata( + q, + "q", + check_for_partial_placements=True, + expected_placements=placements_single_expected, + ) + check_metadata( + k, + "k", + (ctx.B, ctx.N, ctx.c_s), + expected_device_mesh=ctx.device_mesh, + expected_placements=ctx.single_rep_placements, + ) + check_metadata( + v, + "v", + (ctx.B, ctx.N, ctx.c_s), + expected_device_mesh=ctx.device_mesh, + expected_placements=ctx.single_rep_placements, + ) + check_metadata( + z, + "z", + None, # shape can be different from single representation(s) due to multiplicity + expected_device_mesh=ctx.device_mesh, + check_for_partial_placements=True, + expected_placements=placements_pair_expected, + ) + check_metadata( + mask, + "mask", + None, + expected_device_mesh=ctx.device_mesh, + expected_placements=ctx.single_rep_placements, + ) + if pair_mask is not None: + check_metadata( + pair_mask, + "pair_mask", + None, # shape can be different from single representation(s) due to multiplicity + expected_device_mesh=ctx.device_mesh, + expected_placements=ctx.pair_rep_placements, + ) + + @staticmethod + def check_backward_input_metadata(ctx: FunctionCtx, grad_output: DTensor) -> None: + raise_if_incorrect_dtensor_metadata_args( + grad_output, + "grad_output", + expected_shape=(ctx.B, ctx.N, ctx.c_s), + expected_device_mesh=ctx.device_mesh, + expected_placements=ctx.single_rep_placements, + ) + + @staticmethod + def check_backward_output_metadata( + ctx: FunctionCtx, + grad_q_dtensor: DTensor, + grad_k_dtensor: DTensor, + grad_v_dtensor: DTensor, + ) -> None: + """DTensor.from_local(..) requires the specification of stride if + shape is specified, so the specification of stride is side-stepped + in this usage by checking the shape determined by DTensor library.""" + metadata_tuple = ( + (grad_q_dtensor, "grad_q", (ctx.B, ctx.N, ctx.c_s)), + (grad_k_dtensor, "grad_k", (ctx.B, ctx.N, ctx.c_s)), + (grad_v_dtensor, "grad_v", (ctx.B, ctx.N, ctx.c_s)), + ) + for dtensor_instance, dtensor_name, expected_shape in metadata_tuple: + if not dtensor_instance.shape == expected_shape: + raise ValueError( + ", ".join( + [ + f"dtensor '{dtensor_name}' should have shape {expected_shape}", + f"but instead has shape {dtensor_instance.shape}.", + ] + ) + ) + + @staticmethod + def check_forward_input_for_impl_state(apb_context_vec_params: _AttentionPairBiasContextVecParams) -> None: + """ + Check the implementation scope of the forward pass. + + Parameters + ---------- + apb_context_vec_params: AttentionPairBiasContextVecParams + The parameters for the attention pair bias context vector. + + Returns + ------- + None + + Raises + ------ + NotImplementedError + """ + ( + _, + _, + _, + _, + _, + use_window_batching, + _, + ) = apb_context_vec_params + + if use_window_batching: + raise NotImplementedError(f"use_window_batching={use_window_batching} is not implemented") + + +def ring_attention( + q: Tensor, + k: Tensor, + v: Tensor, + z: Tensor, + pair_mask: Tensor, + mask: Tensor, + ring_comm: AttentionPairBiasComm, + inf: float = 1e6, +) -> Tensor: + """Functional interface to RingAttention autograd function. + + Based on vanilla torch tensors. + """ + return RingAttention.apply(q, k, v, z, pair_mask, mask, ring_comm, inf) + + +class RingAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: Tensor, + k: Tensor, + v: Tensor, + z: Tensor, + pair_mask: Tensor, + mask: Tensor, + ring_comm: AttentionPairBiasComm, + inf: float = 1e6, + ) -> Tensor: + pair_mask = pair_mask[:, None, :, :] * mask[:, None, None, :] + o, data_for_backward = ring_attention_simple_forward( + q=q, + k=k, + v=v, + z=z, + pair_mask=pair_mask, + ring_comm=ring_comm, + inf=inf, + ) + # Unpack tensors for save_for_backward to enable automatic memory management and hook support + ctx.save_for_backward( + data_for_backward.q_store, + data_for_backward.k_t_store, + data_for_backward.v_t_store, + data_for_backward.z_store, + data_for_backward.lse_m, + data_for_backward.o_store, + ) + ctx.ring_comm = data_for_backward.ring_comm + ctx.sdpa_with_bias_backend = data_for_backward.sdpa_with_bias_backend + return o + + @staticmethod + def backward(ctx, do: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, None, None, None, None]: + q_store, k_t_store, v_t_store, z_store, lse_m, o_store = ctx.saved_tensors + data_for_backward = RingAttentionSimpleDataForBackward( + q_store=q_store, + k_t_store=k_t_store, + v_t_store=v_t_store, + z_store=z_store, + lse_m=lse_m, + ring_comm=ctx.ring_comm, + sdpa_with_bias_backend=ctx.sdpa_with_bias_backend, + o_store=o_store, + ) + del ctx.ring_comm + del ctx.sdpa_with_bias_backend + dq, dk, dv, dz = ring_attention_simple_backward(data_for_backward=data_for_backward, do=do) + return dq, dk, dv, dz, None, None, None, None + + +class RingAttentionSimpleDataForBackward(NamedTuple): + """Data for backward pass of ring attention.""" + + q_store: Tensor + k_t_store: Tensor + v_t_store: Tensor + z_store: Tensor + lse_m: Tensor + ring_comm: AttentionPairBiasComm + sdpa_with_bias_backend: SDPAWithBiasBackend = SDPAWithBiasBackend.REFERENCE + o_store: Tensor | None = None + multiplicity: int = 1 + + +def is_power_of_2(n: int) -> bool: + """Check if n is a power of 2.""" + return (n > 0) and (n & (n - 1) == 0) + + +def ring_attention_simple_forward( + q: Tensor, + k: Tensor, + v: Tensor, + z: Tensor, + pair_mask: Tensor, + ring_comm: AttentionPairBiasComm, + inf: float = 1e6, + sdpa_with_bias_backend: SDPAWithBiasBackend = SDPAWithBiasBackend.TORCH_FLEX_ATTN, +) -> tuple[Tensor, RingAttentionSimpleDataForBackward]: + """Forward pass of ring attention. + + example sharding strategy on N_tokens=2, world_size=4 + device mesh = [[ 0, 1], + [ 2, 3]] + q/k/v, z, m = [[ q0, q0], [[ k0, k0], [[z00, z01], [[m00, m01], + [ q1, q1]] [ k1, k1]] z10, z11]] m10, m11]] + + step = 0 + k_t, z = [[ k0, k1], [[z00, z01], + [ k0, k1]] z10, z11]] + + step = 1 (roll to the left) + k_t, z = [[ k1, k0], [[z01, z00], + [ k1, k0]] z11, z10]] + + now we roll left k, v, z as the ring attention outer loop. + + o = [[ (q0k0+z00+m00)v0 + (q0k1+z01+m01)v1, (q0k1+z01+m01)v1 + (q0k0+z00+m00)v0 ], + [ (q1k0+z10+m10)v0 + (q1k1+z11+m11)v1, (q1k1+z11+m11)v1 + (q1k0+z10+m10)v0 ]] + + Parameters + ---------- + q : Tensor + Query tensor (B, I, H, D) + k : Tensor + Key tensor (B, J, H, D) + v : Tensor + Value tensor (B, J, H, D) + z : Tensor + Pair bias projection (B, H, I, J) + pair_mask : Tensor + Pair bias mask (B, 1, I, J) + ring_comm : AttentionPairBiasComm + Ring communication for async operation + inf : float + Infinity value for masking, by default 1e6 + + Returns + ------- + Tensor + Output tensor (B, I, H, D) + RingAttentionSimpleDataForBackward + Data for backward pass + """ + if sdpa_with_bias_backend == SDPAWithBiasBackend.TORCH_SDPA_EFFICIENT_ATTENTION: + warnings.warn("torch_sdpa backend is not implemented and will fall back to flex_attention backend") + sdpa_with_bias_backend = SDPAWithBiasBackend.TORCH_FLEX_ATTN + + # flex_attention (compiled Triton/Inductor) requires power-of-2 head_dim and head_dim >= 16. + use_flex_attn = ( + sdpa_with_bias_backend == SDPAWithBiasBackend.TORCH_FLEX_ATTN + and HAS_FLEX_ATTN + and q.is_cuda + and q.dtype != torch.float64 + and is_power_of_2(q.size(-1)) + and q.size(-1) >= 16 # head_dim >= 16 + ) + + B_q, S, H, D = q.shape + B_z_orig = z.shape[0] + multiplicity = B_q // B_z_orig + + embed_dim = q.size(-1) + + # save input for backward + requires_grad = q.requires_grad or k.requires_grad or v.requires_grad or z.requires_grad + if requires_grad: + q_store = q.detach() + else: + q_store = None + + # Overlap k, v comm with mask addition + k_recv = ring_comm.comm_transpose_k.enqueue_to_dispatch(k.contiguous()) + v_recv = ring_comm.comm_transpose_v.enqueue_to_dispatch(v.contiguous()) + z = z.contiguous() + (1 - pair_mask) * -inf + ring_comm.comm_transpose_k.wait_until_finished() + ring_comm.comm_transpose_v.wait_until_finished() + k = k_recv + v = v_recv + + if requires_grad: + k_t_store = k.detach() + v_t_store = v.detach() + else: + k_t_store, v_t_store = None, None + + # save input for backward + if requires_grad: + z_store = z.detach() + else: + z_store = None + + # Ring attention + o: Union[Tensor, None] = None + lse_m: Union[Tensor, None] = None + amax: Union[Tensor, None] = None + + size_device_grid0 = ring_comm.group_layout.shape[0] + for step in range(size_device_grid0): + # Overlap k, v, z+m roll-left comm + if step + 1 != size_device_grid0: + next_k = ring_comm.comm_k.enqueue_to_dispatch(k) + next_v = ring_comm.comm_v.enqueue_to_dispatch(v) + next_z = ring_comm.comm_z.enqueue_to_dispatch(z) + + # Attention by chunk + if use_flex_attn: + # Flex attention expects (B, H, S, D) + q_f = q.transpose(1, 2) + k_f = k.transpose(1, 2) + v_f = v.transpose(1, 2) + + if multiplicity > 1: + # flex_attention doesn't support broadcasting batch dim for bias yet + # so we use a trick: we reshape B_q * H into the head dimension + # and use (h // H) // multiplicity to index into z + q_f = q_f.reshape(1, B_q * H, S, D) + k_f = k_f.reshape(1, B_q * H, S, D) + v_f = v_f.reshape(1, B_q * H, S, D) + + def score_mod(score, b, h, q_idx, kv_idx): + return score + z[(h // H) // multiplicity, h % H, q_idx, kv_idx] + + else: + # B_q == B_z + def score_mod(score, b, h, q_idx, kv_idx): + return score + z[b, h, q_idx, kv_idx] + + # flex_attention_compiled with float32 requires full precision + with setup_tf32_env(Precision.FP32), torch.amp.autocast("cuda", enabled=False): + block_o, aux_data = flex_attention_compiled( + q_f, + k_f, + v_f, + score_mod=score_mod, + return_lse=True, + ) + + # block_o: (B_q, H, S, D) or (1, B*H, S, D) + # aux_data: (B_q, H, S) or (1, B*H, S) + block_o = block_o.reshape(B_q, H, S, D).transpose(1, 2) + block_lse_m = aux_data.reshape(B_q, H, S).transpose(1, 2).unsqueeze(-1) + block_amax = None # flex_attention doesn't return amax, but we can use lse + else: + attn = torch.einsum("bihd,bjhd->bhij", q, k) + attn = attn / (embed_dim**0.5) + + B_z = z.shape[0] + B_q = q.shape[0] + if B_q != B_z: + attn = (attn.view(B_z, -1, *attn.shape[1:]) + z.unsqueeze(1)).view_as(attn) + else: + attn = attn + z + + block_o = torch.einsum("bhij,bjhd->bihd", torch.softmax(attn, dim=-1), v) + block_amax = attn.amax(dim=-1, keepdim=True) + block_lse_m = torch.logsumexp(attn - block_amax, dim=-1, keepdim=True) + + block_lse_m = block_lse_m.transpose(-2, -3) + if block_amax is not None: + block_amax = block_amax.transpose(-2, -3) + + o, lse_m, amax = tiled_softmax_attention_update(block_o, block_lse_m, block_amax, o, lse_m, amax) + + # Get input for next round + if step + 1 != size_device_grid0: + ring_comm.comm_k.wait_until_finished() + ring_comm.comm_v.wait_until_finished() + ring_comm.comm_z.wait_until_finished() + k, v, z = next_k, next_v, next_z + + if requires_grad and not use_flex_attn: + if multiplicity > 1: + # Reduce amax across multiplicity + amax_t = amax.transpose(-2, -3) + amax_t_view = amax_t.view(B_z_orig, multiplicity, *amax_t.shape[1:]) + amax_reduced = amax_t_view.amax(dim=1) + z_store -= amax_reduced + + # Adjust lse_m to be relative to amax_reduced + # We use broadcasting via views to avoid memory-heavy repeat_interleave + amax_reduced_t = amax_reduced.transpose(-2, -3) # (B, I, H, 1) + lse_m_view = lse_m.view(B_z_orig, multiplicity, *lse_m.shape[1:]) + amax_view = amax.view(B_z_orig, multiplicity, *amax.shape[1:]) + + lse_m_view += amax_view - amax_reduced_t.unsqueeze(1) # modify lse_m inplace via view + else: + z_store -= amax.transpose(-2, -3) + + data_for_backward = RingAttentionSimpleDataForBackward( + q_store=q_store, + k_t_store=k_t_store, + v_t_store=v_t_store, + z_store=z_store, + lse_m=lse_m, + ring_comm=ring_comm, + sdpa_with_bias_backend=sdpa_with_bias_backend, + o_store=o.detach() if requires_grad else None, + multiplicity=multiplicity, + ) + + return o, data_for_backward + + +def ring_attention_simple_backward( + data_for_backward: RingAttentionSimpleDataForBackward, + do: Tensor, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Backward pass of ring attention. + + example sharding strategy on N_tokens=2, world_size=4 + device mesh = [[ 0, 1], + [ 2, 3]] + + do = [[ do0, do0], + [ do1, do1]] + + ctx should have saved: + q , z = [[ q0, q0], [[z00, z01], + [ q1, q1]] [z10, z11]] + k_t, v_t = [[ k0, k1], [[v0, v1], + [ k0, k1]] [v0, v1]] + lse = [[ lse0, lse0], + [ lse1, lse1]] + + where z has already accounted for the pair mask + + The output gradient should be sharded as follows: + dq = [[ dq0, dq0], + [ dq1, dq1]] + dk = [[ dk0, dk0], + [ dk1, dk1]] + dv = [[ dv0, dv0], + [ dv1, dv1]] + dz = [[ dz00, dz01], + [ dz10, dz11]] + + Parameters + ---------- + data_for_backward : RingAttentionSimpleDataForBackward + Data for backward pass + do : Tensor + Gradient of output tensor + + Returns + ------- + tuple[Tensor, Tensor, Tensor, Tensor] + Gradients for q, k, v, z + """ + q, k_t, v_t, z, lse_m, ring_comm, sdpa_with_bias_backend, out_global, multiplicity = data_for_backward + if sdpa_with_bias_backend == SDPAWithBiasBackend.TORCH_SDPA_EFFICIENT_ATTENTION: + raise NotImplementedError("torch_sdpa backend is not implemented") + + embed_dim = q.size(-1) + + # Pre-calculate out_global * do for softmax gradient + # do: (B, S, H, D) + # out_global: (B, S, H, D) + # do_o: (B, S, H, 1) -> will be reshaped to (B, H, S, 1) for the reference path + do_o = torch.sum(do * out_global, dim=-1, keepdim=True) + + lse_m_t = lse_m.transpose(1, 2) # (B, S, H, 1) -> (B, H, S, 1) + + B_q = q.shape[0] + + # flex_attention (compiled Triton/Inductor) requires power-of-2 head_dim and head_dim >= 16. + use_flex_attn = ( + sdpa_with_bias_backend == SDPAWithBiasBackend.TORCH_FLEX_ATTN + and HAS_FLEX_ATTN + and q.is_cuda + and q.dtype != torch.float64 + and is_power_of_2(q.size(-1)) + and q.size(-1) >= 16 # head_dim >= 16 + ) + + if use_flex_attn: + # Re-run forward to get aux data + # In a real implementation we would save this + # but for Ring Attention we might need to re-run or save it per chunk + B_q, S, H, D = q.shape + + with torch.enable_grad(): + q_l = q.detach().requires_grad_(True) + k_l = k_t.detach().requires_grad_(True) + v_l = v_t.detach().requires_grad_(True) + z_l = z.detach().requires_grad_(True) + + q_f = q_l.transpose(1, 2) + k_f = k_l.transpose(1, 2) + v_f = v_l.transpose(1, 2) + + if multiplicity > 1: + q_f = q_f.reshape(1, B_q * H, S, D) + k_f = k_f.reshape(1, B_q * H, S, D) + v_f = v_f.reshape(1, B_q * H, S, D) + + def score_mod(score, b, h, q_idx, kv_idx): + return score + z_l[(h // H) // multiplicity, h % H, q_idx, kv_idx] + + else: + # B_q == B_z + def score_mod(score, b, h, q_idx, kv_idx): + return score + z_l[b, h, q_idx, kv_idx] + + # flex_attention_compiled with float32 requires full precision + # Request LSE to perform the global scaling trick + # This ensures the local gradients match the global context + with setup_tf32_env(Precision.FP32), torch.amp.autocast("cuda", enabled=False): + out_l, aux_l = flex_attention_compiled( + q_f, + k_f, + v_f, + score_mod=score_mod, + return_lse=True, + ) + + out_l = out_l.reshape(B_q, H, S, D).transpose(1, 2) + lse_l = aux_l.reshape(B_q, H, S, 1).transpose(1, 2) + + size_device_grid0 = ring_comm.group_layout.shape[0] + if size_device_grid0 == 1: + grad_inputs = torch.autograd.grad(out_l, (q_l, k_l, v_l, z_l), do, allow_unused=True) + else: + # Scaling trick: Global contribution of this chunk + # Contribution_global = (out_local - out_global) * exp(lse_local - lse_global) + # d(Contribution_global)/dx gives the correct softmax gradient including denominator + out_scaled = (out_l - out_global) * torch.exp(lse_l - lse_m) + grad_inputs = torch.autograd.grad(out_scaled, (q_l, k_l, v_l, z_l), do, allow_unused=True) + + dq, dk, dv, dz = grad_inputs + dq = dq if dq is not None else torch.zeros_like(q_l) + dk = dk if dk is not None else torch.zeros_like(k_l) + dv = dv if dv is not None else torch.zeros_like(v_l) + dz = dz if dz is not None else torch.zeros_like(z_l) + + if multiplicity > 1 and dz.shape[0] > (q.shape[0] // multiplicity): + B_z_orig = q.shape[0] // multiplicity + dz = dz.view(B_z_orig, multiplicity, *dz.shape[1:]).sum(dim=1) + else: + # Compute S_ij and A_ij + s = torch.einsum("bihd,bjhd->bhij", q, k_t) + s /= embed_dim**0.5 + + # Memory efficient in-place softmax reconstruction + s.sub_(lse_m_t) + if multiplicity > 1: + B_z_orig = q.shape[0] // multiplicity + if z.shape[0] == q.shape[0]: + s.view(B_z_orig, multiplicity, *s.shape[1:]).add_(z.view(B_z_orig, multiplicity, *z.shape[1:])) + else: + s.view(B_z_orig, multiplicity, *s.shape[1:]).add_(z.unsqueeze(1)) + else: + s.add_(z) + + a = s.exp_() + + # Compute gradient of v and c + # dV_j = \Sum_{i} A^T_{ij} dO_i + # c_i = \Sum_{k} v^T_k A_{ik} + dv = torch.einsum("bihd,bhij->bjhd", do, a) + + # Compute gradient of S (dS = dz) + # dS_{ij} = A_{ij} * (dO_i v^T_j - (dO_i * out_global_i)) + # do_o has shape (B, S, H, 1), we need (B, H, S, 1) to match a + do_o_step = do_o.transpose(1, 2) + + # In-place compute dS to save memory + tmp = torch.einsum("bihd,bjhd->bhij", do, v_t) + tmp.sub_(do_o_step) + a.mul_(tmp) + del tmp + + dS = a + dz = a + if multiplicity > 1 and dz.shape[0] > (q.shape[0] // multiplicity): + B_z_orig = q.shape[0] // multiplicity + dz = dz.view(B_z_orig, multiplicity, *dz.shape[1:]).sum(dim=1) + + # Compute gradient of q, k + # dq_i = \Sum_{j} dS_{ij} k^T_j + # dk_j = \Sum_{i} dS^T_{ij} q_i + dq = torch.einsum("bhij,bjhd->bihd", dS, k_t) / ( + embed_dim**0.5 + ) # _t here refers to transposition across device mesh + dk = torch.einsum("bhij,bihd->bjhd", dS, q) / (embed_dim**0.5) + + dv_recv = ring_comm.comm_transpose_v.enqueue_to_dispatch(dv.contiguous()) + + # collect and complete dv reduction + ring_comm.comm_transpose_v.wait_until_finished() + dv_work = dist.all_reduce(dv_recv, op=dist.ReduceOp.SUM, group=ring_comm.cp_axis_1_group, async_op=True) + + dq = dq.contiguous() + dq_work = dist.all_reduce(dq, op=dist.ReduceOp.SUM, group=ring_comm.cp_axis_1_group, async_op=True) + + dk_recv = ring_comm.comm_transpose_k.enqueue_to_dispatch(dk.contiguous()) + ring_comm.comm_transpose_k.wait_until_finished() + dk_work = dist.all_reduce(dk_recv, op=dist.ReduceOp.SUM, group=ring_comm.cp_axis_1_group, async_op=True) + + # Collect all async works + dq_work.wait() + dk_work.wait() + dv_work.wait() + + return dq, dk_recv, dv_recv, dz + + +class RingAttentionSimple: + """Ring attention pair bias with context parallelism. + + This class serves as a namespace for the ring attention forward and + backward functions. The definitions of these functions outside an + autograd function subclass, are useful to encapsulate the communication + and math logic that is used in _AttentionPairBiasContexVecImpl, but + also may be used elsewhere in the codebase via the simple function + ring_attention above. + """ + + @staticmethod + def forward( + q: Tensor, + k: Tensor, + v: Tensor, + z: Tensor, + pair_mask: Tensor, + mask: Tensor, + ring_comm: AttentionPairBiasComm, + inf: float = 1e6, + ) -> tuple[Tensor, RingAttentionSimpleDataForBackward]: + assert isinstance(q, Tensor), f"q must be a Tensor, got {type(q)}" + assert isinstance(k, Tensor), f"k must be a Tensor, got {type(k)}" + assert isinstance(v, Tensor), f"v must be a Tensor, got {type(v)}" + assert isinstance(z, Tensor), f"z must be a Tensor, got {type(z)}" + assert isinstance(pair_mask, Tensor), f"pair_mask must be a Tensor, got {type(pair_mask)}" + assert isinstance(mask, Tensor), f"mask must be a Tensor, got {type(mask)}" + assert isinstance( + ring_comm, AttentionPairBiasComm + ), f"ring_comm must be a AttentionPairBiasComm, got {type(ring_comm)}" + return ring_attention_simple_forward(q, k, v, z, pair_mask, mask, ring_comm, inf) + + @staticmethod + def backward( + data_for_backward: RingAttentionSimpleDataForBackward, + do: Tensor, + ) -> tuple[Tensor, Tensor, Tensor, Tensor]: + return ring_attention_simple_backward(data_for_backward, do) diff --git a/src/boltz/distributed/model/layers/cat_and_chunk.py b/src/boltz/distributed/model/layers/cat_and_chunk.py new file mode 100644 index 000000000..2bbda8347 --- /dev/null +++ b/src/boltz/distributed/model/layers/cat_and_chunk.py @@ -0,0 +1,417 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import torch +from torch import Tensor +from torch.autograd.function import FunctionCtx +from torch.distributed.tensor import DTensor, Partial, Shard + +from boltz.distributed.model.layers.dtensor_metadata_tools import ( + raise_if_incorrect_dtensor_metadata_args, +) +from boltz.distributed.utils import update_exhaustive_strides + + +def _shardwise_chunk( + x: DTensor, + chunks: int, + dim: int, +) -> tuple[DTensor, ...]: + """Generalized shardwise chunking operation with validation. + + This function performs input validation and chunking operation. + + Parameters + ---------- + x : DTensor + Input DTensor to chunk. + chunks : int + Number of chunks to split the dimension into. + dim : int + Dimension to chunk along. + + Returns + ------- + tuple[DTensor, ...] + Tuple of DTensor instances after chunking. + + Raises + ------ + TypeError + See raise_if_incorrect_dtensor_metadata_args for more details. + ValueError + Checks on the input x and parameters. + """ + dim_normalized = dim if dim >= 0 else dim + x.ndim + # do not allow chunking on a dimension with a sharded placement + for i_dim_device_mesh, placement in enumerate(x.placements): + if isinstance(placement, Shard): + if placement.dim == dim_normalized: + raise NotImplementedError( + f"Chunking along dimension {dim} shared by device_mesh axis {i_dim_device_mesh} is not supported" + ) + if x.shape[placement.dim] % x.device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {x.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {x.device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + elif isinstance(placement, Partial): + raise ValueError(f"Placements of type {Partial} are not supported") + + x_local = x.to_local() + + # Perform operation on local tensors + x_chunks_local: tuple[Tensor, ...] = x_local.chunk(chunks=chunks, dim=dim) + + shapes_output = [] + strides_output = [] + + for chunk in x_chunks_local: + shape_output = list(x.shape) + shape_output[dim_normalized] = chunk.shape[dim_normalized] + # we try to be as consistent with the original stride as possible + # but in principle there is no way to keep the resulting chunked DTensor.full_tensor + # as views of the original DTensor.full_tensor because upon calling this function, + # the latter is not materialized in memory yet + stride_output = update_exhaustive_strides(x.shape, x.stride(), shape_output) + shapes_output.append(tuple(shape_output)) + strides_output.append(tuple(stride_output)) + + # Create output tuple of DTensor using input tensor's device mesh and placements + # leave the dim check to the torch.Tensor.chunk + x_in_chunks: tuple[DTensor, ...] = tuple( + [ + DTensor.from_local( + chunk, + device_mesh=x.device_mesh, + placements=x.placements, + shape=shapes_output[i], + stride=strides_output[i], + ) + for i, chunk in enumerate(x_chunks_local) + ] + ) + return x_in_chunks + + +def _shardwise_cat( + *inputs: DTensor, dim: int, shape: tuple[int, ...] | None = None, stride: tuple[int, ...] | None = None +) -> DTensor: + """Generalized shardwise concatenation operation with validation. + + This function performs input validation and concatenation operation. + + Parameters + ---------- + *inputs : DTensor + Variable number of input DTensors to concatenate with dim as the last argument + dim : int + Dimension to concatenate along + shape : tuple[int, ...], optional + Shape of the output DTensor. If not provided, will infer from the inputs. + stride : tuple[int, ...], optional + Stride of the output DTensor. If not provided, will infer from the inputs. + + Returns + ------- + DTensor + Concatenated DTensor. + + Raises + ------ + TypeError, ValueError + Checks on the input tensors and parameters. + """ + # Validate inputs + if len(inputs) == 0: + raise ValueError("Cannot concatenate empty list of tensors.") + + if (shape is None) != (stride is None): + raise ValueError("Either both shape and stride must be provided or neither") + + first_input = inputs[0] + + placements = first_input.placements + device_mesh = first_input.device_mesh + + # DTensor Shard(dim) is always normalized to be non-negative + dim_normalized = dim if dim >= 0 else dim + first_input.ndim + # Check that concatenation dimension is not sharded + # and there are no unevenly sharded tensor axes + for i_dim_device_mesh, placement in enumerate(placements): + if isinstance(placement, Shard): + i_dim_tensor = placement.dim + if i_dim_tensor == dim_normalized: + # unevenly sharded inputs are not supported + raise NotImplementedError( + f"Concatenation along dimension {dim} shared by device_mesh axis {i_dim_device_mesh} is not supported" + ) + if first_input.shape[i_dim_tensor] % device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {i_dim_tensor} of size {first_input.shape[i_dim_tensor]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + elif isinstance(placement, Partial): + raise ValueError(f"Placements of type {Partial} are not supported") + + # Check that all inputs have compatible metadata + for i, input_tensor in enumerate(inputs[1:], 1): + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=input_tensor, + dtensor_name=f"_shardwise_cat inputs[{i}]", + expected_device_mesh=device_mesh, + expected_placements=placements, + ) + + if shape is None: + shape = list(first_input.shape) + shape[dim_normalized] = sum(input_dtensor.shape[dim_normalized] for input_dtensor in inputs) + shape = tuple(shape) + stride = update_exhaustive_strides(first_input.shape, first_input.stride(), shape) + + # passing the previous checks implies: all inputs have same device_mesh and same + # placements. The following torch.cat will check for the local shards in the list + # for the same shape. If the torch.cat is successful, it means that the inputs + # will have same shape along all the sharded axes (given that dim can't be sharded) + # and hence with evenly sharded axes (because first.input is evenly sharded) + + # Perform operation on local tensors + # leave the dim and shape check to the torch.Tensor.cat + output_local: Tensor = torch.cat([x.to_local() for x in inputs], dim=dim) + + # Create output DTensor using first input's device mesh and placements + output: DTensor = DTensor.from_local( + output_local, + device_mesh=device_mesh, + placements=placements, + shape=shape, + stride=stride, + ) + + return output + + +class _ShardwiseChunkImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx: FunctionCtx, + x: DTensor, + chunks: int, + dim: int = -1, + ) -> tuple[DTensor, ...]: + """Forward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object for saving information needed in backward pass. + x : DTensor + Input DTensor. + chunks : int + Number of chunks to split the dimension into. + dim : int, optional + Dimension to chunk along. Default is -1 (last dimension). + + Returns + ------- + tuple[DTensor, ...] + Tuple of DTensor instances after chunking. + """ + # Perform chunking with built-in validation + result = _shardwise_chunk(x, chunks, dim) + + ctx.dim = dim + ctx.device_mesh_input = x.device_mesh + ctx.placements_input = x.placements + ctx.shape_input = x.shape + ctx.stride_input = x.stride() + + return result + + @staticmethod + def backward( + ctx: FunctionCtx, + *grad_outputs: tuple[DTensor, ...], + ) -> tuple[DTensor, None, None]: + """Backward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object containing saved tensors and metadata from forward pass. + grad_outputs : tuple[DTensor, ...] + Gradient of the loss with respect to each component of the output. + + Returns + ------- + tuple[DTensor, None, None] + Gradient with respect to input, None for chunks and dim parameters. + """ + # Use cat operation (inverse of chunk) for backward pass with built-in validation + # no need for shape consistency check here because: + # 1. torch.autograd.backward will check output's shape against grad_outputs shape + # 2. the underlying torch.cat will check for shape along other axes than "dim" + # 3. mismatching shape in the cat grad will be caught upon attaching to the input tensor + grad_x = _shardwise_cat(*grad_outputs, dim=ctx.dim, shape=ctx.shape_input, stride=ctx.stride_input) + + return grad_x, None, None + + +class _ShardWiseCatImpl(torch.autograd.Function): + @staticmethod + def forward(ctx: FunctionCtx, *inputs) -> DTensor: + """Forward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object for saving information needed in backward pass. + *inputs : DTensor + Variable number of input DTensors to concatenate with dim as the last argument + + Returns + ------- + DTensor + Concatenated DTensor. + """ + # Perform concatenation with built-in validation + tensors_to_cat = inputs[:-1] + dim = inputs[-1] + result = _shardwise_cat(*tensors_to_cat, dim=dim) + + # Save metadata for backward pass + # shardwise_cat guarantee same device_mesh and placements as in the inputs + ctx.device_mesh = result.device_mesh + ctx.placements = result.placements + ctx.dim = dim + ctx.n_chunks = len(tensors_to_cat) + ctx.shapes_and_strides_input = [ + (input_dtensor.shape, input_dtensor.stride()) for input_dtensor in tensors_to_cat + ] + + # Save the sizes of each input tensor in the concatenation dimension for backward pass + ctx.split_sizes = [input_tensor.shape[dim] for input_tensor in tensors_to_cat] + + return result + + @staticmethod + def backward( + ctx: FunctionCtx, + grad_output: DTensor, + ) -> tuple[DTensor, ...]: + """Backward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object containing saved tensors and metadata from forward pass. + grad_output : DTensor + Gradient of the loss with respect to the output. + + Returns + ------- + tuple[DTensor, ...] + Gradients with respect to each input tensor, None for dim parameter. + """ + # Verify grad_output has the same device_mesh and placements as expected + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=grad_output, + dtensor_name="_ShardWiseCatImpl.backward grad_output", + expected_device_mesh=ctx.device_mesh, + expected_placements=ctx.placements, + ) + + # Use local torch.split for backward pass + grad_output_local = grad_output.to_local() + grad_inputs_local = torch.split(grad_output_local, ctx.split_sizes, dim=ctx.dim) + + # Wrap each split back into DTensor using the saved metadata + grad_inputs = tuple( + DTensor.from_local( + grad_local, + device_mesh=ctx.device_mesh, + placements=ctx.placements, + shape=ctx.shapes_and_strides_input[i][0], + stride=ctx.shapes_and_strides_input[i][1], + ) + for i, grad_local in enumerate(grad_inputs_local) + ) + + return (*grad_inputs, None) + + +def shardwise_chunk(x: DTensor, chunks: int, dim: int = -1) -> tuple[DTensor, ...]: + """Chunk a DTensor along a specified dimension. + + This function splits a DTensor into chunks along the specified dimension. + The dimension must not be sharded on the device mesh. + + Parameters + ---------- + x : DTensor + Input DTensor to chunk. + chunks : int + Number of chunks to split the dimension into. + dim : int, optional + Dimension to chunk along. Default is -1 (last dimension). + + Returns + ------- + tuple[DTensor, ...] + Tuple of DTensor instances after chunking. + + Raises + ------ + ValueError + If the specified dimension is sharded or other validation errors. + """ + return _ShardwiseChunkImpl.apply(x, chunks, dim) + + +def shardwise_cat(inputs: list[DTensor], dim: int = -1) -> DTensor: + """Concatenate DTensors along a specified dimension. + + This function concatenates DTensors along the specified dimension. + The dimension must not be sharded on the device mesh, and all tensors + must have compatible shapes except in the concatenation dimension. + + Parameters + ---------- + inputs : list[DTensor] + List of DTensors to concatenate. + dim : int, optional + Dimension to concatenate along. Default is -1 (last dimension). + + Returns + ------- + DTensor + Concatenated DTensor. + + Raises + ------ + ValueError + If the specified dimension is sharded or other validation errors. + """ + if not inputs: + raise ValueError("Cannot concatenate empty list of tensors.") + + return _ShardWiseCatImpl.apply(*inputs, dim) diff --git a/src/boltz/distributed/model/layers/clip.py b/src/boltz/distributed/model/layers/clip.py new file mode 100644 index 000000000..0ce5f6e31 --- /dev/null +++ b/src/boltz/distributed/model/layers/clip.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from typing import Optional + +import torch +from torch.distributed.tensor import DTensor, Partial, Shard + + +class _ClipImpl(torch.autograd.Function): + """Distributed implementation of clipping operation using DTensors. + + This autograd function implements distributed clipping operations that constrain + tensor values to be within specified bounds. The operation is performed element-wise + across distributed tensors while maintaining proper gradient computation. + + Supported operations: + - CLIP: output = torch.clip(tensor, min=min_val, max=max_val) + + Key features: + - Distributed computation across device meshes with various sharding strategies + - Memory-efficient implementation that operates on local tensor chunks + - Supports gradient computation through custom backward pass + - Supports both min and max clipping bounds (either can be None) + """ + + @staticmethod + def forward(ctx, tensor: DTensor, min_val: Optional[float] = None, max_val: Optional[float] = None) -> DTensor: + """Forward pass of distributed clipping operation. + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object for saving information needed in backward pass. + tensor : DTensor + Input tensor. Can have any shape and sharding strategy. + min_val : Optional[float], default None + Minimum value for clipping. If None, no lower bound is applied. + max_val : Optional[float], default None + Maximum value for clipping. If None, no upper bound is applied. + + Returns + ------- + DTensor + Output tensor with shape identical to input tensor. + Contains the result of the clipping operation. + + Raises + ------ + TypeError + If inputs are not of expected types. + ValueError + If Partial placements are used (not supported), or if both min_val and max_val are None. + """ + if not isinstance(tensor, DTensor): + raise TypeError(f"Input 'tensor' must be of type DTensor. Got type {type(tensor)}.") + + if min_val is None and max_val is None: + raise ValueError("At least one of min_val or max_val must be specified for clipping.") + + device_mesh_input = tensor.device_mesh + placements_input = tensor.placements + + for i_dim_device_mesh, placement in enumerate(placements_input): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + if tensor.shape[placement.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {tensor.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + + tensor_local = tensor.to_local() + + # Perform the clipping operation + output_local = torch.clip(tensor_local, min=min_val, max=max_val) + + if tensor.requires_grad: + # Pre-allocate mask in bool to save memory from saving a float tensor_local copy + mask_local = torch.ones_like(tensor_local, dtype=torch.bool) + if min_val is not None: + mask_local = mask_local & (tensor_local >= min_val) # inclusive in torch.clip + if max_val is not None: + mask_local = mask_local & (tensor_local <= max_val) # inclusive in torch.clip + ctx.save_for_backward(mask_local) + ctx.device_mesh_input = device_mesh_input + ctx.placements_input = placements_input + ctx.input_shape = tensor.shape + ctx.min_val = min_val + ctx.max_val = max_val + + out = DTensor.from_local( + output_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=tensor.shape, + stride=tensor.stride(), + ) + return out + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[DTensor | None, None, None]: + """Backward pass of distributed clipping operation. + + Computes gradients with respect to the input tensor. + + The gradient is: + - For CLIP: d_tensor = grad_output * mask + where mask = 1 for elements within bounds, 0 for clipped elements + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object containing saved tensors and metadata from forward pass. + grad_outputs : tuple + Gradients of the loss with respect to the output tensors. + + Returns + ------- + tuple[DTensor | None, None, None] + Gradients with respect to tensor, min_val, and max_val parameters. + Only tensor gradient is computed; min_val and max_val gradients are None. + """ + if not ctx.needs_input_grad[0]: + return None, None, None + + if not isinstance(grad_output, DTensor): + raise TypeError(f"Input 'grad_output' must be of type DTensor. Got type {type(grad_output)}.") + + if grad_output.device_mesh != ctx.device_mesh_input: + raise ValueError( + f"Input 'grad_output' must have the same device mesh as the input tensor. " + f"Got device meshes {grad_output.device_mesh} and {ctx.device_mesh_input}." + ) + + if grad_output.placements != ctx.placements_input: + raise ValueError( + f"Input 'grad_output' must have the same placements as the input tensor. " + f"Got placements {grad_output.placements} and {ctx.placements_input}." + ) + + grad_output_local = grad_output.to_local() + (mask_local,) = ctx.saved_tensors + + # Compute gradient mask: 1 for elements within bounds, 0 for clipped elements + d_tensor_local = grad_output_local * mask_local + d_tensor = DTensor.from_local( + d_tensor_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + + return d_tensor, None, None + + +def clip(tensor: DTensor, min_val: Optional[float] = None, max_val: Optional[float] = None) -> DTensor: + """Apply clipping operation to a distributed tensor. + + This function constrains tensor values to be within specified bounds. + Elements below min_val are set to min_val, and elements above max_val are set to max_val. + The operation is performed efficiently using local tensor operations while maintaining + gradient computation capabilities. + + Parameters + ---------- + tensor : DTensor + Input tensor. Can have any shape and sharding strategy. + min_val : Optional[float], default None + Minimum value for clipping. If None, no lower bound is applied. + max_val : Optional[float], default None + Maximum value for clipping. If None, no upper bound is applied. + + Returns + ------- + DTensor + Output tensor with shape identical to input tensor. + Contains the result of the clipping operation. + + Examples + -------- + >>> # Assume we have distributed tensor x with shape (B, N, D) + >>> clipped_positive = clip(x, min_val=0.0) + >>> # clipped_positive = torch.clip(x, min=0.0), computed in distributed fashion + >>> + >>> clipped_range = clip(x, min_val=-1.0, max_val=1.0) + >>> # clipped_range = torch.clip(x, min=-1.0, max=1.0), computed in distributed fashion + >>> + >>> clipped_max = clip(x, max_val=10.0) + >>> # clipped_max = torch.clip(x, max=10.0), computed in distributed fashion + + Notes + ----- + - Input tensor must be a DTensor with any placement strategy + - Partial placements are not currently supported + - The function is differentiable and supports gradient computation + - The operation is performed on local tensor chunks for efficiency + - At least one of min_val or max_val must be specified + + Raises + ------ + TypeError + If input tensor is not a DTensor. + ValueError + If Partial placements are used (not supported), or if both min_val and max_val are None. + """ + return _ClipImpl.apply(tensor, min_val, max_val) # type: ignore diff --git a/src/boltz/distributed/model/layers/distribute_module_tools.py b/src/boltz/distributed/model/layers/distribute_module_tools.py new file mode 100644 index 000000000..0b35dada4 --- /dev/null +++ b/src/boltz/distributed/model/layers/distribute_module_tools.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from typing import Union + +from einops.layers.torch import Rearrange +from torch import nn +from torch.distributed.device_mesh import DeviceMesh +from torch.nn import Module + +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.linear import LinearParamsReplicated + + +def _convert_each_child_module_to_dtensor_compatible_version( + module: Module, + input_module: Module, + device_mesh: DeviceMesh, + reduction: str = "contain", + module_name: Union[str, None] = None, +): + """ + This function creates an attribute for module that is a DTensor + compatible version of each of the child modules of input_module. + + For example, in the module.__init__(), this function can be called to + create attributes of module. The user may still have some additional + steps to take in module.forward() to use the DTensor API. + + Parameters + ---------- + module : Module + The module that will be modified in place to have a DTensor API. + input_module : Module + The non-dtensor module with child modules that do not use DTensor. + device_mesh : DeviceMesh + The device mesh. + reduction : str, optional + - "contain": An attribute is created for module that is a DTensor + compatible version of a non-dtensor child of input_module. + - "sequential": An attribute is created for module that is a + module_name : str, optional + The name of the module to be replaced. + This is only used when reduction is "sequential". + If not provided, the module name is inferred from the input module. + + Returns + ------- + None + + Raises + ------ + NotImplementedError + If the dtensor version of the child module is not implemented. + """ + names_to_be_replaced: list[str] = [] + child_replacements: list[Module] = [] + for name, child in input_module.named_children(): + if isinstance(child, nn.Linear): + names_to_be_replaced.append(name) + child_replacements.append(LinearParamsReplicated(child, device_mesh=device_mesh)) + + elif isinstance(child, nn.LayerNorm): + names_to_be_replaced.append(name) + child_replacements.append(LayerNormParamsReplicated(child, device_mesh=device_mesh)) + + elif isinstance(child, nn.Sequential): + _convert_each_child_module_to_dtensor_compatible_version( + module, child, device_mesh=device_mesh, reduction="sequential", module_name=name + ) + + elif isinstance(child, Rearrange): + # Not yet implemented + pass + + else: + raise NotImplementedError + + if reduction == "contain": + for name, child_replacement in zip(names_to_be_replaced, child_replacements): + setattr(module, name, child_replacement) + elif reduction == "sequential" and module_name is not None: + setattr(module, module_name, nn.Sequential(*child_replacements)) + else: + raise NotImplementedError diff --git a/src/boltz/distributed/model/layers/dropout.py b/src/boltz/distributed/model/layers/dropout.py new file mode 100644 index 000000000..d21c4c172 --- /dev/null +++ b/src/boltz/distributed/model/layers/dropout.py @@ -0,0 +1,289 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import torch +from torch.distributed.tensor import DTensor, Replicate, Shard + + +class _ApplyDropoutMaskMsaOrPairImpl(torch.autograd.Function): + """Distributed implementation of apply_dropout_mask_msa_or_pair using DTensor.""" + + @staticmethod + def forward( + ctx, src: DTensor, dropout: float, training: bool, columnwise: bool | None, samples_dropout: DTensor | None + ) -> DTensor: + """Forward pass of distributed dropout samples_dropout application. + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object for saving information needed in backward pass. + src : DTensor + The source tensor to apply dropout to. + dropout : float + The dropout rate between 0.0 and 1.0. + training : bool + Whether the model is in training mode. + columnwise : bool, optional + If True, applies the same samples_dropout to all elements in each column, by default False. + samples_dropout : DTensor, optional + These are the uniform random numbers drawn from [0, 1) and samples_dropout >= dropout will be + used as the dropout mask. If None, a new samples_dropout is created. Note that currently + there is no effective way to generate consistent random number sequences between the + serial and distributed versions so this argument can be used to passed in pre-generated + samples from the serial version sliced according to the input "src" placements as a mock + to reproduce the random number sequence in the distributed version. We use this method + in the tests to verify the distributed version is consistent with the serial version. + + + Returns + ------- + DTensor + The source tensor with dropout applied during training, or unchanged during inference. + """ + # Check if inputs are of type DTensor + if not isinstance(src, DTensor): + raise TypeError(f"Input 'src' must be of type DTensor. Got type {type(src)}.") + + # Verify that src is 4-dimensional as required for indexing patterns + if src.ndim != 4: + raise ValueError(f"Input tensor 'src' must be 4-dimensional. Got {src.ndim} dimensions.") + + if samples_dropout is not None: + if not isinstance(samples_dropout, DTensor): + raise TypeError(f"Input 'samples_dropout' must be of type DTensor. Got type {type(samples_dropout)}.") + + if samples_dropout.ndim != 4: + raise ValueError( + f"Input tensor 'samples_dropout' must be 4-dimensional. Got {samples_dropout.ndim} dimensions." + ) + + if samples_dropout.device_mesh != src.device_mesh: + raise ValueError( + f"Input tensor 'samples_dropout' must have the same device mesh as the input tensor. " + f"Got device meshes {samples_dropout.device_mesh} and {src.device_mesh}." + ) + + if samples_dropout.requires_grad: + raise ValueError( + "Input tensor 'samples_dropout' must not require gradients and its gradient computation is not supported" + ) + + if columnwise: + if samples_dropout.shape[1] != 1 or samples_dropout.shape[3] != 1: + raise ValueError( + f"Input tensor 'samples_dropout' must have shape [*, 1, *, 1] for columnwise dropout. Got {samples_dropout.shape}." + ) + if samples_dropout.placements != (Shard(0), Replicate(), Shard(2)): + raise ValueError( + f"Input tensor 'samples_dropout' must have placements (Shard(0), Replicate(), Shard(2)) for columnwise dropout. Got {samples_dropout.placements}." + ) + else: + if samples_dropout.shape[2] != 1 or samples_dropout.shape[3] != 1: + raise ValueError( + f"Input tensor 'samples_dropout' must have shape [*, *, 1, 1] for rowwise dropout. Got {samples_dropout.shape}." + ) + if samples_dropout.placements != (Shard(0), Shard(1), Replicate()): + raise ValueError( + f"Input tensor 'samples_dropout' must have placements (Shard(0), Shard(1), Replicate()) for rowwise dropout. Got {samples_dropout.placements}." + ) + + ctx.mark_non_differentiable(samples_dropout) + + # Verify that src.placements is exactly (Shard(0), Shard(1), Shard(2)) + expected_placements = (Shard(0), Shard(1), Shard(2)) + if src.placements != expected_placements: + raise ValueError( + f"Input tensor 'src' must have placements {expected_placements}. Got placements {src.placements}." + ) + + # Save context for backward pass + ctx.device_mesh_input = src.device_mesh + ctx.placements_input = src.placements + ctx.training = training + ctx.columnwise = columnwise + ctx.dropout = dropout + + # Extract local tensors + src_local = src.to_local() + + if training: + if samples_dropout is None: + # Create dropout samples_dropout using the same logic as the serial version + shape = list(src_local.shape) + if columnwise: + # equivalent to torch.rand_like(src_local[:, 0, :, 0] + shape[1] = 1 + shape[3] = 1 + else: + # equivalent to torch.rand_like(src_local[:, :, 0, 0]) + shape[2] = 1 + shape[3] = 1 + if torch.is_autocast_enabled("cuda"): + mask_dtype = torch.promote_types(src_local.dtype, torch.float32) + else: + mask_dtype = src_local.dtype + samples_dropout_local = torch.rand(shape, device=src_local.device, dtype=mask_dtype) + else: + samples_dropout_local = samples_dropout.to_local() + d = samples_dropout_local >= dropout + if torch.is_autocast_enabled("cuda"): + scale_dtype = torch.promote_types(src_local.dtype, torch.float32) + else: + scale_dtype = src_local.dtype + d = (d * 1.0 / (1.0 - dropout)).to(dtype=scale_dtype) + + # Apply dropout mask + result_local = src_local * d + + # Save dropout mask for backward pass + ctx.save_for_backward(d) + # Convert result back to DTensor + result = DTensor.from_local( + result_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=src.shape, + stride=src.stride(), + ) + + return result + else: + return src + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[DTensor, None, None, None, None]: + """Backward pass of distributed dropout mask application. + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object containing saved information from forward pass. + grad_output : DTensor + Gradient of the loss with respect to the output. + + Returns + ------- + tuple[DTensor, None, None, None, None] + Gradients with respect to inputs. Only src gets a gradient. + """ + if not isinstance(grad_output, DTensor): + raise TypeError(f"Input 'grad_output' must be of type DTensor. Got type {type(grad_output)}.") + + if grad_output.device_mesh != ctx.device_mesh_input: + raise ValueError( + f"Input 'grad_output' must have the same device mesh as the input tensors. " + f"Got device meshes {grad_output.device_mesh} and {ctx.device_mesh_input}." + ) + + if grad_output.placements != ctx.placements_input: + raise ValueError( + f"Input 'grad_output' must have the same placements as the input tensors. " + f"Got placements {grad_output.placements} and {ctx.placements_input}." + ) + + # Extract local gradient + grad_output_local = grad_output.to_local() + + if ctx.training: + # Extract saved dropout mask + (dropout_mask,) = ctx.saved_tensors + + # Apply the same dropout mask to the gradient + grad_src_local = grad_output_local * dropout_mask + else: + # During inference, gradient passes through unchanged + grad_src_local = grad_output_local + + # Convert gradient back to DTensor + grad_src = DTensor.from_local( + grad_src_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + + # Return gradients: (src, dropout, training, columnwise, mask) + # Only src needs gradient, others are None + return grad_src, None, None, None, None + + +def apply_dropout_mask_msa_or_pair( + src: DTensor, + dropout: float, + training: bool, + columnwise: bool = False, + samples_dropout: DTensor | None = None, +) -> DTensor: + """Apply dropout directly to the source DTensor for MSA or pair representations. + + This function applies dropout to the source DTensor using the shape of z as a reference. + It behaves like standard dropout during training, and is a no-op during inference. + + When columnwise=True, the same dropout mask is applied to all elements in the same column, + meaning that entire columns are either kept or dropped together. + + IMPORTANT: This function makes strong assumptions about tensor shape and indexing. + The reference tensor z must be indexable by [:, 0:1, :, 0:1] (columnwise=True) or + [:, :, 0:1, 0:1] (columnwise=False). This is specifically designed for MSA and pair + representation tensors with expected 4D structure. + + Parameters + ---------- + src : DTensor + The source DTensor to apply dropout to. Must have placements (Shard(0), Shard(1), Shard(2)). + dropout : float + The dropout rate between 0.0 and 1.0 + training : bool + Whether the model is in training mode + columnwise : bool, optional + If True, applies the same mask to all elements in each column, by default False + samples_dropout : DTensor, optional + These are the uniform random numbers drawn from [0, 1) and samples_dropout >= dropout will be + used as the dropout mask. If None, a new samples_dropout is created. Note that currently + there is no effective way to generate consistent random number sequences between the + serial and distributed versions so this argument can be used to passed in pre-generated + samples from the serial version sliced according to the input "src" placements as a mock + to reproduce the random number sequence in the distributed version. We use this method + in the tests to verify the distributed version is consistent with the serial version. + + Returns + ------- + DTensor + The source DTensor with dropout applied during training, or unchanged during inference + + Notes + ----- + During training, the values that are kept are scaled by 1/(1-dropout) to maintain + the expected value of the tensor. During inference (training=False), the input tensor + is returned unchanged. + + The implementation uses a custom autograd function to handle DTensor operations + by working with local tensors and properly managing distributed tensor metadata. + + This function enforces specific placement requirements and tensor shape assumptions, + making it suitable only for MSA and pair representation tensors in the expected format. + """ + if not training: + return src + return _ApplyDropoutMaskMsaOrPairImpl.apply(src, dropout, training, columnwise, samples_dropout) diff --git a/src/boltz/distributed/model/layers/dtensor_metadata_tools.py b/src/boltz/distributed/model/layers/dtensor_metadata_tools.py new file mode 100644 index 000000000..5e8ca517c --- /dev/null +++ b/src/boltz/distributed/model/layers/dtensor_metadata_tools.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from typing import Union + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Partial, Placement, Shard + +placement_type_3 = tuple[Placement, Placement, Placement] + + +def raise_if_incorrect_dtensor_metadata_args( + dtensor_instance: DTensor, + dtensor_name: str, + expected_shape: Union[tuple[int, ...], None] = None, + expected_device_mesh: Union[DeviceMesh, None] = None, + expected_placements: Union[tuple[Placement], None] = None, + check_for_partial_placements: bool = False, +): + """Check if the DTensor metadata is correct. + + This function will check in order: + - that the DTensor instance is a DTensor + - that the DTensor instance has the expected shape + - that the DTensor instance has the expected device mesh + - that the DTensor instance has the expected placements + - that the DTensor instance does not have a Partial placement + + If any of the checks fail, an Exception is raised. + + + Parameters + ---------- + dtensor_instance: DTensor + The DTensor instance to check. + dtensor_name: str + The name of the DTensor instance to check. + expected_shape: Union[tuple, None] + The expected shape of the DTensor. If None, the check is skipped. + expected_device_mesh: Union[DeviceMesh, None] + The expected device mesh. If None, the check is skipped. + expected_placements: Union[tuple[Placement], None] + The expected placements. If None, the check is skipped. + check_for_partial_placements: bool + Whether or not to check for partial placements in the input dtensor. + + Notes + ----- + (1) A DTensor with a Partial placement is considered invalid, in this library + as a temporary measure to avoid its complexity. + (2) A valid usage of this function is to check the equality of the + placements between two DTensor instances, and to check that the placements + of the input DTensor do not contain Partial placements. + + Raises + ------ + TypeError + If dtensor_instance is not a DTensor. + ValueError + If the DTensor metadata is incorrect. + """ + + if not isinstance(dtensor_instance, DTensor): + raise TypeError( + ", ".join( + [ + f"DTensor instance '{dtensor_instance}' should have type {DTensor}", + f"but instead has type {type(dtensor_instance)}.", + ] + ) + ) + + # Consolidate if statements, each is >~ 10ns + if expected_shape is not None and not dtensor_instance.shape == expected_shape: + raise ValueError( + ", ".join( + [ + f"DTensor instance '{dtensor_name}' should have shape {expected_shape}", + f"but instead has shape {dtensor_instance.shape}.", + ] + ) + ) + + if expected_device_mesh is not None and not dtensor_instance.device_mesh == expected_device_mesh: + raise ValueError( + ", ".join( + [ + f"DTensor instance '{dtensor_name}' should have device mesh {expected_device_mesh}", + f"but instead has device mesh {dtensor_instance.device_mesh}.", + ] + ) + ) + + if expected_placements is not None and not tuple(dtensor_instance.placements) == tuple(expected_placements): + raise ValueError( + ", ".join( + [ + f"DTensor instance '{dtensor_name}' should have placements {expected_placements}", + f"but instead has placements {dtensor_instance.placements}.", + ] + ) + ) + + if check_for_partial_placements: + for placement in dtensor_instance.placements: + if isinstance(placement, Partial): + raise ValueError( + ", ".join( + [ + f"DTensor instance '{dtensor_name}' should not have have placement of type {Partial}", + f"but instead has placements {dtensor_instance.placements}.", + ] + ) + ) + + +def raise_if_shapes_incompatible_with_placements( + dtensor_instances: tuple[DTensor, ...], + dtensor_names: tuple[str, ...], +): + """ + Checks each DTensor independently, that any dimension that is sharded, should have + a dimension length that is a multiple of the device mesh length. + + NOTE: this function uses triple nesting loop for checking the input dtensor placements, which + could have performance implications + + Parameters + ---------- + dtensor_instances: tuple of DTensor instances + dtensor_names: tuple of names of the DTensor instances + + Raises + ------- + TypeError: If the DTensor instances are not of type DTensor. + ValueError: If the shapes of the DTensor instances are incompatible with the placements. + """ + for x, name in zip(dtensor_instances, dtensor_names): + if not isinstance(x, DTensor): + return TypeError(f"Object '{name}' must be of type DTensor. Got type {type(x)}.") + + for plac in [p for p in x.placements if isinstance(p, Shard)]: + tensor_dim_length = x.shape[plac.dim] + mesh_dim_length = x.device_mesh.shape[plac.dim] + + if tensor_dim_length % mesh_dim_length != 0: + raise ValueError( + ", ".join( + [ + f"Dtensor with name={name} and id={id(x)}", + f"has dimension={plac.dim} with with placement=Shard({plac.dim})", + f"dimension length {tensor_dim_length}", + f"which is not a multiple of the device mesh length {mesh_dim_length}", + ] + ) + ) diff --git a/src/boltz/distributed/model/layers/elementwise_op.py b/src/boltz/distributed/model/layers/elementwise_op.py new file mode 100644 index 000000000..1e26493b2 --- /dev/null +++ b/src/boltz/distributed/model/layers/elementwise_op.py @@ -0,0 +1,1024 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from enum import Enum, auto + +import torch +from torch.distributed.tensor import DTensor, Partial, Shard + + +class ElementwiseOp(Enum): + """Enumeration of supported elementwise operations.""" + + # n-ary ops + SUM = auto() + SUB = auto() + PROD = auto() + DIV = auto() + EQUAL = auto() + BITAND = auto() + + # unary ops + COS = auto() + RELU = auto() + ROUND = auto() + EXP = auto() + ABS = auto() + SIGMOID = auto() + + # comparison ops + GT = auto() + LT = auto() + LOG = auto() + + # scalar-tensor ops + POW = auto() + MAX = auto() + + +class _SingleTensorOpImpl(torch.autograd.Function): + """Distributed implementation of single-tensor operations using DTensors. + + This autograd function implements distributed single-tensor operations + like cosine, ReLU, and logarithm. The operations are performed element-wise across distributed + tensors while maintaining proper gradient computation. + + Supported operations: + - COS: output = cos(x) + - RELU: output = max(0, x) + - ROUND: output = round(x) + - LOG: output = log(x) + - EXP: output = exp(x) + - ABS: output = |x| + - SIGMOID: output = 1 / (1 + exp(-x)) + + Key features: + - Distributed computation across device meshes with various sharding strategies + - Memory-efficient implementation that operates on local tensor chunks + - Supports gradient computation through custom backward pass + """ + + @staticmethod + def forward(ctx, x: DTensor, op: ElementwiseOp) -> DTensor: + """Forward pass of distributed single-tensor operation. + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object for saving information needed in backward pass. + x : DTensor + Input tensor. Can have any shape and sharding strategy. + op : ElementwiseOp + The operation to perform (COS, RELU, ROUND, LOG, EXP, ABS, or SIGMOID). + Returns + ------- + DTensor + Output tensor with shape identical to input tensor. + + Raises + ------ + TypeError + If input is not a DTensor. + ValueError + If Partial placements are used (not supported), or if op is invalid. + """ + if not isinstance(x, DTensor): + raise TypeError(f"Input 'x' must be of type DTensor. Got type {type(x)}.") + if not isinstance(op, ElementwiseOp): + raise TypeError(f"Input 'op' must be of type ElementwiseOp. Got type {type(op)}.") + + device_mesh_input = x.device_mesh + placements_input = x.placements + + for i_dim_device_mesh, placement in enumerate(placements_input): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + if x.shape[placement.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {x.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + + x_local = x.to_local() + + # Perform the operation + if op == ElementwiseOp.COS: + output_local = torch.cos(x_local) + if x.requires_grad: + x_local_copy = x_local.detach().clone() + ctx.save_for_backward(x_local_copy) + elif op == ElementwiseOp.RELU: + output_local = torch.relu(x_local) + if x.requires_grad: + x_local_copy = x_local.detach().clone() + ctx.save_for_backward(x_local_copy) + elif op == ElementwiseOp.ROUND: + output_local = torch.round(x_local) + elif op == ElementwiseOp.LOG: + output_local = torch.log(x_local) + if x.requires_grad: + x_local_copy = x_local.detach().clone() + ctx.save_for_backward(x_local_copy) + elif op == ElementwiseOp.EXP: + output_local = torch.exp(x_local) + if x.requires_grad: + x_local_copy = x_local.detach().clone() + ctx.save_for_backward(x_local_copy) + elif op == ElementwiseOp.ABS: + output_local = torch.abs(x_local) + if x.requires_grad: + x_local_copy = x_local.detach().clone() + ctx.save_for_backward(x_local_copy) + elif op == ElementwiseOp.SIGMOID: + output_local = torch.sigmoid(x_local) + if x.requires_grad: + ctx.save_for_backward(output_local.clone()) + else: + raise ValueError(f"Unsupported single-tensor operation: {op}") + + ctx.device_mesh_input = device_mesh_input + ctx.placements_input = placements_input + ctx.input_shape = x.shape + ctx.op = op + + out = DTensor.from_local( + output_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=x.shape, + stride=x.stride(), + ) + if op == ElementwiseOp.ROUND: + ctx.mark_non_differentiable(out) + return out + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[DTensor | None, None]: + """Backward pass of distributed single-tensor operation. + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object containing saved tensors and metadata from forward pass. + grad_output : DTensor + Gradients of the loss with respect to the output tensor. + + Returns + ------- + tuple[DTensor | None, None] + Gradients with respect to x and op. + """ + if not isinstance(grad_output, DTensor): + raise TypeError(f"Input 'grad_output' must be of type DTensor. Got type {type(grad_output)}.") + + if grad_output.device_mesh != ctx.device_mesh_input: + raise ValueError( + f"Input 'grad_output' must have the same device mesh as the input tensor. " + f"Got device meshes {grad_output.device_mesh} and {ctx.device_mesh_input}." + ) + + if grad_output.placements != ctx.placements_input: + raise ValueError( + f"Input 'grad_output' must have the same placements as the input tensor. " + f"Got placements {grad_output.placements} and {ctx.placements_input}." + ) + + grad_output_local = grad_output.to_local() + dx = None + x_local = ctx.saved_tensors[0] + + # Compute gradients based on operation + if ctx.op == ElementwiseOp.COS: + if ctx.needs_input_grad[0]: + # Derivative of cos(x) is -sin(x) + dx_local = grad_output_local * (-torch.sin(x_local)) + dx = DTensor.from_local( + dx_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.RELU: + if ctx.needs_input_grad[0]: + # Derivative of relu(x) is 1 where x > 0, 0 elsewhere + dx_local = grad_output_local.clone() + dx_local[x_local <= 0] = 0 + dx = DTensor.from_local( + dx_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.ROUND: + pass # no gradient through this op + elif ctx.op == ElementwiseOp.LOG: + if ctx.needs_input_grad[0]: + # Derivative of log(x) is 1/x + dx_local = grad_output_local / x_local + dx = DTensor.from_local( + dx_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.EXP: + if ctx.needs_input_grad[0]: + # Derivative of exp(x) is exp(x) + dx_local = grad_output_local * torch.exp(x_local) + dx = DTensor.from_local( + dx_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.ABS: + if ctx.needs_input_grad[0]: + # Derivative of abs(x) is sign(x) = x/abs(x) for x != 0, undefined at x = 0 + # We use torch.sign(x) which handles the case x = 0 by returning 0 + dx_local = grad_output_local * torch.sign(x_local) + dx = DTensor.from_local( + dx_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.SIGMOID: + if ctx.needs_input_grad[0]: + # Derivative of sigmoid(x) is sigmoid(x) * (1 - sigmoid(x)) + sigmoid_output = x_local + dx_local = grad_output_local * sigmoid_output * (1 - sigmoid_output) + dx = DTensor.from_local( + dx_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + else: + raise ValueError(f"Unsupported single-tensor operation: {ctx.op}") + + return dx, None + + +class _ElementwiseOpImpl(torch.autograd.Function): + """Distributed implementation of elementwise operations using DTensors. + + This autograd function implements distributed elementwise operations that can perform + summation, multiplication, division, and logical operations between two input tensors. The operations are + performed element-wise across distributed tensors while maintaining proper gradient + computation. + + Supported operations: + - SUM: output = a + b + - SUB: output = a - b + - PROD: output = a * b + - DIV: output = a / b + - EQUAL: output = a & b + - BITAND: output = a & b + + Key features: + - Distributed computation across device meshes with various sharding strategies + - Memory-efficient implementation that operates on local tensor chunks + - Supports gradient computation through custom backward pass + - Validates tensor compatibility (device mesh, placements, shapes) + + Notes + ----- + Input tensors must be DTensors with: + - Identical device mesh and placements + - Compatible shapes (a and b must have the same shape) + - No Partial placements (not currently supported) + """ + + @staticmethod + def forward(ctx, a: DTensor, b: DTensor, op: ElementwiseOp) -> DTensor: + """Forward pass of distributed elementwise operation. + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object for saving information needed in backward pass. + a : DTensor + First input tensor. Can have any shape and sharding strategy. + b : DTensor + Second input tensor. Must have identical shape, device mesh, + and placements as a. + op : ElementwiseOp + The operation to perform (SUM, PROD, DIV, EQUAL, or BITAND). + + Returns + ------- + DTensor + Output tensor with shape identical to input tensors. + Contains the result of the specified operation. + + Raises + ------ + TypeError + If inputs are not DTensors. + ValueError + If tensors have incompatible device meshes, placements, or if + Partial placements are used (not supported), or if op is invalid. + """ + if not isinstance(a, DTensor): + raise TypeError(f"Input 'a' must be of type DTensor. Got type {type(a)}.") + if not isinstance(b, DTensor): + raise TypeError(f"Input 'b' must be of type DTensor. Got type {type(b)}.") + if not isinstance(op, ElementwiseOp): + raise TypeError(f"Input 'op' must be of type ElementwiseOp. Got type {type(op)}.") + + device_mesh_input = a.device_mesh + if b.device_mesh != device_mesh_input: + raise ValueError( + f"Input tensors 'a' and 'b' must have identical device mesh. " + f"Got device meshes {device_mesh_input} and {b.device_mesh}." + ) + + placements_input = a.placements + for i_dim_device_mesh, placement in enumerate(placements_input): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + if a.shape[placement.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {a.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + + if b.placements != placements_input: + raise ValueError( + f"Input tensors 'a' and 'b' must have identical placements. " + f"Got placements {placements_input} and {b.placements}." + ) + + input_shape = a.shape + if input_shape != b.shape: + raise ValueError( + f"Input tensors 'a' and 'b' must have identical shapes. Got shapes {input_shape} and {b.shape}." + ) + + a_local = a.to_local() + b_local = b.to_local() + + # Perform the operation + if op == ElementwiseOp.SUM: + output_local = a_local + b_local + elif op == ElementwiseOp.SUB: + output_local = a_local - b_local + elif op == ElementwiseOp.PROD: + output_local = a_local * b_local + # TODO: check if we can afford save_for_backward(a_local, b_local) without explicitly copying + # pytorch's c++ backend has this code here that can determine the necessity of the copy: + # https://github.com/pytorch/pytorch/blob/7caf6c801ddfaf556a3ca191173b50002c4261f4/torch/csrc/autograd/saved_variable.cpp#L67-L79 + # so we might not need to explicitly copy the tensors here + a_local_copy = a_local.detach().clone() if b.requires_grad else None + b_local_copy = b_local.detach().clone() if a.requires_grad else None + ctx.save_for_backward(a_local_copy, b_local_copy) + elif op == ElementwiseOp.DIV: + output_local = a_local / b_local + # Save tensors for backward pass: + # - if a.requires_grad, we need b for: da = grad_output / b + # - if b.requires_grad, we need b and output for: db = -grad_output * output / b + b_local_copy = b_local.detach().clone() if a.requires_grad else None + output_over_b_local = output_local / b_local if b.requires_grad else None + ctx.save_for_backward(b_local_copy, output_over_b_local) + elif op == ElementwiseOp.EQUAL: + if a.requires_grad or b.requires_grad: + raise ValueError("EQUAL operation is not differentiable but requires_grad is True") + output_local = a_local & b_local + elif op == ElementwiseOp.BITAND: + if a.requires_grad or b.requires_grad: + raise ValueError("BITAND operation is not differentiable but requires_grad is True") + output_local = a_local & b_local + else: + raise ValueError(f"Unsupported operation: {op}") + + ctx.device_mesh_input = device_mesh_input + ctx.placements_input = placements_input + ctx.input_shape = input_shape + ctx.op = op + + out = DTensor.from_local( + output_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=a.shape, + stride=a.stride(), + ) + return out + + @staticmethod + def backward(ctx, grad_output) -> tuple[DTensor | None, DTensor | None, None]: + """Backward pass of distributed elementwise operation. + + Computes gradients with respect to both input tensors a and b. + + The gradients are: + - For SUM: da = grad_output, db = grad_output + - For SUB: da = grad_output, db = -grad_output + - For PROD: da = grad_output * b, db = grad_output * a + - For DIV: da = grad_output / b, db = -grad_output * output / b + - For EQUAL: da = 0, db = 0 (not differentiable) + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object containing saved tensors and metadata from forward pass. + grad_output : DTensor + Gradient of the loss with respect to the output tensor. + Must have identical device mesh and placements as the input tensors. + + Returns + ------- + tuple[DTensor, DTensor, None] + Gradients with respect to a, b, and None for the op parameter. + Both gradients have the same shape and distribution as their corresponding inputs. + + Raises + ------ + TypeError + If grad_output is not a DTensor. + ValueError + If grad_output has incompatible device mesh or placements compared + to the input tensors from the forward pass. + """ + + if not isinstance(grad_output, DTensor): + raise TypeError(f"Input 'grad_output' must be of type DTensor. Got type {type(grad_output)}.") + + if grad_output.device_mesh != ctx.device_mesh_input: + raise ValueError( + f"Input 'grad_output' must have the same device mesh as the input tensor. " + f"Got device meshes {grad_output.device_mesh} and {ctx.device_mesh_input}." + ) + + if grad_output.placements != ctx.placements_input: + raise ValueError( + f"Input 'grad_output' must have the same placements as the input tensor. " + f"Got placements {grad_output.placements} and {ctx.placements_input}." + ) + + if grad_output.shape != ctx.input_shape: + raise ValueError( + f"Input 'grad_output' must have the same shape as the input tensor. " + f"Got shapes {grad_output.shape} and {ctx.input_shape}." + ) + + grad_output_local = grad_output.to_local() + + # Compute gradients based on operation + da = None + db = None + + if ctx.op == ElementwiseOp.SUM: + if ctx.needs_input_grad[0]: + da_local = grad_output_local + da = DTensor.from_local( + da_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + if ctx.needs_input_grad[1]: + db_local = grad_output_local + db = DTensor.from_local( + db_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.SUB: + if ctx.needs_input_grad[0]: + da_local = grad_output_local + da = DTensor.from_local( + da_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + if ctx.needs_input_grad[1]: + db_local = -grad_output_local + db = DTensor.from_local( + db_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.PROD: + # Unpack saved_tensors once for checkpoint compatibility - must only be done once + a_local, b_local = ctx.saved_tensors + + if ctx.needs_input_grad[0]: + da_local = grad_output_local * b_local + da = DTensor.from_local( + da_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + if ctx.needs_input_grad[1]: + db_local = grad_output_local * a_local + db = DTensor.from_local( + db_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.DIV: + # Unpack saved_tensors once for checkpoint compatibility - must only be done once + b_local, output_over_b_local = ctx.saved_tensors + + if ctx.needs_input_grad[0]: + # Gradient w.r.t. numerator: da = grad_output / b + da_local = grad_output_local / b_local + da = DTensor.from_local( + da_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + if ctx.needs_input_grad[1]: + # Gradient w.r.t. denominator: db = -grad_output * output / b + db_local = -grad_output_local * output_over_b_local + db = DTensor.from_local( + db_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.EQUAL: # dummy op, not differentiable + pass + elif ctx.op == ElementwiseOp.BITAND: # dummy op, not differentiable + pass + else: + raise ValueError(f"Unsupported operation: {ctx.op}") + + return da, db, None + + +def elementwise_op(a: DTensor, b: DTensor, op: ElementwiseOp) -> DTensor: + """Apply elementwise operation to two distributed tensors. + + This function performs element-wise operations between two distributed tensors. + Supported operations are summation (a + b), subtraction (a - b), multiplication (a * b), + division (a / b), logical AND (a & b), and bitwise AND (a & b). The operation is performed + efficiently using local tensor operations while maintaining gradient + computation capabilities. + + Parameters + ---------- + a : DTensor + First input tensor. Can have any shape and sharding strategy. + b : DTensor + Second input tensor. Must have identical shape, device mesh, + and placements as a. + op : ElementwiseOp + The operation to perform (ElementwiseOp.SUM, ElementwiseOp.SUB, ElementwiseOp.PROD, ElementwiseOp.DIV, ElementwiseOp.EQUAL, or ElementwiseOp.BITAND). + + Returns + ------- + DTensor + Output tensor with shape identical to input tensors. + Contains the result of the specified operation. + + Examples + -------- + >>> # Assume we have distributed tensors a and b with shape (B, N, D) + >>> sum_output = elementwise_op(a, b, ElementwiseOp.SUM) + >>> # sum_output = a + b, computed in distributed fashion + >>> + >>> sub_output = elementwise_op(a, b, ElementwiseOp.SUB) + >>> # sub_output = a - b, computed in distributed fashion + >>> + >>> prod_output = elementwise_op(a, b, ElementwiseOp.PROD) + >>> # prod_output = a * b, computed in distributed fashion + >>> + >>> div_output = elementwise_op(a, b, ElementwiseOp.DIV) + >>> # div_output = a / b, computed in distributed fashion + >>> + >>> equal_output = elementwise_op(a, b, ElementwiseOp.EQUAL) + >>> # equal_output = a == b, computed in distributed fashion + >>> + >>> bitand_output = elementwise_op(a, b, ElementwiseOp.BITAND) + >>> # bitand_output = a & b, computed in distributed fashion + + Notes + ----- + - Both input tensors must be DTensors with compatible device meshes and placements + - Partial placements are not currently supported + - The function is differentiable and supports gradient computation + - The operation is performed on local tensor chunks for efficiency + """ + return _ElementwiseOpImpl.apply(a, b, op) # type: ignore + + +def single_tensor_op(x: DTensor, op: ElementwiseOp) -> DTensor: + """Apply single-tensor operation to a distributed tensor. + + This function performs element-wise operations on a single distributed tensor. + Supports cosine, ReLU, round, logarithm, and exponential operations. + + Parameters + ---------- + x : DTensor + Input tensor. Can have any shape and sharding strategy. + op : ElementwiseOp + The operation to perform (ElementwiseOp.COS, ElementwiseOp.RELU, ElementwiseOp.ROUND, ElementwiseOp.LOG, ElementwiseOp.EXP, ElementwiseOp.ABS, or ElementwiseOp.SIGMOID). + + Returns + ------- + DTensor + Output tensor with shape identical to input tensor. + + Examples + -------- + >>> # Assume we have distributed tensor x with shape (B, N, D) + >>> cos_output = single_tensor_op(x, ElementwiseOp.COS) + >>> # cos_output = cos(x), computed in distributed fashion + >>> + >>> relu_output = single_tensor_op(x, ElementwiseOp.RELU) + >>> # relu_output = max(0, x), computed in distributed fashion + >>> + >>> round_output = single_tensor_op(x, ElementwiseOp.ROUND) + >>> # round_output = round(x), computed in distributed fashion + >>> + >>> log_output = single_tensor_op(x, ElementwiseOp.LOG) + >>> # log_output = log(x), computed in distributed fashion + >>> + >>> exp_output = single_tensor_op(x, ElementwiseOp.EXP) + >>> # exp_output = exp(x), computed in distributed fashion + >>> + >>> abs_output = single_tensor_op(x, ElementwiseOp.ABS) + >>> # abs_output = |x|, computed in distributed fashion + >>> + >>> sigmoid_output = single_tensor_op(x, ElementwiseOp.SIGMOID) + >>> # sigmoid_output = 1 / (1 + exp(-x)), computed in distributed fashion + + Notes + ----- + - Input tensor must be a DTensor with any placement strategy + - Partial placements are not currently supported + - The function is differentiable and supports gradient computation + - The operation is performed on local tensor chunks for efficiency + """ + return _SingleTensorOpImpl.apply(x, op) # type: ignore + + +class _ScalarTensorOpImpl(torch.autograd.Function): + """Distributed implementation of scalar-tensor operations using DTensors. + + This autograd function implements distributed operations between a scalar and a DTensor. + The operations are performed element-wise across distributed tensors while maintaining + proper gradient computation. + + Supported operations: + - SUM: output = scalar + tensor + - SUB: output = scalar - tensor + - PROD: output = scalar * tensor + - DIV: output = scalar / tensor + - GT: output = scalar > tensor + - LT: output = scalar < tensor + - EQUAL: output = scalar == tensor + - POW: output = tensor ** scalar + - MAX: output = max(scalar, tensor) (element-wise clamp from below) + + Key features: + - Distributed computation across device meshes with various sharding strategies + - Memory-efficient implementation that operates on local tensor chunks + - Supports gradient computation through custom backward pass + - Validates tensor compatibility (no Partial placements) + """ + + @staticmethod + def forward(ctx, scalar: float | int, tensor: DTensor, op: ElementwiseOp) -> DTensor: + """Forward pass of distributed scalar-tensor operation. + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object for saving information needed in backward pass. + scalar : float + Scalar value to operate with. + tensor : DTensor + Input tensor. Can have any shape and sharding strategy. + op : ElementwiseOp + The operation to perform (SUM, SUB, PROD, DIV, GT, or LT). + + Returns + ------- + DTensor + Output tensor with shape identical to input tensor. + + Raises + ------ + TypeError + If inputs are not of expected types. + ValueError + If Partial placements are used (not supported), or if op is invalid. + """ + if not isinstance(scalar, (int, float)): + raise TypeError(f"Input 'scalar' must be of type int or float. Got type {type(scalar)}.") + if not isinstance(tensor, DTensor): + raise TypeError(f"Input 'tensor' must be of type DTensor. Got type {type(tensor)}.") + if not isinstance(op, ElementwiseOp): + raise TypeError(f"Input 'op' must be of type ElementwiseOp. Got type {type(op)}.") + + device_mesh_input = tensor.device_mesh + placements_input = tensor.placements + + for i_dim_device_mesh, placement in enumerate(placements_input): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + if tensor.shape[placement.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {tensor.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + + tensor_local = tensor.to_local() + + # Perform the operation + if op == ElementwiseOp.SUM: + output_local = scalar + tensor_local + elif op == ElementwiseOp.SUB: + output_local = scalar - tensor_local + elif op == ElementwiseOp.PROD: + output_local = scalar * tensor_local + elif op == ElementwiseOp.DIV: + output_local = scalar / tensor_local + if tensor.requires_grad: + # Save tensor for backward pass + tensor_local_copy = tensor_local.detach().clone() + ctx.save_for_backward(tensor_local_copy) + elif op == ElementwiseOp.GT: + output_local = scalar > tensor_local + elif op == ElementwiseOp.LT: + output_local = scalar < tensor_local + elif op == ElementwiseOp.EQUAL: + output_local = torch.eq( + torch.tensor(scalar, device=tensor_local.device, dtype=tensor_local.dtype), tensor_local + ) + elif op == ElementwiseOp.POW: + if tensor_local.min() < 0 and not scalar.is_integer(): + raise ValueError( + "Negative tensor values are not supported for DTensor POW operation but got scalar: {scalar}" + ) + output_local = torch.pow(tensor_local, scalar) + if tensor.requires_grad: + # Save tensor for backward pass: d_tensor = grad_output * scalar * tensor^(scalar-1) + if scalar != 0: # 0 pow has gradient of zero + tensor_local_copy = tensor_local.detach().clone() + ctx.save_for_backward(tensor_local_copy) + elif op == ElementwiseOp.MAX: + # max(scalar, tensor) element-wise, equivalent to tensor.clamp(min=scalar) + output_local = torch.clamp(tensor_local, min=scalar) + if tensor.requires_grad: + # Save mask for backward: gradient passes through where tensor >= scalar + mask_local = (tensor_local >= scalar).detach() + ctx.save_for_backward(mask_local) + else: + raise ValueError(f"Unsupported scalar-tensor operation: {op}") + + if tensor.requires_grad: + ctx.device_mesh_input = device_mesh_input + ctx.placements_input = placements_input + ctx.input_shape = tensor.shape + ctx.op = op + ctx.scalar = scalar + + out = DTensor.from_local( + output_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=tensor.shape, + stride=tensor.stride(), + ) + if op == ElementwiseOp.GT or op == ElementwiseOp.LT or op == ElementwiseOp.EQUAL: + ctx.mark_non_differentiable(out) + return out + + @staticmethod + def backward(ctx, grad_output) -> tuple[float | None, DTensor | None, None]: + """Backward pass of distributed scalar-tensor operation. + + Computes gradients with respect to both scalar and tensor inputs. + + The gradients are: + - For SUM: d_scalar = grad_output.sum(), d_tensor = grad_output + - For SUB: d_scalar = grad_output.sum(), d_tensor = -grad_output + - For PROD: d_scalar = (grad_output * tensor).sum(), d_tensor = grad_output * scalar + - For DIV: d_scalar = -(grad_output * tensor / scalar^2).sum(), d_tensor = -grad_output * scalar / tensor^2 + - For GT: d_scalar = None, d_tensor = None (not differentiable) + - For LT: d_scalar = None, d_tensor = None (not differentiable) + - For EQUAL: d_scalar = None, d_tensor = None (not differentiable) + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object containing saved tensors and metadata from forward pass. + grad_output : DTensor + Gradients of the loss with respect to the output tensors. + + Returns + ------- + tuple[float | None, DTensor | None, None] + Gradients with respect to scalar, tensor, and None for the op parameter. + """ + if not ctx.needs_input_grad[1]: + return None, None, None + + if not isinstance(grad_output, DTensor): + raise TypeError(f"Input 'grad_output' must be of type DTensor. Got type {type(grad_output)}.") + + if grad_output.device_mesh != ctx.device_mesh_input: + raise ValueError( + f"Input 'grad_output' must have the same device mesh as the input tensor. " + f"Got device meshes {grad_output.device_mesh} and {ctx.device_mesh_input}." + ) + + if grad_output.placements != ctx.placements_input: + raise ValueError( + f"Input 'grad_output' must have the same placements as the input tensor. " + f"Got placements {grad_output.placements} and {ctx.placements_input}." + ) + + grad_output_local = grad_output.to_local() + d_tensor = None + + # Compute gradients based on operation + if ctx.op == ElementwiseOp.SUM: + # Gradient w.r.t. tensor: grad_output + d_tensor = DTensor.from_local( + grad_output_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.SUB: + # Gradient w.r.t. tensor: -grad_output + d_tensor_local = -grad_output_local + d_tensor = DTensor.from_local( + d_tensor_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.PROD: + # Gradient w.r.t. tensor: grad_output * scalar + d_tensor_local = grad_output_local * ctx.scalar + d_tensor = DTensor.from_local( + d_tensor_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.DIV: + (tensor_local,) = ctx.saved_tensors + # Gradient w.r.t. tensor: -grad_output * scalar / tensor^2 + d_tensor_local = -grad_output_local * ctx.scalar / (tensor_local**2) + d_tensor = DTensor.from_local( + d_tensor_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.GT: # dummy op, not differentiable + pass # no gradient through this op + elif ctx.op == ElementwiseOp.LT: # dummy op, not differentiable + pass # no gradient through this op + elif ctx.op == ElementwiseOp.EQUAL: # dummy op, not differentiable + pass # no gradient through this op + elif ctx.op == ElementwiseOp.POW: + if ctx.scalar == 0: + d_tensor_local = torch.zeros_like(grad_output_local) + else: + (tensor_local,) = ctx.saved_tensors + d_tensor_local = grad_output_local * ctx.scalar * tensor_local ** (ctx.scalar - 1) + d_tensor = DTensor.from_local( + d_tensor_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + elif ctx.op == ElementwiseOp.MAX: + # Gradient of max(scalar, tensor): passes through where tensor >= scalar, else 0 + (mask_local,) = ctx.saved_tensors + d_tensor_local = grad_output_local * mask_local + d_tensor = DTensor.from_local( + d_tensor_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + else: + raise ValueError(f"Unsupported scalar-tensor operation: {ctx.op}") + + return None, d_tensor, None + + +# TODO reduce human error by switching the order (scalar, tensor) to (tensor, scalar) +def scalar_tensor_op(scalar: float | int, tensor: DTensor, op: ElementwiseOp) -> DTensor: + """Apply scalar-tensor operation to a scalar and distributed tensor. + + This function performs element-wise operations between a scalar and a distributed tensor. + Supported operations are summation (scalar + tensor), subtraction (scalar - tensor), multiplication (scalar * tensor), + division (scalar / tensor), greater than comparison (scalar > tensor), less than comparison (scalar < tensor), equality comparison (scalar == tensor), and power (tensor ** scalar). The operation is performed efficiently using local + tensor operations while maintaining gradient computation capabilities. + + Parameters + ---------- + scalar : float | int + Scalar value to operate with. + tensor : DTensor + Input tensor. Can have any shape and sharding strategy. + op : ElementwiseOp + The operation to perform (ElementwiseOp.SUM, ElementwiseOp.SUB, ElementwiseOp.PROD, ElementwiseOp.DIV, ElementwiseOp.GT, ElementwiseOp.LT, ElementwiseOp.EQUAL, ElementwiseOp.POW, or ElementwiseOp.MAX). + + Returns + ------- + DTensor + Output tensor with shape identical to input tensor. + Contains the result of the specified operation. + + Examples + -------- + >>> # Assume we have distributed tensor x with shape (B, N, D) + >>> sum_output = scalar_tensor_op(2.0, x, ElementwiseOp.SUM) + >>> # sum_output = 2.0 + x, computed in distributed fashion + >>> + >>> sub_output = scalar_tensor_op(2.0, x, ElementwiseOp.SUB) + >>> # sub_output = 2.0 - x, computed in distributed fashion + >>> + >>> prod_output = scalar_tensor_op(0.5, x, ElementwiseOp.PROD) + >>> # prod_output = 0.5 * x, computed in distributed fashion + >>> + >>> div_output = scalar_tensor_op(1.0, x, ElementwiseOp.DIV) + >>> # div_output = 1.0 / x, computed in distributed fashion + >>> + >>> gt_output = scalar_tensor_op(0.5, x, ElementwiseOp.GT) + >>> # gt_output = 0.5 > x, computed in distributed fashion (boolean tensor) + >>> + >>> lt_output = scalar_tensor_op(0.5, x, ElementwiseOp.LT) + >>> # lt_output = 0.5 < x, computed in distributed fashion (boolean tensor) + >>> + >>> equal_output = scalar_tensor_op(0.5, x, ElementwiseOp.EQUAL) + >>> # equal_output = 0.5 == x, computed in distributed fashion (boolean tensor) + >>> + >>> pow_output = scalar_tensor_op(2.0, x, ElementwiseOp.POW) + >>> # pow_output = x ** 2.0, computed in distributed fashion + + Notes + ----- + - Input tensor must be a DTensor with any placement strategy + - Partial placements are not currently supported + - The function is differentiable and supports gradient computation for both scalar and tensor (except GT, LT, and EQUAL which are not differentiable) + - The operation is performed on local tensor chunks for efficiency + + Raises + ------ + TypeError + If inputs are not of expected types. + ValueError + If Partial placements are used (not supported), or if op is invalid. + """ + return _ScalarTensorOpImpl.apply(scalar, tensor, op) # type: ignore diff --git a/src/boltz/distributed/model/layers/embedding.py b/src/boltz/distributed/model/layers/embedding.py new file mode 100644 index 000000000..002926162 --- /dev/null +++ b/src/boltz/distributed/model/layers/embedding.py @@ -0,0 +1,273 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from typing import Optional, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_tensor + +from boltz.distributed.utils import update_exhaustive_strides + + +class _EmbeddingParamsReplicatedImpl(torch.autograd.Function): + """ + Custom autograd implementation for embedding with replicated parameters. + + The embedding weight is replicated across all device mesh dimensions, while the input + indices can be sharded or replicated. The output DTensor follows the same placements + as the input indices. + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + x: DTensor, + weight: DTensor, + padding_idx: Optional[int], + ) -> DTensor: + """ + Forward pass for the distributed embedding operation. + + Assumptions and requirements: + 1. Parameters (weight) must be replicated on all device mesh dimensions + 2. Input tensor and parameters must be on the same device mesh + 3. Partial reduction along any input dimension is not supported + 4. Input indices must be integer dtype and must not require gradients + + Args: + ctx: Context object to store information for backward pass + x: Input index DTensor with arbitrary placement strategy, except Partial placement + weight: Weight DTensor with all-replicate placements + padding_idx: Optional padding index to be passed into F.embedding + + Returns: + Output DTensor with same placement strategy as input + + Raises: + ValueError: If any of the placement requirements are violated + """ + if not isinstance(x, DTensor): + raise TypeError(f"Expected x to be a DTensor but got {type(x)}.") + if x.dtype.is_floating_point or x.dtype.is_complex or x.dtype == torch.bool: + raise ValueError(f"Expected x to be an integer DTensor but got {x.dtype}.") + if x.requires_grad: + raise ValueError("x must not require grad in the forward pass") + if not isinstance(weight, DTensor): + raise TypeError(f"Expected weight to be a DTensor but got {type(weight)}.") + + device_mesh = x.device_mesh + if weight.device_mesh != device_mesh: + raise ValueError("weight and x must be on the same device mesh") + + ndim_device_mesh = device_mesh.ndim + all_replicate_placements = tuple([Replicate()] * ndim_device_mesh) + if weight.placements != all_replicate_placements: + raise ValueError("weight must be replicated on all device mesh dimensions") + + placements_grad_params = list(weight.placements) + + for i_dim_device_mesh, placement in enumerate(x.placements): + if isinstance(placement, Partial): + raise ValueError("Partial reduction along any input dimension is not supported") + if isinstance(placement, Shard): + if placement.dim >= x.ndim: + raise ValueError("Input placement sharding dimension is out of range") + if x.shape[placement.dim] % device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {x.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh.shape[i_dim_device_mesh]} " + "is not supported" + ) + placements_grad_params[i_dim_device_mesh] = Partial("sum") + elif not isinstance(placement, Replicate): + raise ValueError( + f"Unsupported x's placements along {i_dim_device_mesh} axis of the device mesh: {placement}" + ) + + x_local = x.to_local() + weight_local = weight.to_local() + + needs_grad = weight.requires_grad + if needs_grad: + with torch.enable_grad(): + weight_local_detached = weight_local.detach().requires_grad_(True) + output_local = F.embedding( + x_local, + weight_local_detached, + padding_idx=padding_idx, + ) + ctx.save_for_backward(output_local, weight_local_detached) + else: + output_local = F.embedding( + x_local, + weight_local, + padding_idx=padding_idx, + ) + + shape_output = list(output_local.shape) + for i_dim_mesh, placement in enumerate(x.placements): + if isinstance(placement, Shard): + shape_output[placement.dim] *= device_mesh.shape[i_dim_mesh] + + if needs_grad: + ctx.device_mesh = device_mesh + ctx.placements_x = x.placements + ctx.placements_grad_params = placements_grad_params + ctx.weight_shape = weight.shape + ctx.weight_stride = weight.stride() + ctx.output_shape = torch.Size(shape_output) + + stride_output = update_exhaustive_strides(output_local.shape, output_local.stride(), shape_output) + output = DTensor.from_local( + output_local, + device_mesh, + x.placements, + shape=torch.Size(shape_output), + stride=torch.Size(stride_output), + ) + return output + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward( # type: ignore[override] + ctx, grad_output: DTensor + ) -> tuple[None, Optional[DTensor], None]: + """ + Backward pass for the distributed embedding operation. + + Assumptions and requirements: + 1. Gradient of the output is on the same device mesh as the output + + Args: + ctx: Context object with stored information from forward pass + grad_output: Gradient of the loss with respect to the output + + Returns: + Tuple of gradients for input, weight, and padding_idx. + """ + if grad_output.device_mesh != ctx.device_mesh: + raise ValueError( + "_EmbeddingParamsReplicatedImpl: different device mesh between grad_output and the forward input" + ) + + if grad_output.shape != ctx.output_shape: + raise ValueError( + "_EmbeddingParamsReplicatedImpl: different shape between grad_output and the forward output" + ) + + grad_weight = None + if ctx.needs_input_grad[1]: + grad_output_local = grad_output.to_local() + output_local, weight_local_detached = ctx.saved_tensors + (grad_weight_local,) = torch.autograd.grad( + outputs=[output_local], + inputs=[weight_local_detached], + grad_outputs=[grad_output_local], + retain_graph=False, + ) + grad_weight = DTensor.from_local( + grad_weight_local, + ctx.device_mesh, + ctx.placements_grad_params, + shape=ctx.weight_shape, + stride=ctx.weight_stride, + ) + + return None, grad_weight, None + + +class EmbeddingParamsReplicated(nn.Module): + """ + Distributed embedding layer with parameters replicated across all device mesh dimensions. + + This is almost equivalent to + ```python + layer = torch.distributed.tensor.distribute_module(layer_local, device_mesh) + ``` + with the exception that the torch.distributed.tensor.distribute_module version will incur + significant overhead due to the unnecessary replication of the output tensor along certain + device mesh dimensions. + + This class avoids such unnecessary overhead by using the custom _EmbeddingParamsReplicatedImpl + autograd function for forward and backward pass computation instead of relying on the distributed + module's forward implementation. + + Key requirements: + 1. Parameters (weight) will be replicated on all device mesh dimensions + 2. Input tensor and parameters must be on the same device mesh + 3. Partial reduction along any input dimension is not supported + 4. Input indices must be integer dtype and must not require gradients + 5. Input and outputs must be on the same device mesh with the same placements + 6. Gradients of the weight have Partial("sum") placements along the input's Shard placements' + dimension so that the all-reduce will be performed along those device-grid dimensions + """ + + def __init__(self, layer_local: nn.Embedding, device_mesh: DeviceMesh): + if not isinstance(layer_local, nn.Embedding): + raise TypeError("layer_local is not an instance of nn.Embedding") + if layer_local.weight.device.type != device_mesh.device_type: + raise ValueError( + f"layer_local.weight and device_mesh are not on the same device type: " + f"{layer_local.weight.device.type} != {device_mesh.device_type}" + ) + if layer_local.sparse: + raise ValueError("sparse option is not supported in EmbeddingParamsReplicated") + if layer_local.scale_grad_by_freq: + raise ValueError("scale_grad_by_freq option is not supported in EmbeddingParamsReplicated") + if layer_local.max_norm is not None: + raise ValueError("max_norm option is not supported in EmbeddingParamsReplicated") + + super().__init__() + all_replicate_placements = [Replicate()] * device_mesh.ndim + self.weight = nn.Parameter( + distribute_tensor(layer_local.weight.data, device_mesh, all_replicate_placements), + requires_grad=layer_local.weight.requires_grad, + ) + self.padding_idx = layer_local.padding_idx + self.num_embeddings = layer_local.num_embeddings + self.embedding_dim = layer_local.embedding_dim + + def forward(self, input: DTensor) -> DTensor: + """ + Forward pass for the distributed embedding layer. + + Uses the custom _EmbeddingParamsReplicatedImpl autograd function to perform the computation + efficiently while preserving correct autograd behavior for distributed tensors. + + Args: + input: Input index DTensor + + Returns: + Output DTensor with same placement strategy as input + """ + return cast( + DTensor, + _EmbeddingParamsReplicatedImpl.apply( + input, + self.weight, + self.padding_idx, + ), + ) diff --git a/src/boltz/distributed/model/layers/flatten_and_unflatten.py b/src/boltz/distributed/model/layers/flatten_and_unflatten.py new file mode 100644 index 000000000..beccf14a8 --- /dev/null +++ b/src/boltz/distributed/model/layers/flatten_and_unflatten.py @@ -0,0 +1,864 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import math + +import torch +from torch import Tensor +from torch.autograd.function import FunctionCtx +from torch.distributed.tensor import DTensor, Partial, Placement, Shard + +from boltz.distributed.utils import LayoutRightMap, update_exhaustive_strides + + +def _shardwise_flatten( + x: DTensor, + start_dim: int = 0, + end_dim: int = -1, + output_placements: tuple | None = None, + input_placements_expected: tuple | None = None, +) -> DTensor: + """Generalized shardwise flattening operation with validation. + + This function performs input validation and flattening operation. + + Parameters + ---------- + x : DTensor + Input DTensor to flatten. + start_dim : int, optional + First dimension to flatten. Default is 0. + end_dim : int, optional + Last dimension to flatten. Default is -1 (last dimension). + output_placements : tuple | None, optional + If provided, use these placements for the output DTensor instead of + computing them from the input placements. Default is None. + input_placements_expected : tuple | None, optional + If provided, skip validation by assuming input has these expected placements. + Must be present if output_placements is present, and absent if output_placements + is absent. Default is None. + + Returns + ------- + DTensor + Flattened DTensor. + + Raises + ------ + ValueError + Checks on the input x and parameters, or if placement argument constraints are violated. + NotImplementedError + If any dimension to be flattened is sharded. + """ + # Validate placement argument constraints + if (output_placements is None) != (input_placements_expected is None): + raise ValueError("input_placements_expected must be present if and only if output_placements is present") + + has_input_output_placements = input_placements_expected is not None and output_placements is not None + + # Normalize dimensions + start_dim_normalized = start_dim if start_dim >= 0 else start_dim + x.ndim + end_dim_normalized = end_dim if end_dim >= 0 else end_dim + x.ndim + + # Validate dimension ranges + if start_dim_normalized < 0 or start_dim_normalized >= x.ndim: + raise ValueError(f"start_dim {start_dim} is out of range for tensor with {x.ndim} dimensions") + if end_dim_normalized < 0 or end_dim_normalized >= x.ndim: + raise ValueError(f"end_dim {end_dim} is out of range for tensor with {x.ndim} dimensions") + if start_dim_normalized > end_dim_normalized: + raise ValueError(f"start_dim {start_dim} must be <= end_dim {end_dim}") + + # Check that no dimension to be flattened is sharded + # and there are no unevenly sharded tensor axes + # Also calculate new placements accounting for dimension changes (if not provided) + if has_input_output_placements: + # Check if input placements match expected placements + if x.placements != input_placements_expected: + raise ValueError( + f"Input placements {x.placements} do not match expected placements {input_placements_expected}" + ) + # Skip validation and use provided placements directly + new_placements = output_placements + else: + dims_removed = end_dim_normalized - start_dim_normalized + new_placements = [] + + for i_dim_device_mesh, placement in enumerate(x.placements): + if isinstance(placement, Shard): + i_dim_tensor = placement.dim + if start_dim_normalized <= i_dim_tensor <= end_dim_normalized: + raise NotImplementedError( + f"Flattening dimension {i_dim_tensor} sharded by device_mesh axis {i_dim_device_mesh} is not supported" + ) + if x.shape[i_dim_tensor] % x.device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {i_dim_tensor} of size {x.shape[i_dim_tensor]} " + f"along device mesh dimension {i_dim_device_mesh} of size {x.device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + + # Calculate new placement for this shard + if i_dim_tensor < start_dim_normalized: + # Dimension before flattened region - unchanged + new_placements.append(Shard(i_dim_tensor)) + elif i_dim_tensor > end_dim_normalized: + # Dimension after flattened region - shift left by dims_removed + new_placements.append(Shard(i_dim_tensor - dims_removed)) + # Dimensions within flattened region are handled by the validation above + elif isinstance(placement, Partial): + raise ValueError(f"Placements of type {Partial} are not supported") + else: + new_placements.append(placement) + + # Perform operation on local tensors + x_local = x.to_local() + output_local: Tensor = torch.flatten(x_local, start_dim=start_dim, end_dim=end_dim) + + # Compute output shape and stride + flattened_size = math.prod(x.shape[start_dim_normalized : (end_dim_normalized + 1)]) + shape_output = x.shape[:start_dim_normalized] + (flattened_size,) + x.shape[end_dim_normalized + 1 :] + # Use update_exhaustive_strides to compute new strides + strides_output = update_exhaustive_strides(output_local.shape, output_local.stride(), shape_output) + + # Create output DTensor using input tensor's device mesh and updated placements + output: DTensor = DTensor.from_local( + output_local, + device_mesh=x.device_mesh, + placements=tuple(new_placements), + shape=shape_output, + stride=strides_output, + ) + + return output + + +def _shardwise_unflatten( + x: DTensor, + dim: int, + sizes: tuple[int, ...], + output_placements: tuple | None = None, + input_placements_expected: tuple | None = None, +) -> DTensor: + """Generalized shardwise unflattening operation with validation. + + This function performs input validation and unflattening operation following + the torch.unflatten API. + + Parameters + ---------- + x : DTensor + Input flattened DTensor to unflatten. + dim : int + Dimension to unflatten. + sizes : tuple[int, ...] + Sizes for the new dimensions that will replace the specified dimension. + output_placements : tuple | None, optional + If provided, use these placements for the output DTensor instead of + computing them from the input placements. Default is None. + input_placements_expected : tuple | None, optional + If provided, skip validation by assuming input has these expected placements. + Must be present if output_placements is present, and absent if output_placements + is absent. Default is None. + + Returns + ------- + DTensor + Unflattened DTensor with expanded dimensions. + + Raises + ------ + ValueError + Checks on the input x and parameters, or if placement argument constraints are violated. + """ + # Validate placement argument constraints + if (output_placements is None) != (input_placements_expected is None): + raise ValueError("input_placements_expected must be present if and only if output_placements is present") + + has_input_output_placements = input_placements_expected is not None and output_placements is not None + + # Normalize dimension + dim_normalized = dim if dim >= 0 else dim + x.ndim + + # Validate dimension range + if dim_normalized < 0 or dim_normalized >= x.ndim: + raise ValueError(f"dim {dim} is out of range for tensor with {x.ndim} dimensions") + + # Also calculate new placements accounting for dimension changes (if not provided) + if has_input_output_placements: + # Check if input placements match expected placements + if x.placements != input_placements_expected: + raise ValueError( + f"Input placements {x.placements} do not match expected placements {input_placements_expected}" + ) + # Skip validation and use provided placements directly + new_placements = output_placements + else: + dims_added = len(sizes) - 1 + new_placements = [] + + for i_dim_device_mesh, placement in enumerate(x.placements): + # Check that the dimension to unflatten is not sharded + # and there are no unevenly sharded tensor axes + if isinstance(placement, Shard): + i_dim_tensor = placement.dim + if x.shape[i_dim_tensor] % x.device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {i_dim_tensor} of size {x.shape[i_dim_tensor]} " + f"along device mesh dimension {i_dim_device_mesh} of size {x.device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + if i_dim_tensor == dim_normalized: + raise NotImplementedError( + f"Unflattening dimension {dim} shared by device_mesh axis {i_dim_device_mesh} is not supported" + ) + elif i_dim_tensor < dim_normalized: + # Dimension before unflattened region - unchanged + new_placements.append(Shard(i_dim_tensor)) + else: + # Dimension after unflattened region - shift right by dims_added + new_placements.append(Shard(i_dim_tensor + dims_added)) + elif isinstance(placement, Partial): + raise ValueError(f"Placements of type {Partial} are not supported") + else: + new_placements.append(placement) + + # Perform operation on local tensors + x_local = x.to_local() + output_local: Tensor = torch.unflatten(x_local, dim=dim, sizes=sizes) + + # compute the output DTensor's shape and stride + shape_output = list(x.shape) + dim_to_insert = dim_normalized + shape_output.pop(dim_to_insert) + shape_output[dim_to_insert:dim_to_insert] = sizes + # torch unflatten enforce contiguous layout, which is LayoutRight + layout_right = LayoutRightMap(shape_output) + strides_output = layout_right.strides + + # Create output DTensor using input tensor's device mesh and updated placements + output: DTensor = DTensor.from_local( + output_local, + device_mesh=x.device_mesh, + placements=tuple(new_placements), + shape=tuple(shape_output), + stride=strides_output, + ) + + return output + + +class _ShardWiseFlattenImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx: FunctionCtx, + x: DTensor, + start_dim: int = 0, + end_dim: int = -1, + ) -> DTensor: + """Forward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object for saving information needed in backward pass. + x : DTensor + Input DTensor to flatten. + start_dim : int, optional + First dimension to flatten. Default is 0. + end_dim : int, optional + Last dimension to flatten. Default is -1 (last dimension). + + Returns + ------- + DTensor + Flattened DTensor. + """ + # Normalize dimensions before flattening + start_dim_normalized = start_dim if start_dim >= 0 else start_dim + x.ndim + end_dim_normalized = end_dim if end_dim >= 0 else end_dim + x.ndim + + # Perform flattening with built-in validation + result = _shardwise_flatten(x, start_dim, end_dim) + + # Save metadata for backward pass + # For unflattening, we need the dimension in the flattened tensor and the original sizes + ctx.unflatten_dim = start_dim_normalized # This is the flattened dimension position + ctx.unflatten_sizes = x.shape[start_dim_normalized : end_dim_normalized + 1] # Original sizes + ctx.device_mesh_input = x.device_mesh + ctx.placements_input = x.placements + ctx.placements_output = result.placements + + return result + + @staticmethod + def backward( + ctx: FunctionCtx, + grad_output: DTensor, + ) -> tuple[DTensor, None, None]: + """Backward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object containing saved tensors and metadata from forward pass. + grad_output : DTensor + Gradient of the loss with respect to the output. + + Returns + ------- + tuple[DTensor, None, None] + Gradient with respect to input, None for start_dim and end_dim parameters. + """ + # Use unflatten operation (inverse of flatten) for backward pass with built-in validation + grad_x = _shardwise_unflatten( + grad_output, + ctx.unflatten_dim, + ctx.unflatten_sizes, + output_placements=ctx.placements_input, + input_placements_expected=ctx.placements_output, + ) + + return grad_x, None, None + + +class _ShardWiseUnflattenImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx: FunctionCtx, + x: DTensor, + dim: int, + sizes: tuple[int, ...], + ) -> DTensor: + """Forward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object for saving information needed in backward pass. + x : DTensor + Input DTensor to unflatten. + dim : int + Dimension to unflatten. + sizes : tuple[int, ...] + Sizes for the new dimensions that will replace the specified dimension. + + Returns + ------- + DTensor + Unflattened DTensor. + """ + # Normalize dimension before unflattening + dim_normalized = dim if dim >= 0 else dim + x.ndim + + # Perform unflattening with built-in validation + result = _shardwise_unflatten(x, dim, sizes) + + # Save metadata for backward pass + # For flattening, we need the start and end dimensions in the unflattened tensor + ctx.flatten_start_dim = dim_normalized # This is the start dimension for flattening + ctx.flatten_end_dim = dim_normalized + len(sizes) - 1 # This is the end dimension for flattening + ctx.device_mesh_input = x.device_mesh + ctx.placements_input = x.placements + ctx.placements_output = result.placements + + return result + + @staticmethod + def backward( + ctx: FunctionCtx, + grad_output: DTensor, + ) -> tuple[DTensor, None, None]: + """Backward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object containing saved tensors and metadata from forward pass. + grad_output : DTensor + Gradient of the loss with respect to the output. + + Returns + ------- + tuple[DTensor, None, None] + Gradient with respect to input, None for dim and sizes parameters. + """ + # Use flatten operation (inverse of unflatten) for backward pass with built-in validation + grad_x = _shardwise_flatten( + grad_output, + ctx.flatten_start_dim, + ctx.flatten_end_dim, + output_placements=ctx.placements_input, + input_placements_expected=ctx.placements_output, + ) + + return grad_x, None, None + + +def shardwise_flatten(x: DTensor, start_dim: int = 0, end_dim: int = -1) -> DTensor: + """Flatten a DTensor along specified dimensions. + + This function flattens a DTensor from start_dim to end_dim (inclusive). + The dimensions to be flattened must not be sharded on the device mesh. + + Parameters + ---------- + x : DTensor + Input DTensor to flatten. + start_dim : int, optional + First dimension to flatten. Default is 0. + end_dim : int, optional + Last dimension to flatten. Default is -1 (last dimension). + + Returns + ------- + DTensor + Flattened DTensor. + + Raises + ------ + ValueError + If any specified dimension is sharded or other validation errors. + NotImplementedError + If any dimension to be flattened is sharded. + """ + return _ShardWiseFlattenImpl.apply(x, start_dim, end_dim) + + +def shardwise_unflatten(x: DTensor, dim: int, sizes: tuple[int, ...]) -> DTensor: + """Unflatten a DTensor along a specified dimension. + + This function unflattens a DTensor by expanding the specified dimension + into multiple dimensions with the given sizes. The dimension to be + unflattened must not be sharded on the device mesh. + + Parameters + ---------- + x : DTensor + Input DTensor to unflatten. + dim : int + Dimension to unflatten. + sizes : tuple[int, ...] + Sizes for the new dimensions that will replace the specified dimension. + + Returns + ------- + DTensor + Unflattened DTensor. + + Raises + ------ + ValueError + If validation errors occur during unflattening or if placement argument + constraints are violated. + NotImplementedError + If the dimension to be unflattened is sharded. + """ + return _ShardWiseUnflattenImpl.apply(x, dim, sizes) + + +def _shardwise_unflatten_sharded_impl(input: DTensor, dim: int, sizes: tuple[int, ...]) -> DTensor: + """Unflatten a sharded DTensor along a specified dimension. + + This function splits a single dimension into multiple dimensions, similar to + torch.Tensor.unflatten, but designed for sharded DTensors. The input must be + sharded along the specified dimension, and sizes[0] must be evenly divisible + by the device mesh size so that the resulting DTensor is again sharded along + the same dimension. + + This is the inverse operation of _shardwise_flatten_sharded_impl. + + Args: + input: Input DTensor to unflatten. Must be sharded along `dim`. + dim: The dimension to unflatten. Must correspond to a sharded dimension + in the input's placements. If negative, wraps around. + sizes: Tuple of integers specifying the shape to unflatten into. + The product of sizes must equal input.shape[dim]. + sizes[0] must be evenly divisible by the device mesh size. + + Returns: + DTensor with the specified dimension unflattened into multiple dimensions. + The output shape is input.shape[:dim] + sizes + input.shape[dim+1:]. + + Raises: + TypeError: If input is not a DTensor, dim is not an int, or sizes + is not a tuple. + ValueError: If input has Partial placements, if sizes[0] is not + evenly shardable, if the product of sizes doesn't match the + dim size, or if input is not sharded along dim. + """ + if not isinstance(input, DTensor): + raise TypeError(f"Expected DTensor, got {type(input)}") + if not isinstance(dim, int): + raise TypeError(f"Expected int for dim, got {type(dim)}") + if not isinstance(sizes, tuple): + raise TypeError(f"Expected tuple for sizes, got {type(sizes)}") + if len(sizes) < 2: + raise ValueError("Must provide at least two dimensions for unflattening") + + ndim = input.ndim + + # Normalize dimension + if dim < 0: + dim = ndim + dim + + if not (0 <= dim < ndim): + raise ValueError(f"dim {dim} out of range for {ndim}D tensor") + + device_mesh = input.device_mesh + placements = input.placements + + i_mesh_dim_shard_dim = None + for i_dim_device_mesh, placement in enumerate(placements): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + if input.shape[placement.dim] % device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {input.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + if placement.dim == dim: + i_mesh_dim_shard_dim = i_dim_device_mesh + + if i_mesh_dim_shard_dim is None: + raise ValueError(f"input is not sharded along dim {dim}") + + size_expected = math.prod(sizes) + if size_expected != input.shape[dim]: + raise ValueError(f"Expected size {size_expected} but got {input.shape[dim]}") + + size_group = device_mesh.size(i_mesh_dim_shard_dim) + + # sizes[0] will become the new shape[dim] and should be evenly sharded + if sizes[0] % size_group != 0: + raise ValueError( + f"sizes[0] {sizes[0]} must be evenly sharded along device mesh dimension {i_mesh_dim_shard_dim} of size {size_group}" + ) + + # input.shape[dim] // size_group flattened and sharded into: (sizes[0] // size_group, sizes[1:]) + output_local = input.to_local().unflatten(dim, (sizes[0] // size_group, *sizes[1:])) + + shape_output = input.shape[:dim] + sizes + input.shape[dim + 1 :] + strides_output = update_exhaustive_strides(output_local.shape, output_local.stride(), shape_output) + + # Adjust Shard dim indices for dimensions shifted by the unflatten. + # Splitting dim into len(sizes) parts adds (len(sizes) - 1) new dims, + # so any Shard(d) with d > dim must shift up by that amount. + n_dims_added = len(sizes) - 1 + output_placements: tuple[Placement, ...] = tuple( + Shard(p.dim + n_dims_added) if isinstance(p, Shard) and p.dim > dim else p for p in placements + ) + + output: DTensor = DTensor.from_local( + output_local, device_mesh, output_placements, shape=shape_output, stride=strides_output + ) + + return output + + +def _shardwise_flatten_sharded_impl(input: DTensor, start_dim: int, end_dim: int) -> DTensor: + """Flatten consecutive dimensions of a sharded DTensor. + + This function flattens dimensions from start_dim to end_dim (inclusive) into + a single dimension. The input must be sharded along start_dim, and the + sharding is preserved on the flattened output dimension. + + This is the inverse operation of _shardwise_unflatten_sharded_impl. + + Args: + input: Input DTensor to flatten. Must be sharded along `start_dim`. + start_dim: First dimension to flatten. + end_dim: Last dimension to flatten (inclusive). If negative, wraps around. + + Returns: + DTensor with dimensions [start_dim, end_dim] flattened into a single + dimension at position start_dim. + + Raises: + TypeError: If input is not a DTensor, or start_dim/end_dim are not int. + ValueError: If input has Partial placements, if start_dim is not sharded, + or if dimension indices are invalid. + """ + if not isinstance(input, DTensor): + raise TypeError(f"Expected DTensor, got {type(input)}") + if not isinstance(start_dim, int): + raise TypeError(f"Expected int for start_dim, got {type(start_dim)}") + if not isinstance(end_dim, int): + raise TypeError(f"Expected int for end_dim, got {type(end_dim)}") + + ndim = input.ndim + + # Normalize dimensions + if start_dim < 0: + start_dim = ndim + start_dim + if end_dim < 0: + end_dim = ndim + end_dim + + if not (0 <= start_dim < ndim): + raise ValueError(f"start_dim {start_dim} out of range for {ndim}D tensor") + if not (0 <= end_dim < ndim): + raise ValueError(f"end_dim {end_dim} out of range for {ndim}D tensor") + if start_dim > end_dim: + raise ValueError(f"start_dim {start_dim} must be <= end_dim {end_dim}") + + device_mesh = input.device_mesh + placements = input.placements + + i_mesh_dim_shard_start = None + for i_dim_device_mesh, placement in enumerate(placements): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + if input.shape[placement.dim] % device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {input.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + if placement.dim == start_dim: + i_mesh_dim_shard_start = i_dim_device_mesh + + if i_mesh_dim_shard_start is None: + raise ValueError(f"input is not sharded along start_dim {start_dim}") + + # Flatten locally + output_local = input.to_local().flatten(start_dim=start_dim, end_dim=end_dim) + + # Compute global output shape + flattened_size = math.prod(input.shape[start_dim : end_dim + 1]) + shape_output = input.shape[:start_dim] + (flattened_size,) + input.shape[end_dim + 1 :] + strides_output = update_exhaustive_strides(output_local.shape, output_local.stride(), shape_output) + + # Adjust Shard dim indices for dimensions shifted by the flatten. + # Merging dims [start_dim, end_dim] removes (end_dim - start_dim) dims, + # so any Shard(d) with d > end_dim must shift down by that amount. + n_dims_removed = end_dim - start_dim + output_placements: tuple[Placement, ...] = tuple( + Shard(p.dim - n_dims_removed) if isinstance(p, Shard) and p.dim > end_dim else p for p in placements + ) + + output: DTensor = DTensor.from_local( + output_local, device_mesh, output_placements, shape=shape_output, stride=strides_output + ) + + return output + + +class ShardwiseUnflattenShardedImpl(torch.autograd.Function): + """Autograd function to unflatten a sharded DTensor along a specified axis. + + This function performs an unflatten operation on a DTensor while preserving + the sharding semantics. The input must be sharded along the specified axis, + and the first element of `sizes` must be evenly divisible by the device mesh + size along the sharding dimension. + + Example: + If input has global shape (B, N) sharded along axis=1 across 2 ranks, + and sizes=(K, W), the output will have global shape (B, K, W) still + sharded along axis=1. Each rank holds (B, K//2, W) locally. + """ + + @staticmethod + def forward(ctx: FunctionCtx, input: DTensor, axis: int, sizes: tuple[int, ...]) -> DTensor: + """Forward pass: unflatten the DTensor along the specified axis. + + Args: + ctx: Autograd context for saving tensors and metadata for backward. + input: Input DTensor to unflatten. Must be sharded along `axis`. + axis: The axis along which to unflatten. Must correspond to a + sharded dimension in the input's placements. + sizes: Tuple of integers specifying the new shape for the unflattened + dimension. The product of sizes must equal input.shape[axis]. + sizes[0] must be evenly divisible by the device mesh size. + + Returns: + DTensor with the specified axis unflattened into multiple dimensions. + The output shape is input.shape[:axis] + sizes + input.shape[axis+1:]. + Sharding is preserved along the first unflattened dimension. + + Raises: + TypeError: If input is not a DTensor, axis is not an int, or sizes + is not a tuple. + ValueError: If input has Partial placements, if sizes[0] is not + evenly shardable, if the product of sizes doesn't match the + axis dimension, or if input is not sharded along axis. + """ + output = _shardwise_unflatten_sharded_impl(input, dim=axis, sizes=sizes) + ctx.axis = axis + ctx.sizes = sizes + return output + + @staticmethod + def backward(ctx: FunctionCtx, grad_output: DTensor) -> tuple[DTensor, None, None]: + """Backward pass: flatten the gradient back to the original input shape. + + Args: + ctx: Autograd context containing saved tensors and metadata from forward. + grad_output: Gradient DTensor with respect to the forward output. + Must have the same device_mesh, placements, and shape as the + forward output. + + Returns: + Tuple of (grad_input, None, None) where grad_input is the gradient + with respect to the input DTensor, and the None values correspond + to the non-differentiable axis and sizes arguments. + + Raises: + TypeError: If grad_output is not a DTensor. + ValueError: If grad_output has mismatched device_mesh, placements, + or shape compared to the forward output. + """ + axis = ctx.axis + sizes = ctx.sizes + end_dim = axis + len(sizes) - 1 + grad_input = _shardwise_flatten_sharded_impl(grad_output, start_dim=axis, end_dim=end_dim) + return grad_input, None, None + + +def shardwise_unflatten_sharded(input: DTensor, axis: int, sizes: tuple[int, ...]) -> DTensor: + """Unflatten a sharded DTensor along a specified axis while preserving sharding. + + This function reshapes a DTensor by splitting a single dimension into multiple + dimensions, similar to torch.Tensor.unflatten, but designed to work correctly + with distributed tensors that are sharded along the unflattened axis. + + The key constraint is that the first element of `sizes` must be evenly + divisible by the number of ranks sharding the axis, ensuring the resulting + tensor can maintain valid sharding semantics. + + Args: + input: Input DTensor to unflatten. Must be sharded along `axis` with + even sharding (no remainder when dividing by device mesh size). + axis: The axis along which to unflatten. Must be a sharded dimension. + sizes: Tuple of integers specifying the shape to unflatten into. + Must satisfy: prod(sizes) == input.shape[axis] and + sizes[0] % device_mesh_size == 0. + + Returns: + DTensor with shape input.shape[:axis] + sizes + input.shape[axis+1:]. + The tensor remains sharded along the same device mesh dimension, + with the sharding now applying to the first element of sizes. + + Example: + >>> # input: DTensor of global shape (4, 128) sharded on axis=1 across 2 ranks + >>> # Each rank holds (4, 64) locally + >>> output = shardwise_unflatten_sharded(input, axis=1, sizes=(16, 8)) + >>> # output: DTensor of global shape (4, 16, 8) sharded on axis=1 + >>> # Each rank holds (4, 8, 8) locally + """ + return ShardwiseUnflattenShardedImpl.apply(input, axis, sizes) + + +class ShardwiseFlattenShardedImpl(torch.autograd.Function): + """Autograd function to flatten consecutive dimensions of a sharded DTensor. + + This function performs a flatten operation on a DTensor while preserving + the sharding semantics. The input must be sharded along start_dim, and the + sharding is preserved on the flattened output dimension. + + Example: + If input has global shape (B, K, W) sharded along axis=1 across 2 ranks, + and we flatten dims 1 and 2, the output will have global shape (B, K*W) + still sharded along axis=1. Each rank holds (B, K*W//2) locally. + """ + + @staticmethod + def forward(ctx: FunctionCtx, input: DTensor, start_dim: int, end_dim: int) -> DTensor: + """Forward pass: flatten the DTensor along the specified dimensions. + + Args: + ctx: Autograd context for saving tensors and metadata for backward. + input: Input DTensor to flatten. Must be sharded along `start_dim`. + start_dim: First dimension to flatten. + end_dim: Last dimension to flatten (inclusive). If negative, wraps around. + + Returns: + DTensor with dimensions [start_dim, end_dim] flattened into a single + dimension at position start_dim. + + Raises: + TypeError: If input is not a DTensor, or start_dim/end_dim are not int. + ValueError: If input has Partial placements, if start_dim is not sharded, + or if dimension indices are invalid. + """ + # Normalize end_dim for storing in context + ndim = input.ndim + if end_dim < 0: + end_dim = ndim + end_dim + + output = _shardwise_flatten_sharded_impl(input, start_dim=start_dim, end_dim=end_dim) + + # Save the sizes of the flattened dimensions for backward (unflatten) + ctx.start_dim = start_dim if start_dim >= 0 else ndim + start_dim + ctx.sizes = tuple(input.shape[ctx.start_dim : end_dim + 1]) + return output + + @staticmethod + def backward(ctx: FunctionCtx, grad_output: DTensor) -> tuple[DTensor, None, None]: + """Backward pass: unflatten the gradient back to the original input shape. + + Args: + ctx: Autograd context containing saved tensors and metadata from forward. + grad_output: Gradient DTensor with respect to the forward output. + Must have the same device_mesh, placements, and shape as the + forward output. + + Returns: + Tuple of (grad_input, None, None) where grad_input is the gradient + with respect to the input DTensor, and the None values correspond + to the non-differentiable start_dim and end_dim arguments. + + Raises: + TypeError: If grad_output is not a DTensor. + ValueError: If grad_output has mismatched device_mesh, placements, + or shape compared to the forward output. + """ + start_dim = ctx.start_dim + sizes = ctx.sizes + grad_input = _shardwise_unflatten_sharded_impl(grad_output, dim=start_dim, sizes=sizes) + return grad_input, None, None + + +def shardwise_flatten_sharded(input: DTensor, start_dim: int, end_dim: int) -> DTensor: + """Flatten consecutive dimensions of a sharded DTensor while preserving sharding. + + This function reshapes a DTensor by merging multiple dimensions into a single + dimension, similar to torch.Tensor.flatten, but designed to work correctly + with distributed tensors that are sharded along the start_dim. + + This is the inverse operation of shardwise_unflatten_sharded. + + Args: + input: Input DTensor to flatten. Must be sharded along `start_dim` with + even sharding (no remainder when dividing by device mesh size). + start_dim: First dimension to flatten. Must be a sharded dimension. + end_dim: Last dimension to flatten (inclusive). If negative, wraps around. + + Returns: + DTensor with dimensions [start_dim, end_dim] merged into a single dimension. + The tensor remains sharded along the same device mesh dimension. + + Example: + >>> # input: DTensor of global shape (4, 16, 8) sharded on axis=1 across 2 ranks + >>> # Each rank holds (4, 8, 8) locally + >>> output = shardwise_flatten_sharded(input, start_dim=1, end_dim=2) + >>> # output: DTensor of global shape (4, 128) sharded on axis=1 + >>> # Each rank holds (4, 64) locally + """ + return ShardwiseFlattenShardedImpl.apply(input, start_dim, end_dim) diff --git a/src/boltz/distributed/model/layers/gather.py b/src/boltz/distributed/model/layers/gather.py new file mode 100644 index 000000000..f5c059715 --- /dev/null +++ b/src/boltz/distributed/model/layers/gather.py @@ -0,0 +1,490 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from typing import Dict, List, Tuple + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor, Shard +from torch.distributed.tensor import Partial + +from boltz.distributed.model.layers.outer_gather import get_overlap_from_peers +from boltz.distributed.utils import update_exhaustive_strides + + +class DistributedGather(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x_dtensor: DTensor, + idx_dtensor: DTensor, + axis: int, + are_ids_contiguous: bool, + idx_mask: DTensor | None = None, + ) -> DTensor: + """Distributed 1D gather. + + Args: + x_dtensor: DTensor with shape ``(*batch, N, *features)``. + idx_dtensor: DTensor with shape ``(*batch, K, W)`` that provides gather + indices into the ``N`` dimension of ``x_dtensor``. + axis: Axis in ``x_dtensor`` corresponding to ``N`` (and to ``K`` in + ``idx_dtensor``). + are_ids_contiguous: This is a heuristic for selecting the underlying + send/recv strategy for performance purpose. Currently only True is supported, + which means that the idx_dtensor maps to a contiguous block of x along (axis) + dimensions for all the shards and for all the leading (batch) dimensions. + When True, the underlying strategy will use the min/max of idx_dtensor to + compute the needed interval, assuming the resulting buffer to be communicated + across the ranks is fully (or approximately so) utilized. In the case the + data inside idx_dtensor doesn't mapped to contiguous blocks, the result will + still be correct by setting are_ids_contiguous=True but the buffer of x chunks + communicated will contain a lot of unused elements, making the sed/recv inefficient. + idx_mask: Optional DTensor with shape ``(*batch, K, W)`` and same device_mesh + and placements as ``idx_dtensor``. Elements with True indicate valid indices, + elements with False indicate invalid indices that should be ignored. + """ + if not are_ids_contiguous: + raise NotImplementedError("DistributedGather currently only supports are_ids_contiguous=True") + + if not isinstance(x_dtensor, DTensor) or not isinstance(idx_dtensor, DTensor): + raise TypeError("x_dtensor and idx_dtensor must be DTensors") + + batch_dims_x = x_dtensor.shape[:axis] + if batch_dims_x != idx_dtensor.shape[:axis]: + raise ValueError(f"Batch dimensions must match: x {batch_dims_x} vs idx {idx_dtensor.shape[:axis]}") + + mesh = x_dtensor.device_mesh + if idx_dtensor.device_mesh != mesh: + raise ValueError("x and idx must be on the same DeviceMesh") + if idx_dtensor.placements != x_dtensor.placements: + raise ValueError("x and idx must have identical placements") + + # Validate idx_mask if provided + if idx_mask is not None: + if not isinstance(idx_mask, DTensor): + raise TypeError("idx_mask must be a DTensor") + if idx_mask.shape != idx_dtensor.shape: + raise ValueError(f"idx_mask shape {idx_mask.shape} must match idx_dtensor shape {idx_dtensor.shape}") + if idx_mask.device_mesh != idx_dtensor.device_mesh: + raise ValueError("idx_mask must have the same device_mesh as idx_dtensor") + if idx_mask.placements != idx_dtensor.placements: + raise ValueError("idx_mask must have the same placements as idx_dtensor") + if idx_mask.dtype != torch.bool: + raise TypeError( + f"idx_mask must have dtype torch.bool, got {idx_mask.dtype}. Use mask.bool() to convert." + ) + + x_placements = x_dtensor.placements + idx_placements = idx_dtensor.placements + + ndim_x = x_dtensor.ndim + if axis < 0: + axis += ndim_x + if axis < 0 or axis >= ndim_x: + raise ValueError(f"axis {axis} out of range for x.ndim={ndim_x}") + + # Identify shard axis on mesh for the gather dimension + mesh_dim_axis = None + for i_mesh_dim, p in enumerate(x_placements): + if isinstance(p, Partial) or isinstance(idx_placements[i_mesh_dim], Partial): + raise ValueError("Partial placements are not supported") + if isinstance(p, Shard): + if p.dim == axis: + mesh_dim_axis = i_mesh_dim + # Enforce even sharding on axis for both x and idx since we require identical device_mesh + # and placements between the two + if x_dtensor.shape[p.dim] % mesh.size(i_mesh_dim) != 0: + raise ValueError( + f"x_dtensor axis {p.dim} size {x_dtensor.shape[p.dim]} not evenly divisible by mesh dim {i_mesh_dim}" + ) + if idx_dtensor.shape[p.dim] % mesh.size(i_mesh_dim) != 0: + raise ValueError( + f"idx_dtensor axis {p.dim} size {idx_dtensor.shape[p.dim]} not evenly divisible by mesh dim {i_mesh_dim}" + ) + + if mesh_dim_axis is None: + raise ValueError(f"x must be sharded along axis {axis}") + + x_local = x_dtensor.to_local() + idx_local = idx_dtensor.to_local() + idx_mask_local = idx_mask.to_local() if idx_mask is not None else None + device = x_local.device + cpu_device = torch.device("cpu") + + # Determine needed interval along axis from local idx + # need interval is required to have a singleton axis of 1 (representing 1-d gathering) + # to be used in get_overlap_from_peers + if idx_local.numel() > 0: + if idx_mask_local is not None: + # Only consider valid indices for interval computation + if idx_mask_local.any(): + valid_idx = idx_local[idx_mask_local] + need_interval = torch.stack(valid_idx.aminmax()).to(dtype=torch.long).unsqueeze(0) # (1,2) + need_interval[:, -1] += 1 + else: + # All indices are masked out + need_interval = torch.tensor([[0, 0]], device=device, dtype=torch.long) # (1,2) + else: + # aminmax return end-inclusive interval of shape (2,) + need_interval = torch.stack(idx_local.aminmax()).to(dtype=torch.long).unsqueeze(0) # (1,2) + need_interval[:, -1] += 1 + else: + need_interval = torch.tensor([[0, 0]], device=device, dtype=torch.long) # (1,2) + need_start = need_interval[0, 0] + need_end = need_interval[0, 1] + need_start_cpu = need_start.to(cpu_device) + + # Owned chunk interval + coord_axis = mesh.get_local_rank(mesh_dim_axis) + chunk_size = x_dtensor.shape[axis] // mesh.size(mesh_dim_axis) + own_start = torch.tensor(coord_axis * chunk_size, device=cpu_device, dtype=torch.long) + own_end = own_start + chunk_size + own_interval = torch.stack([own_start, own_end]).unsqueeze(0) # (1,2) + + # All-gather need intervals along sharded mesh dim (metadata only) + group_axis = mesh.get_group(mesh_dim_axis) + need_range = [torch.zeros_like(need_interval) for _ in range(mesh.size(mesh_dim_axis))] + dist.all_gather(need_range, need_interval, group=group_axis) + need_range = torch.stack(need_range) # (size_group,1,2) + need_range_cpu = need_range.cpu() + + ranks_global_on_mesh = mesh.mesh + my_coords = mesh.get_coordinate() + index_list_submesh = [] + for dim in range(ranks_global_on_mesh.ndim): + if dim == mesh_dim_axis: + index_list_submesh.append(slice(None)) + else: + index_list_submesh.append(torch.tensor(my_coords[dim], device=cpu_device)) + ranks_global_on_submesh = ranks_global_on_mesh[tuple(index_list_submesh)] # (size_group,) + + size_group = mesh.size(mesh_dim_axis) + + # RECEIVE PLAN + if need_start >= need_end: + needed_chunks: List[Dict[str, torch.Tensor | int]] = [] + else: + # peers own intervals + start_peers_own = torch.arange(size_group, device=cpu_device, dtype=torch.long) * chunk_size + end_peers_own = start_peers_own + chunk_size + interval_peers_own = torch.stack([start_peers_own, end_peers_own], dim=-1).unsqueeze(1) # (size_group,1,2) + + needed_chunks = get_overlap_from_peers( + ranks_global_on_submesh, interval_peers_own, need_interval.to(cpu_device) + ) + + ops = [] + recv_bufs = {} + recv_metadata_for_bwd = [] + + for item in needed_chunks: + peer = item["peer"] + interval = item["interval"] # (1,2) + start_global = interval[0, 0] + length = interval[0, 1] - interval[0, 0] + + shape = list(x_local.shape) + shape[axis] = length.item() + buf = torch.empty(shape, dtype=x_local.dtype, device=device) + + if peer == dist.get_rank(): + start_local = start_global - own_start + buf.copy_(x_local.narrow(axis, start_local.item(), length.item())) + recv_bufs[peer] = buf + recv_metadata_for_bwd.append((peer, interval, shape)) + else: + ops.append(dist.P2POp(dist.irecv, buf, peer)) + recv_bufs[peer] = buf + recv_metadata_for_bwd.append((peer, interval, shape)) + + # SEND PLAN + send_metadata_for_bwd = [] + send_chunks = get_overlap_from_peers( + ranks_global_on_submesh, + need_range_cpu.view(size_group, 1, 2), + own_interval.view(1, 1, 2), + ) + + my_rank = dist.get_rank() + for item in send_chunks: + peer = item["peer"] + if peer == my_rank: + continue + interval = item["interval"] + start_global = interval[0, 0] + length = interval[0, 1] - interval[0, 0] + start_local = start_global - own_start + chunk = x_local.narrow(axis, start_local.item(), length.item()).contiguous() + ops.append(dist.P2POp(dist.isend, chunk, peer)) + send_metadata_for_bwd.append( + ( + peer, + start_local.to(cpu_device, dtype=torch.long), + length.to(cpu_device, dtype=torch.long), + ) + ) + + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # Assemble buffer + if need_start >= need_end: + buffer_shape = list(x_local.shape) + buffer_shape[axis] = 0 + x_buffer = torch.empty(buffer_shape, dtype=x_local.dtype, device=device) + else: + buffer_shape = list(x_local.shape) + buffer_shape[axis] = (need_end - need_start).item() + x_buffer = torch.zeros(buffer_shape, dtype=x_local.dtype, device=device) + + for item in needed_chunks: + peer = item["peer"] + interval = item["interval"] + buf = recv_bufs.get(peer) + if buf is None: + raise RuntimeError(f"Missing recv buffer for peer {peer}") + + start_global = interval[0, 0] + length = interval[0, 1] - interval[0, 0] + start_local_buf = start_global - need_start_cpu + target = x_buffer.narrow(axis, start_local_buf.item(), length.item()) + target.copy_(buf) + + # Adjust idx into buffer coords + idx_local_buffer = idx_local - need_start.item() + + # Local computation using linearized gather over (axis, feature) block + shape_trailing = x_buffer.shape[axis + 1 :] + shape_trailing_flat = torch.Size(shape_trailing).numel() + + shape_leading = x_buffer.shape[:axis] + shape_leading_flat = torch.Size(shape_leading).numel() + L = x_buffer.shape[axis] + K = idx_local_buffer.shape[-2] + W = idx_local_buffer.shape[-1] + + # Handle edge case where buffer is empty (all indices masked out) + if L == 0: + out_shape = list(shape_leading) + [K, W] + list(shape_trailing) + out_local = torch.zeros(out_shape, dtype=x_local.dtype, device=device) + else: + # For masked indices, clamp to valid buffer range to avoid index errors + if idx_mask_local is not None: + idx_local_buffer = torch.where(idx_mask_local, idx_local_buffer, torch.zeros_like(idx_local_buffer)) + + x_flat = x_buffer.reshape(shape_leading_flat, L, shape_trailing_flat) + idx_flat = idx_local_buffer.reshape(shape_leading_flat, K, W) + + # Advanced indexing: + # dim 0: arange(B) reshaped to (B, 1, 1) to broadcast against (B, K, W) + batch_idx = torch.arange(shape_leading_flat, device=device).reshape(shape_leading_flat, 1, 1) + + # This performs the gather + # x_flat[ (B,1,1), (B,K,W), : ] -> (B, K, W, F) + out_flat = x_flat[batch_idx, idx_flat, :] + + # Zero out invalid positions + if idx_mask_local is not None: + mask_flat = idx_mask_local.reshape(shape_leading_flat, K, W, 1).to(out_flat.dtype) + out_flat = out_flat * mask_flat + + if shape_trailing: + out_local = out_flat.reshape(*shape_leading, K, W, *shape_trailing) + else: + out_local = out_flat.reshape(*shape_leading, K, W) + + out_global_shape = list(idx_dtensor.shape) + list(shape_trailing) + final_global_shape = tuple(out_global_shape) + + strides_out = update_exhaustive_strides(out_local.shape, out_local.stride(), final_global_shape) + + out_dtensor = DTensor.from_local( + out_local, idx_dtensor.device_mesh, idx_dtensor.placements, shape=final_global_shape, stride=strides_out + ) + + if idx_mask_local is not None: + ctx.save_for_backward(idx_local, idx_mask_local) + else: + ctx.save_for_backward(idx_local) + ctx.has_mask = idx_mask_local is not None + ctx.comm_meta = { + "recv_metadata_for_bwd": recv_metadata_for_bwd, + "send_metadata_for_bwd": send_metadata_for_bwd, + "x_local_shape": x_local.shape, + "x_buffer_shape": x_buffer.shape, + "need_interval": need_interval, + "axis": axis, + "x_placements": x_placements, + "x_global_shape": x_dtensor.shape, + "output_placements": out_dtensor.placements, + "own_interval": own_interval, + "device_mesh_output": out_dtensor.device_mesh, + } + + return out_dtensor + + @staticmethod + def backward(ctx, grad_output: DTensor) -> Tuple[DTensor, None, None, None, None]: + if ctx.has_mask: + idx_local, idx_mask_local = ctx.saved_tensors + else: + (idx_local,) = ctx.saved_tensors + idx_mask_local = None + meta = ctx.comm_meta + recv_meta = meta["recv_metadata_for_bwd"] + send_meta = meta["send_metadata_for_bwd"] + x_local_shape = meta["x_local_shape"] + x_buffer_shape = meta["x_buffer_shape"] + need_interval = meta["need_interval"] + need_start = need_interval[0, 0] + axis = meta["axis"] + x_placements = meta["x_placements"] + output_placements = meta["output_placements"] + x_global_shape = meta["x_global_shape"] + own_interval = meta["own_interval"] + device_mesh_output = meta["device_mesh_output"] + + if device_mesh_output != grad_output.device_mesh: + raise ValueError( + f"grad_output device_mesh mismatch: expected {device_mesh_output}, got {grad_output.device_mesh}" + ) + if output_placements != grad_output.placements: + raise ValueError( + f"grad_output placements mismatch: expected {output_placements}, got {grad_output.placements}" + ) + + grad_local = grad_output.to_local().contiguous() + + local_idx = idx_local - need_start.item() + shape_trailing = x_buffer_shape[axis + 1 :] + shape_trailing_flat = torch.Size(shape_trailing).numel() + + shape_leading = x_buffer_shape[:axis] + shape_leading_flat = torch.Size(shape_leading).numel() + L = x_buffer_shape[axis] + K = local_idx.shape[-2] + W = local_idx.shape[-1] + + grad_flat = grad_local.reshape(shape_leading_flat, K * W, shape_trailing_flat) + + # Handle edge case where buffer is empty (all indices masked out) + if L == 0: + grad_x_buffer = torch.zeros( + *shape_leading, 0, *shape_trailing, dtype=grad_local.dtype, device=grad_local.device + ) + else: + # For masked indices, clamp to valid buffer range and zero out gradients + if idx_mask_local is not None: + local_idx = torch.where(idx_mask_local, local_idx, torch.zeros_like(local_idx)) + mask_flat = idx_mask_local.reshape(shape_leading_flat, K * W, 1).to(grad_flat.dtype) + grad_flat = grad_flat * mask_flat + + idx_flat = local_idx.reshape(shape_leading_flat, K * W, 1).expand( + shape_leading_flat, K * W, shape_trailing_flat + ) + + grad_buf = torch.zeros( + shape_leading_flat, L, shape_trailing_flat, dtype=grad_local.dtype, device=grad_local.device + ) + grad_buf.scatter_add_(1, idx_flat, grad_flat) + + grad_x_buffer = grad_buf.reshape(*shape_leading, L, *shape_trailing) + + ops = [] + grad_x_local = torch.zeros(x_local_shape, dtype=grad_local.dtype, device=grad_local.device) + + # Backward send (reverse of recv) + for peer, interval, _shape in recv_meta: + start_local_buf = interval[0, 0] - need_start.to(interval.device) + length = interval[0, 1] - interval[0, 0] + grad_chunk = grad_x_buffer.narrow(axis, start_local_buf.item(), length.item()) + + if peer == dist.get_rank(): + # self accumulate + start_local = interval[0, 0] - own_interval[0, 0] + target = grad_x_local.narrow(axis, start_local.item(), length.item()) + target.add_(grad_chunk) + else: + ops.append(dist.P2POp(dist.isend, grad_chunk.contiguous(), peer)) + + # Backward recv (reverse of send) + bwd_recv_bufs = [] + for peer, start_local, length in send_meta: + shape = list(x_local_shape) + shape[axis] = length.item() + buf = torch.empty(shape, dtype=grad_local.dtype, device=grad_local.device) + ops.append(dist.P2POp(dist.irecv, buf, peer)) + bwd_recv_bufs.append((buf, start_local, length)) + + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + for buf, start_local, length in bwd_recv_bufs: + target = grad_x_local.narrow(axis, start_local.item(), length.item()) + target.add_(buf) + + grad_x_dtensor = DTensor.from_local( + grad_x_local, + grad_output.device_mesh, + x_placements, + shape=x_global_shape, + stride=update_exhaustive_strides(grad_x_local.shape, grad_x_local.stride(), x_global_shape), + ) + + return grad_x_dtensor, None, None, None, None + + +def distributed_gather( + x_dtensor: DTensor, + idx_dtensor: DTensor, + axis: int = 1, + are_ids_contiguous: bool = False, + idx_mask: DTensor | None = None, +) -> DTensor: + """Distributed 1D gather. + + Args: + x_dtensor: DTensor with shape ``(*batch, N, *features)``. + idx_dtensor: DTensor with shape ``(*batch, K, W)`` that provides gather + indices into the ``N`` dimension of ``x_dtensor``. + axis: Axis in ``x_dtensor`` corresponding to ``N`` (and to ``K`` in + ``idx_dtensor``). + are_ids_contiguous: This is a heuristic for selecting the underlying + send/recv strategy for performance purpose. Currently only True is supported, + which means that the idx_dtensor maps to a contiguous block of x along (axis) + dimensions for all the shards and for all the leading (batch) dimensions. + When True, the underlying strategy will use the min/max of idx_dtensor to + compute the needed interval, assuming the resulting buffer to be communicated + across the ranks is fully (or approximately so) utilized. In the case the + data inside idx_dtensor doesn't mapped to contiguous blocks, the result will + still be correct by setting are_ids_contiguous=True but the buffer of x chunks + communicated will contain a lot of unused elements, making the sed/recv inefficient. + idx_mask: Optional DTensor with shape ``(*batch, K, W)`` and same device_mesh + and placements as ``idx_dtensor``. Elements with True indicate valid indices, + elements with False indicate invalid indices that should be ignored. + """ + return DistributedGather.apply(x_dtensor, idx_dtensor, axis, are_ids_contiguous, idx_mask) diff --git a/src/boltz/distributed/model/layers/layernorm.py b/src/boltz/distributed/model/layers/layernorm.py new file mode 100644 index 000000000..81e2647bd --- /dev/null +++ b/src/boltz/distributed/model/layers/layernorm.py @@ -0,0 +1,577 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from typing import Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard, distribute_tensor + +_shape_t = Union[int, list[int], torch.Size] + + +class _ContextParallelLayerNormImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + normalized_shape: list[int], + weight: Optional[Tensor], + bias: Optional[Tensor], + eps: float, + reduce_group: dist.ProcessGroup, + ) -> Tensor: + """Forward pass of the layer normalization. + + Args: + ctx: context + x (Tensor): input tensor + normalized_shape (list[int]): shape of the input tensor + weight (Optional[Tensor]): weight tensor + bias (Optional[Tensor]): bias tensor + eps (float): a value added to the denominator for numerical stability + reduce_group (dist.ProcessGroup): process group for all-reduce + + Returns: + output tensor (Tensor) + """ + weight_needs_grad = weight is not None and weight.requires_grad + bias_needs_grad = bias is not None and bias.requires_grad + # For unknown reasons, using ctx.needs_input_grad in the forward pass can occasionally + # cause NCCL hanging. ctx.need_input_grad should not be accessed during the forward pass + # according to this discussion on pytorch forum: + # https://discuss.pytorch.org/t/is-there-a-diffrence-between-ctx-needs-input-grad-behaviour-vs-input-tensor-requires-grad/195063/2 + if x.requires_grad or weight_needs_grad or bias_needs_grad: + ctx.reduce_group = reduce_group + ctx.eps = eps + ctx.normalized_shape = normalized_shape + + if not x.requires_grad: + weight = None + + ctx.save_for_backward(x, weight) + + return F.layer_norm(x, normalized_shape, weight, bias, eps) + + @staticmethod + def backward( + ctx, grad_output: Tensor + ) -> tuple[Optional[Tensor], None, Optional[Tensor], Optional[Tensor], None, None]: + """Backward pass of the layer normalization. + + Although the output between tokens is independent, the backward pass on the weight and bias tensors involves a summation over all tokens, and thus requires all-reduce due to context parallelism. + + Args: + ctx: context + grad_output (Tensor): gradient of the output tensor + + Returns: + gradient for input, weight, bias, eps, and reduce_group (tuple[Optional[Tensor], None, Optional[Tensor], Optional[Tensor], None, None]) + """ + x, weight = ctx.saved_tensors + eps = ctx.eps + normalized_shape = ctx.normalized_shape + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[2]: + dims = tuple(-(i + 1) for i in range(len(normalized_shape))) + mean = x.mean(dim=dims, keepdim=True) + var = x.var(dim=dims, unbiased=False, keepdim=True) + x_norm = (x - mean) / torch.sqrt(var + eps) + + if ctx.needs_input_grad[0]: + if weight is not None: + dy = grad_output * weight.view(*([1] * (grad_output.ndim - len(normalized_shape))), *weight.shape) + else: + dy = grad_output + + dims = tuple(-(i + 1) for i in range(len(normalized_shape))) + dy_mean = dy.mean(dim=dims, keepdim=True) + dy_x_norm_mean = (dy * x_norm).mean(dim=dims, keepdim=True) + grad_input = (dy - dy_mean - x_norm * dy_x_norm_mean) / torch.sqrt(var + eps) + else: + grad_input = None + + if ctx.needs_input_grad[2]: + reduce_dims = list(range(grad_output.ndim - len(normalized_shape))) + grad_weight = (grad_output * x_norm).sum(dim=reduce_dims) + grad_weight = grad_weight.contiguous() + grad_weight_work = dist.all_reduce(grad_weight, op=dist.ReduceOp.SUM, group=ctx.reduce_group, async_op=True) + else: + grad_weight = None + + if ctx.needs_input_grad[3]: + reduce_dims = list(range(grad_output.ndim - len(normalized_shape))) + grad_bias = grad_output.sum(dim=reduce_dims) + grad_bias_work = dist.all_reduce(grad_bias, op=dist.ReduceOp.SUM, group=ctx.reduce_group, async_op=True) + else: + grad_bias = None + + # collect all_reduce results at the end + if ctx.needs_input_grad[2]: + grad_weight_work.wait() + if ctx.needs_input_grad[3]: + grad_bias_work.wait() + + return grad_input, None, grad_weight, grad_bias, None, None + + +class ContextParallelLayerNorm(nn.LayerNorm): + def __init__( + self, + normalized_shape: _shape_t, + reduce_group: dist.ProcessGroup, + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device=None, + dtype=None, + ): + """Context parallel layer normalization, a wrapper around nn.LayerNorm that supports distributed training. + + Although the output between tokens is independent, the backward pass on the weight and bias tensors involves a summation over all tokens, and thus requires all-reduce due to context parallelism. This means we need a dedicated ContextParallelLayerNorm class for training. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + reduce_group (dist.ProcessGroup): The process group to use for gradient all-reduce. + eps (float): a value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine (bool): a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + bias (bool): If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`elementwise_affine` is ``True``). Default: ``True``. + device (torch.device, optional): device on which the module is allocated. Defaults to None. + dtype (torch.dtype, optional): dtype of the module. Defaults to None. + """ + super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) + assert reduce_group is not None, "reduce_group must be provided" + self.reduce_group = reduce_group + + if isinstance(normalized_shape, int): + self._normalized_shape_list = [normalized_shape] + else: + self._normalized_shape_list = list(normalized_shape) + + def forward(self, input: Tensor) -> Tensor: + return _ContextParallelLayerNormImpl.apply( + input, self._normalized_shape_list, self.weight, self.bias, self.eps, self.reduce_group + ) + + +def get_cp_layernorm( + normalized_shape: _shape_t, + reduce_group: dist.ProcessGroup | None = None, + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device: torch.device | None = None, + dtype: torch.dtype | None = None, +) -> nn.LayerNorm | ContextParallelLayerNorm: + """Get a layer normalization module that is optimized for distributed training. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + reduce_group (dist.ProcessGroup, optional): The process group to use for gradient all-reduce. + Defaults to None. + eps (float, optional): a value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine (bool, optional): a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`elementwise_affine` is ``True``). Default: ``True``. + device (torch.device, optional): device on which the module is allocated. Defaults to None. + dtype (torch.dtype, optional): dtype of the module. Defaults to None. + + Returns: + nn.LayerNorm | ContextParallelLayerNorm + """ + if reduce_group is None: + return nn.LayerNorm( + normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + bias=bias, + device=device, + dtype=dtype, + ) + else: + return ContextParallelLayerNorm( + normalized_shape, + reduce_group=reduce_group, + eps=eps, + elementwise_affine=elementwise_affine, + bias=bias, + ) + + +class _LayerNormParamsReplicatedImpl(torch.autograd.Function): + """ + A custom implementation of LayerNorm with replicated parameters for distributed training. + + This class provides a forward and backward implementation of LayerNorm, ensuring compatibility + with distributed tensor placements and device meshes. It supports replicated and sharded + placements for input tensors and replicated placements for weight and bias tensors. + + NOTE: by default, avg reduce over the Replicate placements of the weight and bias gradients + is performed. This is to ensure identical parameter updates across all ranks and avoid + gradual divergence during training. This can be disabled by setting + avg_over_replicate_param_grad to False. + + Methods: + forward(ctx, x, normalized_shape, weight, bias, eps): + Computes the forward pass of LayerNorm. + + backward(ctx, grad_output): + Computes the backward pass of LayerNorm. + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + x: DTensor, + normalized_shape: list[int], + weight: Optional[DTensor], + bias: Optional[DTensor], + eps: float, + cast_params_dtype_to_x: Optional[bool] = False, + avg_over_replicate_param_grad: bool = True, + ) -> DTensor: + """ + Forward pass of LayerNorm with replicated parameters. + + Args: + ctx: Context for saving tensors for backward computation. + x (DTensor): Input tensor. + normalized_shape (list[int]): Shape of the input tensor to normalize. + weight (Optional[DTensor]): Weight tensor for affine transformation. + bias (Optional[DTensor]): Bias tensor for affine transformation. + eps (float): A small value added for numerical stability. + cast_params_dtype_to_x (Optional[bool]): whether to cast the dtype of + the weights and bias to the dtype of the input tensor + avg_over_replicate_param_grad (bool): Whether to perform avg reduce over the + Replicate placements of the weight and bias gradients. For example, + if the input DTensor x.placements = (Shard(0), Replicate()), this layer's + parameters' gradients.placements = (Partial("sum"), Replicate()) if + self._avg_over_replicate_param_grad is False; otherwise, it will be + (Partial("sum"), Partial("avg")). The motivation is to ensure identical + parameter updates across all ranks and avoid gradual divergence during + training. + + Returns: + DTensor: The normalized output tensor. + """ + if not isinstance(x, DTensor): + dtensor_instance = x + raise TypeError( + ", ".join( + [ + f"DTensor instance '{dtensor_instance}' should have type {DTensor}", + f"but instead has type {type(dtensor_instance)}.", + ] + ) + ) + device_mesh = x.device_mesh + ndim_device_mesh = device_mesh.ndim + all_replicate_placements = tuple([Replicate()] * ndim_device_mesh) + if weight is not None: + if not isinstance(weight, DTensor): + dtensor_instance = weight + raise TypeError( + ", ".join( + [ + f"DTensor instance '{dtensor_instance}' should have type {DTensor}", + f"but instead has type {type(dtensor_instance)}.", + ] + ) + ) + if weight.device_mesh != device_mesh: + raise ValueError("weight and x must be on the same device mesh") + if weight.placements != all_replicate_placements: + raise ValueError("weight must be replicated on all device mesh dimensions") + if bias is not None: + if not isinstance(bias, DTensor): + dtensor_instance = bias + raise TypeError( + ", ".join( + [ + f"DTensor instance '{dtensor_instance}' should have type {DTensor}", + f"but instead has type {type(dtensor_instance)}.", + ] + ) + ) + if bias.device_mesh != device_mesh: + raise ValueError("bias and x must be on the same device mesh") + if bias.placements != all_replicate_placements: + raise ValueError("bias must be replicated on all device mesh dimensions") + if weight is not None or bias is not None: + if avg_over_replicate_param_grad: + placements_grad_params = [Partial("avg")] * ndim_device_mesh + else: + placements_grad_params = list(weight.placements) if weight is not None else None + else: + placements_grad_params = None + n_dim_norm = len(normalized_shape) + for i_dim_device_mesh, p in enumerate(x.placements): + if isinstance(p, Partial): + # partial reduction along any input dimension requires complicated backward pass + raise ValueError("Partial reduction along any input dimension is not supported") + if isinstance(p, Shard): + if p.dim >= x.ndim - n_dim_norm: + # the normalized dimensions must not be sharded by the device mesh + raise ValueError("LayerNorm's normalizing dimensions must not be sharded by the device mesh") + if x.shape[p.dim] % device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {p.dim} of size {x.shape[p.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size " + f"{device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + # the only supported placement for the input is Shard, which corresponding + # to the backward's grad partial sum. Otherwise, we can only support Replicate + # placements for other device mesh dimensions. Also, by using the Partial("sum") + # placement on the params, the all_reduce is postponed for the params' gradients + # until needed + if weight is not None or bias is not None: + placements_grad_params[i_dim_device_mesh] = Partial("sum") + elif not isinstance(p, Replicate): + raise ValueError(f"Unsupported x's placements along {i_dim_device_mesh} axis of the device mesh: {p}") + ctx.device_mesh = device_mesh + # will use x.placements for the x.grad in the backward pass, i.e., this function + # enforces consistent placements for the input and its gradient + ctx.placements_x = x.placements + ctx.placements_grad_params = placements_grad_params + + # Save weight and bias shapes and strides for backward pass + if weight is not None: + ctx.weight_shape = weight.shape + ctx.weight_stride = weight.stride() + if bias is not None: + ctx.bias_shape = bias.shape + ctx.bias_stride = bias.stride() + + weight_needs_grad = weight is not None and weight.requires_grad + bias_needs_grad = bias is not None and bias.requires_grad + # IMPORTANT: no modification on *_local for the rest of the code + x_local = x.to_local() + if weight is None: + weight_local = None + else: + weight_local = weight.to_local() + if cast_params_dtype_to_x: + weight_local = weight_local.to(x.dtype) + if bias is None: + bias_local = None + else: + bias_local = bias.to_local() + if cast_params_dtype_to_x: + bias_local = bias_local.to(x.dtype) + # For unknown reasons, using ctx.needs_input_grad in the forward pass can occasionally + # cause NCCL hanging. ctx.need_input_grad should not be accessed during the forward pass + # according to this discussion on pytorch forum: + # https://discuss.pytorch.org/t/is-there-a-diffrence-between-ctx-needs-input-grad-behaviour-vs-input-tensor-requires-grad/195063/2 + if x.requires_grad or weight_needs_grad or bias_needs_grad: + ctx.eps = eps + ctx.normalized_shape = normalized_shape + + if not x.requires_grad: + weight = None + + ctx.save_for_backward(x_local, weight_local) + + output_local = F.layer_norm(x_local, normalized_shape, weight_local, bias_local, eps) + # LayerNorm does not change input's shape + output = DTensor.from_local( + output_local, + device_mesh=device_mesh, + placements=x.placements, + shape=x.shape, + stride=x.stride(), + ) + return output + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward( + ctx, grad_output: DTensor + ) -> tuple[Optional[DTensor], None, Optional[DTensor], Optional[DTensor], None, None]: + """ + Backward pass of LayerNorm with replicated parameters. + + Args: + ctx: Context containing saved tensors and attributes from the forward pass. + grad_output (DTensor): Gradient of the output tensor. + + Returns: + tuple: Gradients for input, weight, bias, and other parameters. + """ + x_local, weight_local = ctx.saved_tensors + eps = ctx.eps + normalized_shape = ctx.normalized_shape + + # IMPORTANT: no modification on *_local for the rest of the code + grad_output_local = grad_output.to_local() + + ids_dim_norm = tuple(-(i + 1) for i in range(len(normalized_shape))) + if ctx.needs_input_grad[0] or ctx.needs_input_grad[2]: + mean_local = x_local.mean(dim=ids_dim_norm, keepdim=True) + var_local = x_local.var(dim=ids_dim_norm, unbiased=False, keepdim=True) + x_norm_local = (x_local - mean_local) / torch.sqrt(var_local + eps) + + if ctx.needs_input_grad[0]: + if weight_local is not None: + dy_local = grad_output_local * weight_local.view( + *([1] * (grad_output_local.ndim - len(normalized_shape))), *weight_local.shape + ) + else: + dy_local = grad_output_local + + dy_mean_local = dy_local.mean(dim=ids_dim_norm, keepdim=True) + dy_x_norm_mean_local = (dy_local * x_norm_local).mean(dim=ids_dim_norm, keepdim=True) + grad_input_local = (dy_local - dy_mean_local - x_norm_local * dy_x_norm_mean_local) / torch.sqrt( + var_local + eps + ) + # LayerNorm does not change input's shape in both forward and backward passes + grad_input = DTensor.from_local( + grad_input_local, + device_mesh=ctx.device_mesh, + placements=ctx.placements_x, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + else: + grad_input = None + + reduce_dims = list(range(grad_output_local.ndim - len(normalized_shape))) + if ctx.needs_input_grad[2]: + grad_weight_local = (grad_output_local * x_norm_local).sum(dim=reduce_dims) + # all-replicate weight implies identical shape and stride across all ranks + grad_weight = DTensor.from_local( + grad_weight_local, + device_mesh=ctx.device_mesh, + placements=ctx.placements_grad_params, + shape=ctx.weight_shape, + stride=ctx.weight_stride, + ) + else: + grad_weight = None + + if ctx.needs_input_grad[3]: + grad_bias_local = grad_output_local.sum(dim=reduce_dims) + # all-replicate weight implies identical shape and stride across all ranks + grad_bias = DTensor.from_local( + grad_bias_local, + device_mesh=ctx.device_mesh, + placements=ctx.placements_grad_params, + shape=ctx.bias_shape, + stride=ctx.bias_stride, + ) + else: + grad_bias = None + + return grad_input, None, grad_weight, grad_bias, None, None, None + + +class LayerNormParamsReplicated(nn.Module): + """ + A LayerNorm module with replicated parameters for distributed training. + + This module wraps around `_LayerNormParamsReplicatedImpl` to provide a user-friendly interface + for LayerNorm operations using the DTensor API. It supports distributed training with replicated + and sharded placements for input tensors and replicated placements for weight and bias tensors. + + NOTE: by default, avg reduce over the Replicate placements of the weight and bias gradients + is performed. This is to ensure identical parameter updates across all ranks and avoid + gradual divergence during training. This can be disabled by setting + avg_over_replicate_param_grad to False. + + Args: + layer_local (nn.LayerNorm): An already-initialized nn.LayerNorm instance. + device_mesh (DeviceMesh): The device mesh for distributed training. + avg_over_replicate_param_grad (bool): Whether to perform avg reduce over the + Replicate placements of the weight and bias gradients. For example, + if the input DTensor x.placements = (Shard(0), Replicate()), this layer's + parameters' gradients.placements = (Partial("sum"), Replicate()) if + self._avg_over_replicate_param_grad is False; otherwise, it will be + (Partial("sum"), Partial("avg")). The motivation is to ensure identical + parameter updates across all ranks and avoid gradual divergence during + training. + """ + + def __init__( + self, layer_local: nn.LayerNorm, device_mesh: DeviceMesh, avg_over_replicate_param_grad: bool = True + ) -> None: + if not isinstance(layer_local, nn.LayerNorm): + raise TypeError("layer_local is not an instance of nn.LayerNorm") + if layer_local.weight is not None and layer_local.weight.device.type != device_mesh.device_type: + raise ValueError( + f"layer_local.weight and device_mesh are not on the same device type: " + f"{layer_local.weight.device.type} != {device_mesh.device_type}" + ) + if layer_local.bias is not None and layer_local.bias.device.type != device_mesh.device_type: + raise ValueError( + f"layer_local.bias and device_mesh are not on the same device type: " + f"{layer_local.bias.device.type} != {device_mesh.device_type}" + ) + + super().__init__() + self.normalized_shape = layer_local.normalized_shape + self.eps = layer_local.eps + self.device_mesh = device_mesh + self.elementwise_affine = layer_local.elementwise_affine + self._avg_over_replicate_param_grad = avg_over_replicate_param_grad + + all_replicate_placements = [Replicate()] * device_mesh.ndim + + if layer_local.weight is None: + self.register_parameter("weight", None) + else: + self.weight = nn.Parameter( + distribute_tensor(layer_local.weight.data, device_mesh, all_replicate_placements) + ) + if layer_local.bias is None: + self.register_parameter("bias", None) + else: + self.bias = nn.Parameter(distribute_tensor(layer_local.bias.data, device_mesh, all_replicate_placements)) + + def forward(self, x: DTensor) -> DTensor: + """ + Forward pass of LayerNormParamsReplicated. + + Args: + x (DTensor): Input tensor. + + Returns: + DTensor: The normalized output tensor. + """ + return _LayerNormParamsReplicatedImpl.apply( + x, + self.normalized_shape, + self.weight, + self.bias, + self.eps, + True, + self._avg_over_replicate_param_grad, + ) diff --git a/src/boltz/distributed/model/layers/linear.py b/src/boltz/distributed/model/layers/linear.py new file mode 100644 index 000000000..e828988f2 --- /dev/null +++ b/src/boltz/distributed/model/layers/linear.py @@ -0,0 +1,500 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_tensor + +from boltz.distributed.utils import update_exhaustive_strides + + +class _ContextParallelLinearImpl(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, weight: Tensor, bias: Optional[Tensor], reduce_group: dist.ProcessGroup) -> Tensor: + """Forward pass of the linear layer. + + Args: + ctx: context + x (Tensor): input tensor + weight (Tensor): weight tensor + bias (Optional[Tensor]): bias tensor + reduce_group (dist.ProcessGroup): process group for all-reduce + + Returns: + output tensor (Tensor) + """ + # For unknown reasons, using ctx.needs_input_grad in the forward pass can occasionally + # cause NCCL hanging. ctx.need_input_grad should not be accessed during the forward pass + # according to this discussion on pytorch forum: + # https://discuss.pytorch.org/t/is-there-a-diffrence-between-ctx-needs-input-grad-behaviour-vs-input-tensor-requires-grad/195063/2 + if x.requires_grad or weight.requires_grad or (bias is not None and bias.requires_grad): + ctx.reduce_group = reduce_group + ctx.save_for_backward( + x if weight.requires_grad else None, + weight if x.requires_grad else None, + ) + return F.linear(x, weight, bias) + + @staticmethod + def backward(ctx, grad_output: Tensor) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None]: + """Backward pass of the linear layer. + + Although the output between tokens is independent, the backward pass on the weight and bias tensors involves a summation over all tokens, and thus requires all-reduce due to context parallelism. + + Args: + ctx: context + grad_output (Tensor): gradient of the output tensor + + Returns: + gradient for input, weight, bias, and reduce_group (tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None]) + """ + x, weight = ctx.saved_tensors + + if ctx.needs_input_grad[1]: + dw = torch.einsum("...i,...o->io", grad_output, x) + dw = dw.contiguous() + dw_work = dist.all_reduce(dw, op=dist.ReduceOp.SUM, group=ctx.reduce_group, async_op=True) + else: + dw = None + dw_work = None + + if ctx.needs_input_grad[2]: + dims = list(range(grad_output.ndim - 1)) # aggregate over all but the last dimension + db = grad_output.sum(dim=dims) + db = db.contiguous() + db_work = dist.all_reduce(db, op=dist.ReduceOp.SUM, group=ctx.reduce_group, async_op=True) + else: + db = None + db_work = None + + if ctx.needs_input_grad[0]: + grad_input = torch.einsum("...i,io->...o", grad_output, weight) + else: + grad_input = None + + # collect all work + if dw_work is not None: + dw_work.wait() + if db_work is not None: + db_work.wait() + + return grad_input, dw, db, None + + +class _LinearParamsReplicatedImpl(torch.autograd.Function): + """ + Custom autograd Function implementation for distributed linear operation with replicated parameters. + + The main purpose of this implementation is to avoid the unnecessary overhead seen the the + equivalent distribute_module-wrapped linear layer, where the output tensors have nonsensical Replicate + placements along device mesh dimensions that are not intended + + This implementation handles the forward and backward passes for a distributed linear layer where + parameters (weight and bias) are replicated across the device mesh. The input tensor can have + various placement strategies. + + NOTE: by default, avg reduce over the Replicate placements of the weight and bias gradients + is performed. This is to ensure identical parameter updates across all ranks and avoid + gradual divergence during training. This can be disabled by setting + avg_over_replicate_param_grad to False. + + Assumptions and requirements: + (see the respective docstring for forward and backward) + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + x: DTensor, + weight: DTensor, + bias: Optional[DTensor], + cast_params_dtype_to_x: bool = False, + avg_over_replicate_param_grad: bool = True, + ) -> DTensor: + """ + Forward pass for the distributed linear operation. + + Assumptions and requirements: + 1. Parameters (weight and bias) must be replicated on all device mesh dimensions + 2. Input tensor and parameters must be on the same device mesh + 3. Feature/hidden dimension of the input must not be sharded across the device mesh + 4. Partial reduction along any input dimension is not supported + 5. Input and outputs must be on the same device mesh with the same placements + + Args: + ctx: Context object to store information for backward pass + x: Input tensor with arbitrary placement strategy + weight: Weight tensor (must be replicated across all device mesh dimensions) + bias: Optional bias tensor (must be replicated if provided) + cast_params_dtype_to_x: whether to cast the dtype of the weight and bias + to the dtype of the input tensor + avg_over_replicate_param_grad: whether to perform avg reduce over the + Replicate placements of the weight and bias gradients. For example, + if the input DTensor x.placements = (Shard(0), Replicate()), this layer's + parameters' gradients.placements = (Partial("sum"), Replicate()) if + self._avg_over_replicate_param_grad is False; otherwise, it will be + (Partial("sum"), Partial("avg")). The motivation is to ensure identical + parameter updates across all ranks and avoid gradual divergence during + training. + + Returns: + Output tensor with same placement strategy as input + + Raises: + ValueError: If any of the placement requirements are violated + """ + device_mesh = x.device_mesh + if weight.device_mesh != device_mesh: + raise ValueError("weight and x must be on the same device mesh") + if bias is not None and bias.device_mesh != device_mesh: + raise ValueError("bias and x must be on the same device mesh") + ndim_device_mesh = device_mesh.ndim + all_replicate_placements = tuple([Replicate()] * ndim_device_mesh) + if weight.placements != all_replicate_placements: + raise ValueError("weight must be replicated on all device mesh dimensions") + if bias is not None and bias.placements != all_replicate_placements: + raise ValueError("bias must be replicated on all device mesh dimensions") + if avg_over_replicate_param_grad: + placements_grad_params = [Partial("avg")] * ndim_device_mesh + else: + # all-replicate placements + placements_grad_params = list(weight.placements) + for i_dim_device_mesh, p in enumerate(x.placements): + if isinstance(p, Partial): + # partial reduction along any input dimension requires complicated backward pass + raise ValueError("Partial reduction along any input dimension is not supported") + if isinstance(p, Shard): + if p.dim == x.ndim - 1: + # the feature or hidden dimension must not be a part of the device mesh + raise ValueError("feature or hidden dimension must not be a part of the device mesh") + if x.shape[p.dim] % device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {p.dim} of size {x.shape[p.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size " + f"{device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + # the only supported placement for the input is Shard, which corresponding + # to the backward's grad partial sum. Otherwise, we can only support Replicate + # placements for other device mesh dimensions. Also, by using the Partial("sum") + # placement on the params, the all_reduce is postponed for the params' gradients + # until needed + placements_grad_params[i_dim_device_mesh] = Partial("sum") + elif not isinstance(p, Replicate): + raise ValueError(f"Unsupported x's placements along {i_dim_device_mesh} axis of the device mesh: {p}") + ctx.device_mesh = device_mesh + # will use x.placements for the x.grad in the backward pass, i.e., this function + # enforces consistent placements for the input and its gradient + ctx.placements_x = x.placements + ctx.placements_grad_params = placements_grad_params + ctx.shape_input = x.shape + ctx.stride_input = x.stride() + ctx.weight_shape = weight.shape + ctx.weight_stride = weight.stride() + ctx.dtype_input = x.dtype + ctx.dtype_weight = weight.dtype + if bias is not None: + ctx.bias_shape = bias.shape + ctx.bias_stride = bias.stride() + ctx.dtype_bias = bias.dtype + else: + ctx.dtype_bias = None + x_local = x.to_local() + weight_local = weight.to_local() + bias_local = None if bias is None else bias.to_local() + + # Save original-precision locals for backward *before* any dtype cast. + # Native autocast saves fp32 weights and lets the backward autocast + # context handle further casts. Saving the bf16-cast version would + # bake in bf16 rounding on CPU (where custom_bwd does NOT restore + # autocast), silently lowering gradient precision. + if x.requires_grad or weight.requires_grad or (bias is not None and bias.requires_grad): + ctx.save_for_backward( + x_local.detach().clone() if weight.requires_grad else None, + weight_local.detach().clone() if x.requires_grad else None, + ) + + if cast_params_dtype_to_x: + weight_local = weight_local.to(x.dtype) + if bias_local is not None: + bias_local = bias_local.to(x.dtype) + # Extract the local shard to perform the linear operation. + # This enforces local matrix multiplication without any communication given that: + # 1. the linear operation is performed locally on each rank along the hidden dimension, + # which is agnostic to the device mesh dimensions + # 2. the weight and bias are replicated on all device mesh dimensions + # 3. the output has the same placements as the input + output_local = torch.nn.functional.linear(x_local, weight_local, bias_local) + # linear only change the last dimension of the input so we need to + # modify the output shape and strides accordingly + shape_output = tuple(x.shape[:-1]) + (output_local.shape[-1],) + strides_output = update_exhaustive_strides(x.shape, x.stride(), shape_output) + output = DTensor.from_local(output_local, device_mesh, x.placements, shape=shape_output, stride=strides_output) + return output + + @staticmethod + def _all_reduce_grad_gteqfp32( + grad: torch.Tensor, + device_mesh: DeviceMesh, + placements: list, + target_dtype: torch.dtype, + ) -> torch.Tensor: + """All-reduce a parameter gradient in at least fp32 across mesh dims. + + For each mesh dimension with a ``Partial`` placement, performs an + all-reduce in at least float32 to avoid bf16/fp16 accumulation errors. + Only the parameter-sized gradient is promoted — not the large + activation tensors. If the gradient is already >=fp32, it is reduced + in its native dtype. + """ + needs_reduce = any(isinstance(p, Partial) and device_mesh.size(dim) > 1 for dim, p in enumerate(placements)) + if not needs_reduce: + return grad.to(target_dtype) + + reduce_dtype = torch.promote_types(grad.dtype, torch.float32) + grad = grad.to(reduce_dtype).contiguous() + for mesh_dim, p in enumerate(placements): + if not isinstance(p, Partial) or device_mesh.size(mesh_dim) <= 1: + continue + group = device_mesh.get_group(mesh_dim) + if p.reduce_op == "sum": + dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=group) + elif p.reduce_op == "avg": + dist.all_reduce(grad, op=dist.ReduceOp.AVG, group=group) + else: + raise ValueError(f"Unsupported reduce_op {p.reduce_op!r} in _all_reduce_grad_gteqfp32") + return grad.to(target_dtype) + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward( + ctx, grad_output: DTensor + ) -> tuple[Optional[DTensor], Optional[DTensor], Optional[DTensor], None, None]: + """Backward pass for the distributed linear operation. + + Local einsum stays in the compute dtype (bf16 MMA accumulates in + fp32 internally on CUDA, so local results are accurate). The + precision hazard is cross-rank reduction: implicit ``Partial("SUM")`` + would reduce in bf16, accumulating errors. We manually all-reduce + ``dw``/``db`` in fp32 via ``_all_reduce_grad_gteqfp32`` and return + them with ``Replicate`` placements. + + On CPU (unit tests), ``custom_bwd`` does not restore autocast, so + we explicitly cast operands to the compute dtype. + """ + if grad_output.device_mesh != ctx.device_mesh: + raise ValueError( + "_LinearParamsReplicatedImpl: different device mesh between grad_output and the forward input" + ) + x_local, weight_local = ctx.saved_tensors + + if grad_output.placements != ctx.placements_x: + # DTensor's backward may spuriously all_gather to Replicate(); + # redistribute back to the input's placements. + grad_output = grad_output.redistribute(ctx.device_mesh, ctx.placements_x) + + grad_output_local = grad_output.to_local() + all_replicate = tuple([Replicate()] * ctx.device_mesh.ndim) + + # Compute dtype (e.g. bf16 under mixed precision). We cast operands + # to this dtype explicitly for CPU compatibility — on CUDA, custom_bwd + # restores autocast which handles this automatically. + go_dtype = grad_output_local.dtype + + if ctx.needs_input_grad[1]: + dw_local = torch.einsum("...i,...o->io", grad_output_local, x_local.to(go_dtype)) + dw_local = _LinearParamsReplicatedImpl._all_reduce_grad_gteqfp32( + dw_local, ctx.device_mesh, ctx.placements_grad_params, ctx.dtype_weight + ) + dw = DTensor.from_local( + dw_local, ctx.device_mesh, all_replicate, shape=ctx.weight_shape, stride=ctx.weight_stride + ) + else: + dw = None + + if ctx.needs_input_grad[2]: + dims = list(range(grad_output_local.ndim - 1)) + if ctx.dtype_bias is None: + raise RuntimeError("bias gradient requested but bias dtype metadata is missing") + db_local = grad_output_local.sum(dim=dims) + db_local = _LinearParamsReplicatedImpl._all_reduce_grad_gteqfp32( + db_local, ctx.device_mesh, ctx.placements_grad_params, ctx.dtype_bias + ) + db = DTensor.from_local( + db_local, ctx.device_mesh, all_replicate, shape=ctx.bias_shape, stride=ctx.bias_stride + ) + else: + db = None + + if ctx.needs_input_grad[0]: + if weight_local is None: + raise RuntimeError("input gradient requested but saved weight tensor is missing") + grad_input_local = torch.einsum("...i,io->...o", grad_output_local, weight_local.to(go_dtype)) + grad_input = DTensor.from_local( + grad_input_local, ctx.device_mesh, ctx.placements_x, shape=ctx.shape_input, stride=ctx.stride_input + ) + else: + grad_input = None + + return grad_input, dw, db, None, None + + +class LinearParamsReplicated(nn.Module): + """ + Distributed linear layer with parameters replicated across all device mesh dimensions. + + This is almost equivalent to + ```python + layer = torch.distributed.tensor.distribute_module(layer_local, device_mesh) + ``` + with the exception that the torch.distributed.tensor.distribute_module version will incur + significant overhead due to the unnecessary replication of the output tensor along certain + device mesh dimensions. + + This class avoids such unnecessary overhead by using the custom _LinearParamsReplicatedImpl + autograd function for forward and backward pass computation instead of relying on the distributed + module's forward implementation. + + NOTE: by default, avg reduce over the Replicate placements of the weight and bias gradients + is performed. This is to ensure identical parameter updates across all ranks and avoid + gradual divergence during training. This can be disabled by setting + avg_over_replicate_param_grad to False. + + Key requirements: + 1. Parameters (weight and bias) will replicated on all device mesh dimensions + 2. Input tensor and parameters must be on the same device mesh + 3. Feature/hidden dimension of the input must not be sharded across the device mesh + 4. Partial reduction along any input dimension is not supported + 5. Input and outputs must be on the same device mesh with the same placements + 6. Gradients of the input have the same placements on the same device mesh as the input + 7. Gradients of the weight and bias have Partial("sum") placements along the input's Shard placements' + dimension so that the all-reduce will be performed along those device-grid dimensions + + """ + + def __init__(self, layer_local: nn.Linear, device_mesh: DeviceMesh, avg_over_replicate_param_grad: bool = True): + """ + Initialize the distributed linear layer. + + Args: + layer_local: nn.Linear to be distributed + device_mesh: Device mesh for distributed computation + avg_over_replicate_param_grad: whether to perform avg reduce over the + Replicate placements of the weight and bias gradients. For example, + if the input DTensor x.placements = (Shard(0), Replicate()), this layer's + parameters' gradients.placements = (Partial("sum"), Replicate()) if + self._avg_over_replicate_param_grad is False; otherwise, it will be + (Partial("sum"), Partial("avg")). The motivation is to ensure identical + parameter updates across all ranks and avoid gradual divergence during + training. + """ + if not isinstance(layer_local, nn.Linear): + raise ValueError("layer_local is not an instance of nn.Linear") + if layer_local.weight.device.type != device_mesh.device_type: + raise ValueError( + f"layer_local.weight and device_mesh are not on the same device type: " + f"{layer_local.weight.device.type} != {device_mesh.device_type}" + ) + if layer_local.bias is not None and layer_local.bias.device.type != device_mesh.device_type: + raise ValueError( + f"layer_local.bias and device_mesh are not on the same device type: " + f"{layer_local.bias.device.type} != {device_mesh.device_type}" + ) + super().__init__() + all_replicate_placements = [Replicate()] * device_mesh.ndim + self.weight = nn.Parameter( + distribute_tensor(layer_local.weight.data, device_mesh, all_replicate_placements), + requires_grad=layer_local.weight.requires_grad, + ) + if layer_local.bias is None: + self.register_parameter("bias", None) + else: + self.bias = nn.Parameter( + distribute_tensor(layer_local.bias.data, device_mesh, all_replicate_placements), + requires_grad=layer_local.bias.requires_grad, + ) + self._avg_over_replicate_param_grad = avg_over_replicate_param_grad + + def forward(self, input: DTensor) -> DTensor: + """ + Forward pass for the distributed linear layer. + + Uses the custom _LinearParamsReplicatedImpl autograd function to perform the computation + efficiently while preserving correct autograd behavior for distributed tensors. + + Args: + input: Input DTensor with appropriate placement strategy + + Returns: + Output DTensor with same placement strategy as input + """ + return _LinearParamsReplicatedImpl.apply( + input, + self.weight, + self.bias, + True, # cast_params_dtype_to_x: under bf16-mixed autocast, upstream + # ops produce bf16 activations while weights stay fp32. custom_fwd + # disables autocast inside the function, so F.linear would get + # mismatched dtypes. Casting weight to input dtype matches what + # native autocast does for F.linear. No-op when dtypes already match. + self._avg_over_replicate_param_grad, + ) + + +class ContextParallelLinear(nn.Linear): + def __init__(self, in_features, out_features, reduce_group, bias=True): + """Context parallel linear layer, a wrapper around nn.Linear that supports distributed training. + + Although the output between tokens is independent, the backward pass on the weight and bias tensors involves a summation over all tokens, and thus requires all-reduce due to context parallelism. This means we need a dedicated ContextParallelLinear class for training. + + Args: + in_features: The number of input features. + out_features: The number of output features. + reduce_group: The process group to use for the all-reduce in backward pass. + bias: Whether to use a bias. + + If group is not provided, the layer will behave like a normal nn.Linear. + """ + super().__init__(in_features, out_features, bias) + assert reduce_group is not None, "reduce_group must be provided" + self.reduce_group = reduce_group + + def forward(self, input: Tensor) -> Tensor: + return _ContextParallelLinearImpl.apply(input, self.weight, self.bias, self.reduce_group) + + +def get_cp_linear( + *args, reduce_group: Optional[dist.ProcessGroup] = None, **kwargs +) -> nn.Linear | ContextParallelLinear: + """Get a context parallel linear layer. + + If group is not provided, the returned layer will fall back to nn.Linear. + """ + if reduce_group is None: + return nn.Linear(*args, **kwargs) + return ContextParallelLinear(*args, reduce_group=reduce_group, **kwargs) diff --git a/src/boltz/distributed/model/layers/outer_gather.py b/src/boltz/distributed/model/layers/outer_gather.py new file mode 100644 index 000000000..e48cc704d --- /dev/null +++ b/src/boltz/distributed/model/layers/outer_gather.py @@ -0,0 +1,1309 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import math +from typing import Dict, List, Tuple + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor, Shard +from torch.distributed.tensor import Partial, Replicate + +from boltz.distributed.utils import update_exhaustive_strides + + +def outer_gather_backward(grad_output, z_shape, idx_q, idx_k, axis, idx_q_mask=None, idx_k_mask=None): + """ + Backward pass for outer_gather. + + Args: + grad_output: (..., K, W, H, ...) + z_shape: Shape of original input z + idx_q: (..., K, W) + idx_k: (..., K, H) + axis: Axis where N starts in z + idx_q_mask: Optional mask for idx_q of shape (..., K, W). True = valid, False = invalid. + idx_k_mask: Optional mask for idx_k of shape (..., K, H). True = valid, False = invalid. + + Returns: + Gradient w.r.t. z + """ + # Early exit for empty z + if math.prod(z_shape) == 0: + # z was empty, gradient is zeros of z_shape + return torch.zeros(z_shape, dtype=grad_output.dtype, device=grad_output.device) + + batch_shape_z = z_shape[:axis] + feature_shape_z = z_shape[axis + 2 :] + if grad_output.shape[:axis] != batch_shape_z: + # in the forward pass, z's 'axis' is replaced in the output by K, W, H so grad_output.shape[axis:axis+3] is (K, W, H) + raise ValueError( + f"grad_output.shape[:axis] must match z_shape[:axis] but got {grad_output.shape[:axis]} vs {z_shape[:axis]}" + ) + + if grad_output.shape[axis + 3 :] != feature_shape_z: + raise ValueError( + f"grad_output.shape[axis + 3 :] must match z_shape[axis + 2 :] but got {grad_output.shape[axis + 3 :]} vs {feature_shape_z}" + ) + + if grad_output.shape[: axis + 2] != idx_q.shape: + # (..., K, W) must match + raise ValueError( + f"grad_output.shape[:axis + 2] must match idx_q.shape but got {grad_output.shape[: axis + 2]} vs {idx_q.shape}" + ) + + if grad_output.shape[: axis + 1] + (grad_output.shape[axis + 2],) != idx_k.shape: + raise ValueError( + f"grad_output.shape[:axis + 1] + (grad_output.shape[axis + 2],) must match idx_k.shape but got " + f"{grad_output.shape[:axis] + (grad_output.shape[axis + 2],)} vs {idx_k.shape}" + ) + + # Validate mask shapes if provided + if idx_q_mask is not None and idx_q_mask.shape != idx_q.shape: + raise ValueError(f"idx_q_mask shape {idx_q_mask.shape} must match idx_q shape {idx_q.shape}") + if idx_k_mask is not None and idx_k_mask.shape != idx_k.shape: + raise ValueError(f"idx_k_mask shape {idx_k_mask.shape} must match idx_k shape {idx_k.shape}") + + ndim_z = len(z_shape) + flatten_leading_dims = len(batch_shape_z) >= 2 # w at least two leading axes + flatten_trailing_dims = len(feature_shape_z) >= 2 # w at least two trailing axes + has_leading_dims = len(batch_shape_z) > 0 + has_trailing_dims = len(feature_shape_z) > 0 + + # 1. Normalize Axis + ndim_z = len(z_shape) + if axis < 0: + axis += ndim_z + + # Re-construct broadcasted indices (avoid flattening if possible) + # idx_q: (B, K, W) + # idx_k: (B, K, H) + + # Broadcast to grid (B, K, W, H) + K = idx_q.shape[-2] + W = idx_q.shape[-1] + H = idx_k.shape[-1] + + # Reshape indices to (B, K, W) and (B, K, H) + idx_q_flat = idx_q + idx_k_flat = idx_k + idx_q_mask_flat = idx_q_mask + idx_k_mask_flat = idx_k_mask + if flatten_leading_dims: + idx_q_flat = idx_q_flat.flatten(0, -3) # (B, K, W) + idx_k_flat = idx_k_flat.flatten(0, -3) # (B, K, H) + if idx_q_mask_flat is not None: + idx_q_mask_flat = idx_q_mask_flat.flatten(0, -3) + if idx_k_mask_flat is not None: + idx_k_mask_flat = idx_k_mask_flat.flatten(0, -3) + + if not has_leading_dims: + idx_q_flat = idx_q_flat.unsqueeze(0) + idx_k_flat = idx_k_flat.unsqueeze(0) + if idx_q_mask_flat is not None: + idx_q_mask_flat = idx_q_mask_flat.unsqueeze(0) + if idx_k_mask_flat is not None: + idx_k_mask_flat = idx_k_mask_flat.unsqueeze(0) + + B = idx_q_flat.shape[0] + + # For masked indices, clamp to valid range (0) to avoid index errors + if idx_q_mask_flat is not None: + idx_q_flat = torch.where(idx_q_mask_flat, idx_q_flat, torch.zeros_like(idx_q_flat)) + if idx_k_mask_flat is not None: + idx_k_flat = torch.where(idx_k_mask_flat, idx_k_flat, torch.zeros_like(idx_k_flat)) + + # Broadcast to (B, K, W, H) + q_broad = idx_q_flat.unsqueeze(-1).expand(-1, -1, -1, H) + k_broad = idx_k_flat.unsqueeze(-2).expand(-1, -1, W, -1) + + # Batch indices (B, K, W, H) + batch_idx = torch.arange(B, device=grad_output.device).reshape(B, 1, 1, 1).expand(-1, K, W, H) + + # Compute linear indices for (B, N, M) -> (B*N*M) + # We are updating (B, N, M, D) + N = z_shape[axis] + M = z_shape[axis + 1] + linear_idx = batch_idx * (N * M) + q_broad * M + k_broad # (B, K, W, H) + linear_idx_flat = linear_idx.reshape(-1) # (B*K*W*H) + + # (..., K, W, H, ...) -> (..., K, W, H, D) -> (B* K * W * H, D) + grad_source_flat = grad_output + if flatten_trailing_dims: + # must flatten trailing dims before leading dims to avoid axis offset changes + grad_source_flat = grad_source_flat.flatten(axis + 3, -1) + if not has_trailing_dims: + grad_source_flat = grad_source_flat.unsqueeze(-1) + grad_source_flat = grad_source_flat.flatten(0, -2) + + # Compute combined mask and zero out gradients at invalid positions + # combined_mask: (B, K, W, H) - position is valid only if both q and k indices are valid + if idx_q_mask_flat is not None or idx_k_mask_flat is not None: + if idx_q_mask_flat is not None and idx_k_mask_flat is not None: + # Broadcast masks: q_mask (B,K,W) -> (B,K,W,H), k_mask (B,K,H) -> (B,K,W,H) + combined_mask = idx_q_mask_flat.unsqueeze(-1) & idx_k_mask_flat.unsqueeze(-2) + elif idx_q_mask_flat is not None: + combined_mask = idx_q_mask_flat.unsqueeze(-1).expand(-1, -1, -1, H) + else: + combined_mask = idx_k_mask_flat.unsqueeze(-2).expand(-1, -1, W, -1) + combined_mask_flat = combined_mask.reshape(-1, 1).to(grad_source_flat.dtype) + grad_source_flat = grad_source_flat * combined_mask_flat + + D = grad_source_flat.shape[-1] + + # Initialize flat grad_z + grad_z_flat = torch.zeros((B * N * M, D), device=grad_output.device, dtype=grad_output.dtype) + + # Scatter add + grad_z_flat.index_add_(0, linear_idx_flat, grad_source_flat) + + # Unflatten grad_z + # (B, N, M, D) -> (..., N, M, ...) + grad_z = grad_z_flat.reshape(z_shape) + + return grad_z + + +class OuterGather(torch.autograd.Function): + @staticmethod + def forward(ctx, z, idx_q, idx_k, axis=1, idx_q_mask=None, idx_k_mask=None): + """ + Perform outer gather operation: z[b, q, k] for all q in idx_q, k in idx_k. + + Args: + z: (..., N, M, ...) + idx_q: (..., K, W) + idx_k: (..., K, H) + axis: The dimension index of the first N in z. + idx_q_mask: Optional mask for idx_q of shape (..., K, W). True = valid, False = invalid. + idx_k_mask: Optional mask for idx_k of shape (..., K, H). True = valid, False = invalid. + + Returns: + Tensor of shape (..., K, W, H, ...) + """ + # 1. Normalize Axis + ndim_z = z.ndim + + if ndim_z < 2: + raise ValueError(f"z must have at least 2 dimensions but got {ndim_z}") + + if idx_q.ndim < 2: + raise ValueError(f"idx_q must have at least 2 dimensions but got {idx_q.ndim}") + + if idx_k.ndim < 2: + raise ValueError(f"idx_k must have at least 2 dimensions but got {idx_k.ndim}") + + if axis < 0: + axis += ndim_z + + if not (0 <= axis < z.ndim): + raise ValueError(f"Axis must be in range [0, {z.ndim - 1}] but got {axis}") + + # 2. Check shapes (Strict Equality) + # idx_q: (..., K, W) + # idx_k: (..., K, H) + # z: (..., N, M, ...) + + batch_shape_idx = idx_q.shape[:-2] + batch_shape_z = z.shape[:axis] + feature_shape_z = z.shape[axis + 2 :] + + if batch_shape_z != batch_shape_idx: + raise ValueError( + f"Leading dimensions must match exactly but got: z {batch_shape_z} vs idx_q {batch_shape_idx}" + ) + if idx_k.shape[:-1] != idx_q.shape[:-1]: + raise ValueError( + f"All dimensions but the last must match exactly but got: idx_k {idx_k.shape[:-2]} vs idx_q {batch_shape_idx}" + ) + + # Validate masks if provided + if idx_q_mask is not None and idx_q_mask.shape != idx_q.shape: + raise ValueError(f"idx_q_mask shape {idx_q_mask.shape} must match idx_q shape {idx_q.shape}") + if idx_k_mask is not None and idx_k_mask.shape != idx_k.shape: + raise ValueError(f"idx_k_mask shape {idx_k_mask.shape} must match idx_k shape {idx_k.shape}") + if idx_q_mask is not None and idx_q_mask.dtype != torch.bool: + raise TypeError( + f"idx_q_mask must have dtype torch.bool, got {idx_q_mask.dtype}. Use mask.bool() to convert." + ) + if idx_k_mask is not None and idx_k_mask.dtype != torch.bool: + raise TypeError( + f"idx_k_mask must have dtype torch.bool, got {idx_k_mask.dtype}. Use mask.bool() to convert." + ) + + K = idx_q.shape[-2] + W = idx_q.shape[-1] + H = idx_k.shape[-1] + + has_leading_dims = len(batch_shape_z) > 0 + + # Reshape z to (B, N, M, D) + flatten_leading_dims = len(batch_shape_z) >= 2 # w at least two leading axes + flatten_trailing_dims = len(feature_shape_z) >= 2 # w at least two trailing axes + z_flat = z + if flatten_trailing_dims: + # must flatten trailing dims before leading dims to avoid axis offset changes + z_flat = z_flat.flatten(axis + 2, -1) + if flatten_leading_dims: + z_flat = z_flat.flatten(0, axis - 1) + + # Reshape indices to (B, K, W) and (B, K, H) + idx_q_flat = idx_q + idx_k_flat = idx_k + idx_q_mask_flat = idx_q_mask + idx_k_mask_flat = idx_k_mask + if flatten_leading_dims: + idx_q_flat = idx_q_flat.flatten(0, -3) + idx_k_flat = idx_k_flat.flatten(0, -3) + if idx_q_mask_flat is not None: + idx_q_mask_flat = idx_q_mask_flat.flatten(0, -3) + if idx_k_mask_flat is not None: + idx_k_mask_flat = idx_k_mask_flat.flatten(0, -3) + + if not has_leading_dims: + z_flat = z_flat.unsqueeze(0) + idx_q_flat = idx_q_flat.unsqueeze(0) + idx_k_flat = idx_k_flat.unsqueeze(0) + if idx_q_mask_flat is not None: + idx_q_mask_flat = idx_q_mask_flat.unsqueeze(0) + if idx_k_mask_flat is not None: + idx_k_mask_flat = idx_k_mask_flat.unsqueeze(0) + B = z_flat.shape[0] + + # Early exit for empty z tensor + if z_flat.numel() == 0: + # Validate: the combined mask (outer-AND of q and k masks) must be all-False when z is empty. + # A position (w, h) is valid only if BOTH idx_q_mask[w] AND idx_k_mask[h] are True. + # We must check the combined mask, not each mask separately, because: + # - If idx_q_mask is all-False, no output positions are valid regardless of idx_k_mask + # - If idx_k_mask is all-False, no output positions are valid regardless of idx_q_mask + has_valid_output_positions = False + if idx_q_mask_flat is not None or idx_k_mask_flat is not None: + if idx_q_mask_flat is not None and idx_k_mask_flat is not None: + # Joint validation: outer-AND of the two masks + combined_mask = idx_q_mask_flat.unsqueeze(-1) & idx_k_mask_flat.unsqueeze(-2) + has_valid_output_positions = combined_mask.any().item() + elif idx_q_mask_flat is not None: + has_valid_output_positions = idx_q_mask_flat.any().item() + else: + has_valid_output_positions = idx_k_mask_flat.any().item() + else: + # No masks provided means all positions are implicitly valid + has_valid_output_positions = idx_q_flat.numel() > 0 and idx_k_flat.numel() > 0 + + if has_valid_output_positions: + raise ValueError( + "z is empty but combined mask (idx_q_mask & idx_k_mask) contains valid entries. " + "This is a logical error - cannot gather from empty tensor with valid indices." + ) + # z is empty - return zeros of correct shape + out_shape = batch_shape_z + (K, W, H) + feature_shape_z + out = torch.zeros(out_shape, dtype=z.dtype, device=z.device) + # Save context for backward (same as normal path) + tensors_to_save = [idx_q, idx_k] + if idx_q_mask is not None: + tensors_to_save.append(idx_q_mask) + if idx_k_mask is not None: + tensors_to_save.append(idx_k_mask) + ctx.save_for_backward(*tensors_to_save) + ctx.has_q_mask = idx_q_mask is not None + ctx.has_k_mask = idx_k_mask is not None + ctx.z_shape = z.shape + ctx.axis = axis + return out + + # For masked indices, clamp to valid range (0) to avoid index errors + if idx_q_mask_flat is not None: + idx_q_flat = torch.where(idx_q_mask_flat, idx_q_flat, torch.zeros_like(idx_q_flat)) + if idx_k_mask_flat is not None: + idx_k_flat = torch.where(idx_k_mask_flat, idx_k_flat, torch.zeros_like(idx_k_flat)) + + # Create broadcasted indices for gather + # q_broad: (B, K, W, H) + q_broad = idx_q_flat.unsqueeze(-1).expand(B, K, W, H) + # k_broad: (B, K, W, H) + k_broad = idx_k_flat.unsqueeze(-2).expand(B, K, W, H) + # batch_idx: (B, K, W, H) + batch_idx = torch.arange(B, device=z.device).reshape(B, 1, 1, 1).expand(B, K, W, H) + + # Gather: z_flat[b, q, k] + # Result: (B, K, W, H, D) + out_flat = z_flat[batch_idx, q_broad, k_broad] + + # Zero out invalid positions + # combined_mask: (B, K, W, H) - position is valid only if both q and k indices are valid + if idx_q_mask_flat is not None or idx_k_mask_flat is not None: + if idx_q_mask_flat is not None and idx_k_mask_flat is not None: + # Broadcast masks: q_mask (B,K,W) -> (B,K,W,H), k_mask (B,K,H) -> (B,K,W,H) + combined_mask = idx_q_mask_flat.unsqueeze(-1) & idx_k_mask_flat.unsqueeze(-2) + elif idx_q_mask_flat is not None: + combined_mask = idx_q_mask_flat.unsqueeze(-1).expand(-1, -1, -1, H) + else: + combined_mask = idx_k_mask_flat.unsqueeze(-2).expand(-1, -1, W, -1) + # Expand to match out_flat: (B, K, W, H) or (B, K, W, H, D) depending on trailing dims + has_trailing_dims = len(feature_shape_z) > 0 + if has_trailing_dims: + combined_mask = combined_mask.unsqueeze(-1) + out_flat = out_flat * combined_mask.to(out_flat.dtype) + + # Reshape to final output + out = out_flat + if flatten_trailing_dims: + out = out.unflatten(-1, z.shape[axis + 2 :]) + if flatten_leading_dims: + out = out.unflatten(0, batch_shape_z) + if not has_leading_dims: + out = out.squeeze(0) + + # Save for backward + tensors_to_save = [idx_q, idx_k] + if idx_q_mask is not None: + tensors_to_save.append(idx_q_mask) + if idx_k_mask is not None: + tensors_to_save.append(idx_k_mask) + ctx.save_for_backward(*tensors_to_save) + ctx.has_q_mask = idx_q_mask is not None + ctx.has_k_mask = idx_k_mask is not None + ctx.z_shape = z.shape + ctx.axis = axis + + return out + + @staticmethod + def backward(ctx, grad_output): + saved = ctx.saved_tensors + idx_q = saved[0] + idx_k = saved[1] + idx = 2 + idx_q_mask = saved[idx] if ctx.has_q_mask else None + if ctx.has_q_mask: + idx += 1 + idx_k_mask = saved[idx] if ctx.has_k_mask else None + + grad_z = outer_gather_backward(grad_output, ctx.z_shape, idx_q, idx_k, ctx.axis, idx_q_mask, idx_k_mask) + return grad_z, None, None, None, None, None + + +def outer_gather(z, one_hot_q, one_hot_k, axis=1, one_hot_q_mask=None, one_hot_k_mask=None): + """ + Efficient gather-based equivalent to einsum for window batching index selection. + + Args: + z: Input tensor (..., N, M, ...) + one_hot_q: One-hot query indices (..., K, W, N) + one_hot_k: One-hot key indices (..., K, H, M) + axis: The dimension index of the first N in z. + one_hot_q_mask: Optional mask for one_hot_q of shape (..., K, W). True = valid, False = invalid. + one_hot_k_mask: Optional mask for one_hot_k of shape (..., K, H). True = valid, False = invalid. + + Returns: + Tensor of shape (..., K, W, H, ...) + """ + if z.shape[axis] != one_hot_q.shape[-1] or z.shape[axis + 1] != one_hot_k.shape[-1]: + raise ValueError( + f"z.shape[axis] must match one_hot_q.shape[-1] and z.shape[axis + 1] must match one_hot_k.shape[-1] but got " + f"{z.shape[axis : axis + 2]} vs {one_hot_q.shape[-1:]} and {one_hot_k.shape[-1:]}" + ) + # Condense one-hot to indices + idx_q = one_hot_q.argmax(dim=-1) + idx_k = one_hot_k.argmax(dim=-1) + return OuterGather.apply(z, idx_q, idx_k, axis, one_hot_q_mask, one_hot_k_mask) + + +def compute_interval_overlap(intervals_a: torch.Tensor, intervals_b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes the intersection (overlap) of two sets of n-dimensional intervals where + the last axis represents [start, end) coordinates (exclusive end). + + Args: + intervals_a (torch.Tensor): A tensor of shape (..., n_dim, 2). + The last dimension contains [start, end). + intervals_b (torch.Tensor): A tensor of shape (..., n_dim, 2). + Must be broadcastable to intervals_a.shape. + + Returns: + tuple[torch.Tensor, torch.Tensor]: + - intervals_overlap (torch.Tensor): The computed intersection intervals with shape + matching the broadcasted shape of inputs (..., n_dim, 2). + Contains [overlap_start, overlap_end). + - mask (torch.Tensor): A boolean mask of shape (...,) corresponding to the leading + dimensions. It is True if the n-dimensional overlap is valid and non-empty. + A valid overlap is defined by: all(overlap_end > overlap_start) across the n_dim axis. + """ + start_a = intervals_a[..., 0] + end_a = intervals_a[..., 1] + start_b = intervals_b[..., 0] + end_b = intervals_b[..., 1] + + overlap_start = torch.maximum(start_a, start_b) + overlap_end = torch.minimum(end_a, end_b) + + intervals_overlap = torch.stack([overlap_start, overlap_end], dim=-1) + + valid_dims = overlap_end > overlap_start + mask = torch.all(valid_dims, dim=-1) + + return intervals_overlap, mask + + +def get_overlap_from_peers( + rank_peers: torch.Tensor, intervals_a: torch.Tensor, intervals_b: torch.Tensor +) -> List[Dict[str, torch.Tensor | int]]: + """ + Given a rank table (rank_peers) and two interval tensors, compute the overlapping + intervals and associated peer ranks. + + Args: + rank_peers: Tensor of shape matching intervals_a leading dims (...). Contains peer/global ranks + (e.g., typically from device_mesh.mesh), but no device_mesh semantics are assumed here. + intervals_a: Tensor of shape (..., n_dim, 2) where last dim is [start, end). + intervals_b: Tensor broadcastable to intervals_a.shape (..., n_dim, 2). + + Returns: + List[dict]: Each dict contains: + - "peer": int global rank + - "interval": Tensor of shape (n_dim, 2) with [start, end) along each axis. + + Notes: + - Leading dims of intervals_a/intervals_b must match rank_peers.shape so that boolean + masking (rank_peers[valid_mask]) is valid. + - Both interval tensors must use trailing (n_dim, 2) layout with exclusive ends. + """ + if intervals_a.shape[-1] != 2: + raise ValueError(f"intervals_a last dim must be 2 (start,end), got {intervals_a.shape}") + n_dim = intervals_a.shape[-2] + if intervals_b.shape[-1] != 2 or intervals_b.shape[-2] != n_dim: + raise ValueError( + f"intervals_b trailing shape must be ({n_dim}, 2) to match intervals_a; got {intervals_b.shape[-2:]}" + ) + + # Broadcast intervals_b to intervals_a shape + try: + intervals_b_broadcast = torch.broadcast_to(intervals_b, intervals_a.shape) + except RuntimeError as e: + raise ValueError(f"intervals_b is not broadcastable to intervals_a shape {intervals_a.shape}") from e + + # Validate submesh shape matches leading dims + leading_shape = intervals_a.shape[:-2] + if rank_peers.shape != leading_shape: + raise ValueError(f"rank_peers.shape {rank_peers.shape} must match intervals leading shape {leading_shape}") + + intervals_a_view = intervals_a + intervals_b_view = intervals_b_broadcast + + overlap_intervals, valid_mask = compute_interval_overlap(intervals_a_view, intervals_b_view) + if not valid_mask.any().item(): + return [] + + if valid_mask.dim() != rank_peers.dim(): + raise ValueError(f"valid_mask dims {valid_mask.dim()} must match rank_peers.dim {rank_peers.dim()}") + + peer_ranks = rank_peers[valid_mask] + interval_selected = overlap_intervals[valid_mask] + + needed = [] + for idx in range(len(peer_ranks)): + interval = interval_selected[idx] + needed.append({"peer": peer_ranks[idx].item(), "interval": interval}) + return needed + + +class DistributedOuterGather(torch.autograd.Function): + @staticmethod + def forward( + ctx, + z_dtensor: DTensor, + idx_n_dtensor: DTensor, + idx_m_dtensor: DTensor, + axis: int, + are_ids_contiguous: bool, + idx_n_mask: DTensor | None = None, + idx_m_mask: DTensor | None = None, + ) -> DTensor: + """Forward pass for DistributedOuterGather. + + "Outer" refers to gathering a Cartesian product block: for each pair of + index sets (n, m) you gather the rectangular sub-block z[..., n, m, ...], + producing an output of shape (..., K, H, *features) where H is the width + of the m index set and K is the shared leading index count. + + Args: + z_dtensor (DTensor): Tensor of shape ``(*batch, N, M, *features)``. The + ``axis`` argument identifies the start of the ``(N, M)`` block in + ``z_dtensor`` (i.e., ``z_dtensor.shape[axis] == N`` and + ``z_dtensor.shape[axis + 1] == M``). + idx_n_dtensor (DTensor): Tensor of shape ``(*batch, K, W)`` containing + gather indices into the ``N`` dimension. Must be sharded on its + ``-2`` dimension across one of the two mesh dims that shard + ``z_dtensor``'s ``axis``/``axis+1``. + idx_m_dtensor (DTensor): Tensor of shape ``(*batch, K, H)`` containing + gather indices into the ``M`` dimension. Must share device mesh and + placements with ``idx_n_dtensor``. + axis (int): Index in ``z_dtensor`` where the ``(N, M)`` block begins. + are_ids_contiguous (bool, optional): This is a heuristic for selecting the underlying + send/recv strategy for performance purpose. Currently only True is supported, + which means that the idx_n and idx_m tensors map to a contiguous block of z + along (axis, axis + 1) dimensions for all the shards and for all the leading + (batch) dimensions. When True, the underlying strategy will use the min/max + of idx_n and idx_m to compute the needed interval, assuming the resulting + buffer to be communicated across the ranks is fully (or approximately so) + utilized. In the case the data inside idx_n and idx_m doesn't mapped to + contiguous blocks, the result will still be correct by setting + are_ids_contiguous=True but the buffer of z chunks communicated will contain + a lot of unused elements, making the sed/recv inefficient. + idx_n_mask (DTensor, optional): Mask for idx_n_dtensor of shape ``(*batch, K, W)``. + Same device_mesh and placements as idx_n_dtensor. True = valid, False = invalid. + idx_m_mask (DTensor, optional): Mask for idx_m_dtensor of shape ``(*batch, K, H)``. + Same device_mesh and placements as idx_m_dtensor. True = valid, False = invalid. + + + Device mesh / placements: + - ``z_dtensor`` must be sharded along both ``axis`` and ``axis+1``. + - No ``Partial`` placements are allowed. + - ``idx_n_dtensor``/``idx_m_dtensor`` must share device mesh and + placements. Their mesh can be either the same as ``z_dtensor`` or the + flatten of ``z_dtensor``'s mesh over the two sharding axes + ``(mesh_dim_axis, mesh_dim_axis_plus_1)``. + - All co-sharding/co-replication checks enforced in the function apply. + + Returns: + DTensor: Output of shape ``(*batch, K, H, *features)`` with the same + device mesh and placements as ``idx_n_dtensor``/``idx_m_dtensor``. + """ + + if not are_ids_contiguous: + raise NotImplementedError( + "DistributedOuterGather currently only supports are_ids_contiguous=True. " + "The current implementation is not efficient when the ids are not contiguous " + "and especially so if they mapped to distal parts of z, " + "even though the result will still be correct." + ) + + # 1. Validations + if ( + not isinstance(z_dtensor, DTensor) + or not isinstance(idx_n_dtensor, DTensor) + or not isinstance(idx_m_dtensor, DTensor) + ): + raise TypeError("All inputs must be DTensors") + + batch_dims_z = z_dtensor.shape[:axis] + + if batch_dims_z != idx_n_dtensor.shape[:-2]: + raise ValueError( + f"Batch dimensions of z ({batch_dims_z}) must match batch dimensions of idx_n ({idx_n_dtensor.shape[:-2]})" + ) + + if batch_dims_z != idx_m_dtensor.shape[:-2]: + raise ValueError( + f"Batch dimensions of z ({batch_dims_z}) must match batch dimensions of idx_m ({idx_m_dtensor.shape[:-2]})" + ) + + if idx_n_dtensor.shape[-2] != idx_m_dtensor.shape[-2]: + raise ValueError( + f"Number of n indices ({idx_n_dtensor.shape[-2]}) must match number of m indices ({idx_m_dtensor.shape[-2]})" + ) + + mesh = z_dtensor.device_mesh + idx_mesh = idx_n_dtensor.device_mesh + + # idx_n and idx_m must share device_mesh/placements with each other + if idx_m_dtensor.device_mesh != idx_mesh: + raise ValueError("idx_n and idx_m must be on the same DeviceMesh") + + if mesh.ndim < idx_mesh.ndim: + raise ValueError( + f"z_dtensor.device_mesh.ndim must be no less than" + f"idx_n/m_dtensor.device_mesh.ndim but got: z_dtensor: {mesh.ndim}, idx_n: {idx_mesh.ndim}, idx_m: {idx_mesh.ndim}" + ) + + if idx_n_dtensor.placements != idx_m_dtensor.placements: + raise ValueError("idx_n and idx_m must have the same placements") + + ndim_z = z_dtensor.ndim + if axis < 0: + axis += ndim_z + if axis < 0 or axis + 1 >= ndim_z: + raise ValueError(f"axis {axis} must satisfy 0 <= axis and axis+1 < z.ndim (got z.ndim={ndim_z})") + + z_placements = z_dtensor.placements + idx_n_placements = idx_n_dtensor.placements + + # Identify sharding dims + mesh_dim_axis = None + mesh_dim_axis_plus_1 = None + mesh_dim_shard_k = None + for i_mesh_dim, p_z in enumerate(z_placements): + check_idx_placements = i_mesh_dim < len(idx_n_placements) + if isinstance(p_z, Partial) or (check_idx_placements and isinstance(idx_n_placements[i_mesh_dim], Partial)): + raise ValueError("Partial placements are not supported") + elif isinstance(p_z, Shard): + if p_z.dim == axis: + mesh_dim_axis = i_mesh_dim + elif p_z.dim == axis + 1: + mesh_dim_axis_plus_1 = i_mesh_dim + else: + # "co-sharding" requirement: we require the device_mesh axis + # that shards any z_dtensor axis other + # than (axis, axis + 1) to also shard the same tensor axis + # in idx_n_dtensor and idx_m_dtensor because the following p2p + # communication ops are restricted within the sub-device_mesh + # of (mesh_dim_axis, mesh_dim_axis_plus_1) + if check_idx_placements and idx_n_placements[i_mesh_dim].dim != p_z.dim: + raise ValueError( + f"z_dtensor's sharded axis {p_z.dim} is outside of {(axis, axis + 1)} but " + f"the same tensor axis in idx_n_dtensor and idx_m_dtensor is not sharded " + f"by the same device mesh axis {i_mesh_dim}" + ) + # any Shard placement must evenly shard any axis + if z_dtensor.shape[p_z.dim] % mesh.size(i_mesh_dim) != 0: + raise ValueError( + f"z_dtensor's sharded axis {p_z.dim} of size {z_dtensor.shape[p_z.dim]} is not " + f"evenly divisible by mesh dim {i_mesh_dim} (size {mesh.size(i_mesh_dim)})" + ) + elif isinstance(p_z, Replicate): + if check_idx_placements and not isinstance(idx_n_placements[i_mesh_dim], Replicate): + # "co-replicating" requirement: for the same reason as co-sharding, + # we require all orthogonal Replicate placements to be consistent across + # the 3 DTensors to limit the number of P2P communication ops in the + # (mesh_dim_axis, mesh_dim_axis_plus_1) sub-device_mesh + raise ValueError( + f"z_dtensor's replicate placement at mesh axis {i_mesh_dim} is expected to " + f"be Replicate placement in idx_n_dtensor but the latter's placement is {idx_n_placements[i_mesh_dim]}" + ) + if check_idx_placements and isinstance(idx_n_placements[i_mesh_dim], Shard): + i_axis_idx_n = idx_n_placements[i_mesh_dim].dim + i_axis_idx_n = i_axis_idx_n if i_axis_idx_n >= 0 else i_axis_idx_n + idx_n_dtensor.ndim + if i_axis_idx_n == idx_n_dtensor.ndim - 2: + mesh_dim_shard_k = i_mesh_dim + else: + # Due to the same "co-sharding" requirement as above, any other sharding mesh axis must + # shard the same tensor axis between z_dtensor and idx_n_dtensor and idx_m_dtensor + if not ( + isinstance(z_placements[i_mesh_dim], Shard) and z_placements[i_mesh_dim].dim == i_axis_idx_n + ): + raise ValueError( + f"idx_n_dtensor and idx_m_dtensor's sharded axis {i_axis_idx_n} by mesh axis {i_mesh_dim} " + f"is not a co-sharded axis: got z_dtensor's placement at mesh axis {i_mesh_dim} as {z_placements[i_mesh_dim]}" + ) + # any Shard placement must evenly shard any axis + if idx_n_dtensor.shape[i_axis_idx_n] % mesh.size(i_mesh_dim) != 0: + raise ValueError( + f"idx_n_dtensor's sharded axis {i_axis_idx_n} of size {idx_n_dtensor.shape[i_axis_idx_n]} is not " + f"evenly divisible by mesh dim {i_mesh_dim} (size {mesh.size(i_mesh_dim)})" + ) + + if mesh_dim_axis is None or mesh_dim_axis_plus_1 is None: + raise ValueError(f"z must be sharded along axis {axis} and {axis + 1}") + + if mesh_dim_shard_k is None: + raise ValueError("idx_n_dtensor must be sharded along axis -2") + + # Allow idx_n/idx_m to reside on a flattened mesh over (mesh_dim_axis, mesh_dim_axis_plus_1) + mesh_tensor = mesh.mesh + idx_mesh_tensor = idx_mesh.mesh + mesh_flat_tensor = torch.flatten(mesh_tensor, start_dim=mesh_dim_axis, end_dim=mesh_dim_axis_plus_1) + mesh_compatible = torch.equal(idx_mesh_tensor, mesh_tensor) or torch.equal(idx_mesh_tensor, mesh_flat_tensor) + if not mesh_compatible: + raise ValueError( + "idx_n/idx_m device_mesh must match z device_mesh or its flatten over (mesh_dim_axis, mesh_dim_axis_plus_1)" + ) + + if mesh_dim_shard_k not in (mesh_dim_axis, mesh_dim_axis_plus_1): + # one of (mesh_dim_axis, mesh_dim_axis_plus_1) must shard idx_n_dtensor and idx_m_dtensor + # along their axis -2 + raise ValueError( + f"mesh_dim_shard_k {mesh_dim_shard_k} must be one of (mesh_dim_axis, mesh_dim_axis_plus_1) " + f"but got {mesh_dim_shard_k}" + ) + + # Validate idx_n_mask if provided + if idx_n_mask is not None: + if not isinstance(idx_n_mask, DTensor): + raise TypeError("idx_n_mask must be a DTensor") + if idx_n_mask.shape != idx_n_dtensor.shape: + raise ValueError( + f"idx_n_mask shape {idx_n_mask.shape} must match idx_n_dtensor shape {idx_n_dtensor.shape}" + ) + if idx_n_mask.device_mesh != idx_n_dtensor.device_mesh: + raise ValueError("idx_n_mask must have the same device_mesh as idx_n_dtensor") + if idx_n_mask.placements != idx_n_dtensor.placements: + raise ValueError("idx_n_mask must have the same placements as idx_n_dtensor") + if idx_n_mask.dtype != torch.bool: + raise TypeError( + f"idx_n_mask must have dtype torch.bool, got {idx_n_mask.dtype}. Use mask.bool() to convert." + ) + + # Validate idx_m_mask if provided + if idx_m_mask is not None: + if not isinstance(idx_m_mask, DTensor): + raise TypeError("idx_m_mask must be a DTensor") + if idx_m_mask.shape != idx_m_dtensor.shape: + raise ValueError( + f"idx_m_mask shape {idx_m_mask.shape} must match idx_m_dtensor shape {idx_m_dtensor.shape}" + ) + if idx_m_mask.device_mesh != idx_m_dtensor.device_mesh: + raise ValueError("idx_m_mask must have the same device_mesh as idx_m_dtensor") + if idx_m_mask.placements != idx_m_dtensor.placements: + raise ValueError("idx_m_mask must have the same placements as idx_m_dtensor") + if idx_m_mask.dtype != torch.bool: + raise TypeError( + f"idx_m_mask must have dtype torch.bool, got {idx_m_mask.dtype}. Use mask.bool() to convert." + ) + + if z_dtensor.device_mesh.ndim == idx_n_dtensor.device_mesh.ndim: + # When idx_n_dtensor.device_mesh.ndim == z_dtensor.device_mesh.ndim, + # the one of idx_n_dtensor.placements along (mesh_dim_axis, mesh_dim_axis_plus_1) + # must be Replicate() and the other must be Shard(-2) because of the "co-sharding" + # and "co-replicating" requirements and the fact that we exclude Partial placements. + if mesh_dim_shard_k == mesh_dim_axis: + mesh_dim_replicate_idx = mesh_dim_axis_plus_1 + else: + mesh_dim_replicate_idx = mesh_dim_axis + if not isinstance(idx_n_placements[mesh_dim_replicate_idx], Replicate): + raise ValueError( + f"idx_n_dtensor's Replicate placement at mesh axis {mesh_dim_replicate_idx} is expected to " + f"be Replicate but the latter's placement is {idx_n_placements[mesh_dim_replicate_idx]}" + ) + else: + # NOTE: there is however an exception where idx_n_dtensor.device_mesh is a flattened + # mesh of z_dtensor along (mesh_dim_axis, mesh_dim_axis_plus_1) then the would-be Replicate() + # devices participate in sharding along idx_n_dtensor's axis -2 but then the two meshes + # ndim are different. + mesh_dim_replicate_idx = None + + # Assert even sharding as per assumption + if mesh.size(mesh_dim_axis) != mesh.size(mesh_dim_axis_plus_1): + # TODO: the p2p comm doesn't actually require this + raise ValueError("Mesh dimensions for sharding z must have equal size") + + # 2. Local Setup + z_local = z_dtensor.to_local() + idx_n_local = idx_n_dtensor.to_local() + idx_m_local = idx_m_dtensor.to_local() + idx_n_mask_local = idx_n_mask.to_local() if idx_n_mask is not None else None + idx_m_mask_local = idx_m_mask.to_local() if idx_m_mask is not None else None + + device = z_local.device + cpu_device = torch.device("cpu") + + # 3. Compute Local Bounding Box (Needed Ranges) + # TODO: min/max should not take across batch dims + # The current approach assumes the max z buffer size + # to be send/recv across the leading batch dimensions, + # e.g., if: + # idx_n_dtensor[0, :, :] and idx_m_dtensor[0, :, :] + # define/need a z buffer bounded by: + # z[0, min_0:max_0, min_1:max_1, ...], + # while: + # idx_n_dtensor[1, :, :] and idx_m_dtensor[1, :, :] + # need a z buffer bounded by: + # z[1, (min_0-delta0):(max_0+delta0), (min_1-delta1):(max_1+delta1), ...], + # where delta0 > 0 and delta1 > 0 are True, then the actual z buffer + # will be of shape: + # z[:, (min_0-delta0):(max_0+delta0), (min_1-delta1):(max_1+delta1), ...], + # i.e., the maximal z_buffer is assumed across all leading batch dimensions. + # This still makes valid computation because the local OuterGather will still + # correctly figure out the indices from idx_n_local and idx_m_local, but + # we actually send/recv more data than necessary, making communication + # suboptimal in general. + # NOTE: however, in our Boltz application, idx_n_dtensor and idx_m_dtensor are + # exclusively the atom to token indices, which are homogeneous within a sample + # across the multiplicity dimension inside AtomAttnEncoder/Decoder so the above "delta" + # is always 0 as soon as we don't actually have sample batching in the leading dimensions + # NOTE: supporting more efficient communication with minimal z buffer for send/recv + # requires a 3D version of the "get_flattened_range_indices()" from utils.py + # where we need to deal with (B, N, M) indices instead of (D, 2) indices, which + # actually incur higher memory overhead because then we requires 3 * n_elements_in_z_buf + # instead of 2 * n_elements_in_z_buf (because of the extra dimension) + # Build needed interval tensor shape (2, 2): dim0 -> n axis, dim1 -> m axis, last dim is [start, end) + if idx_n_local.numel() > 0: + if idx_n_mask_local is not None: + # Only consider valid indices for interval computation + if idx_n_mask_local.any(): + valid_idx_n = idx_n_local[idx_n_mask_local] + need_interval_n = torch.stack(valid_idx_n.aminmax()) + else: + # All indices are masked out + need_interval_n = torch.tensor([0, -1], device=device, dtype=idx_n_local.dtype) + else: + # inclusive interval of shape (2,) + need_interval_n = torch.stack(idx_n_local.aminmax()) + else: + need_interval_n = torch.tensor([0, -1], device=device, dtype=idx_n_local.dtype) + + if idx_m_local.numel() > 0: + if idx_m_mask_local is not None: + # Only consider valid indices for interval computation + if idx_m_mask_local.any(): + valid_idx_m = idx_m_local[idx_m_mask_local] + need_interval_m = torch.stack(valid_idx_m.aminmax()) + else: + # All indices are masked out + need_interval_m = torch.tensor([0, -1], device=device, dtype=idx_m_local.dtype) + else: + # inclusive interval of shape (2,) + need_interval_m = torch.stack(idx_m_local.aminmax()) + else: + need_interval_m = torch.tensor([0, -1], device=device, dtype=idx_m_local.dtype) + + # inclusive interval of both (axis, axis + 1) of shape (2, 2) + need_interval = torch.stack([need_interval_n, need_interval_m]).to(dtype=torch.long) + # left-inclusive and right-exclusive interval of shape (2, 2) + need_interval[:, -1] += 1 + need_start_vec = need_interval[:, 0] + need_end_vec = need_interval[:, 1] + # need_interval must be on same device as the device_mesh for later all_gather but + # save one cpu copy for local indexing computation + need_start_vec_cpu = need_start_vec.to(cpu_device) + + # 4. Exchange Metadata + + # Compute my owned z range as interval tensor shape (2, 2) + my_coord_vec = torch.tensor( + [mesh.get_local_rank(mesh_dim_axis), mesh.get_local_rank(mesh_dim_axis_plus_1)], device=cpu_device + ) + chunk_sizes = torch.tensor( + [ + z_dtensor.shape[axis] // mesh.size(mesh_dim_axis), + z_dtensor.shape[axis + 1] // mesh.size(mesh_dim_axis_plus_1), + ], + device=cpu_device, + ) + + own_start = my_coord_vec * chunk_sizes + own_end = own_start + chunk_sizes + own_interval = torch.stack([own_start, own_end], dim=-1) # (2, 2) + + # Gather along axis+1 (M) first, then along axis (N) to keep N leading + # TODO: due to the limitation of torch DeviceMesh not being able to retrieve + # group of its submesh, we need to manually all_gather along each mesh axis. + # An alternative is to create the submesh an flatten it then retrieve the group + # but the overhead of creating submesh and new groups will recur upon all invocations + # of this function and the resulting groups and submesh will not be managed, leading to + # waste of distributed resources. + group_m = mesh.get_group(mesh_dim_axis_plus_1) + need_range_m = [torch.zeros_like(need_interval) for _ in range(mesh.size(mesh_dim_axis_plus_1))] + dist.all_gather(need_range_m, need_interval, group=group_m) + need_range_m = torch.stack(need_range_m) # (Grid_M, 2, 2) + + group_n = mesh.get_group(mesh_dim_axis) + need_range_nm = [torch.zeros_like(need_range_m) for _ in range(mesh.size(mesh_dim_axis))] + dist.all_gather(need_range_nm, need_range_m, group=group_n) + need_range_nm = torch.stack(need_range_nm) # (Grid_N, Grid_M, 2, 2) on device + need_range_nm_cpu = need_range_nm.cpu() + + ranks_global_on_mesh = mesh.mesh + + # Get my coords in the mesh + my_coords = mesh.get_coordinate() + + # Slice out the 2D submesh over (mesh_dim_axis, mesh_dim_axis_plus_1) anchored at my_coords on other dims. + index_list_submesh = [] + for dim in range(ranks_global_on_mesh.ndim): + if dim == mesh_dim_axis or dim == mesh_dim_axis_plus_1: + index_list_submesh.append(slice(None)) + else: + index_list_submesh.append(torch.tensor(my_coords[dim], device=cpu_device)) + ranks_global_on_submesh = ranks_global_on_mesh[tuple(index_list_submesh)] # shape (size_group_n, size_group_m) + + # 5. P2P Logic + ops = [] + recv_bufs = {} # key: peer_rank, value: buffer + recv_metadata_for_bwd = [] + + # Mesh dimensions + size_group_n = mesh.size(mesh_dim_axis) + size_group_m = mesh.size(mesh_dim_axis_plus_1) + + # --- RECEIVE PLAN --- + if need_start_vec[0] >= need_end_vec[0] or need_start_vec[1] >= need_end_vec[1]: + needed_chunks = [] + else: + # We need chunks that overlap with need_interval + # Construct bounds for all chunks in the mesh + # chunk_i goes 0..size_group_n-1, chunk_j goes 0..size_group_m-1 + # We use meshgrid to generate indices for all chunks + # Note: This is similar to the Send Plan logic where we flatten the mesh view + + i_ranks_n = torch.arange(size_group_n, device=cpu_device) + i_ranks_m = torch.arange(size_group_m, device=cpu_device) + grid_n, grid_m = torch.meshgrid(i_ranks_n, i_ranks_m, indexing="ij") + grid_coords = torch.stack([grid_n, grid_m], dim=-1) # (size_group_n, size_group_m, 2) + + # Calculate bounds for each chunk + starts = grid_coords * chunk_sizes # (size_group_n, size_group_m, 2) + ends = starts + chunk_sizes # (size_group_n, size_group_m, 2) + peers_own_intervals = torch.stack([starts, ends], dim=-1) # (size_group_n, size_group_m, 2, 2) + + # Needed bounds (broadcasted) + # need_start_vec / need_end_vec are scalars (exclusive end) per axis + + # Compute overlaps between peers' owned intervals and the current rank's needed intervals + need_interval_cpu = need_interval.to(cpu_device) + # needed_chunks: list of dicts, each dict contains: + # - "peer": int global rank + # - "interval": Tensor of shape (2, 2) with [i_start, i_end) in the last axis indicating the + # interval of z along its (axis, axis + 1) dimensions. + needed_chunks = get_overlap_from_peers(ranks_global_on_submesh, peers_own_intervals, need_interval_cpu) + + for item in needed_chunks: + peer = item["peer"] + interval = item["interval"] + starts_global = interval[:, 0] + lens = interval[:, 1] - interval[:, 0] + + shape = list(z_local.shape) + shape[axis : axis + 2] = lens.tolist() + + buf = torch.empty(shape, dtype=z_local.dtype, device=device) + + if peer == dist.get_rank(): + # Self-copy + starts_local = starts_global - own_start + chunk = z_local.narrow(axis, starts_local[0].item(), lens[0].item()).narrow( + axis + 1, starts_local[1].item(), lens[1].item() + ) + buf.copy_(chunk) + recv_bufs[peer] = buf + recv_metadata_for_bwd.append((peer, interval, shape)) + else: + ops.append(dist.P2POp(dist.irecv, buf, peer)) + recv_bufs[peer] = buf + recv_metadata_for_bwd.append((peer, interval, shape)) + + # --- SEND PLAN --- + send_metadata_for_bwd = [] + + # need_range_nm: (Grid_N, Grid_M, 4) -> (size_group_n, size_group_m, 4) + # Assuming need_range_nm[i, j] corresponds to peer at coords (i, j) relative to (mesh_dim_axis, mesh_dim_axis_plus_1) + # Flatten need_range_nm to process all at once + + send_chunks = get_overlap_from_peers( + ranks_global_on_submesh, + need_range_nm_cpu.view(size_group_n, size_group_m, 2, 2), + own_interval.view(1, 1, 2, 2), + ) + + my_rank = dist.get_rank() + + for item in send_chunks: + peer_rank = item["peer"] + if peer_rank == my_rank: + continue + + interval = item["interval"] + starts_global = interval[:, 0] + lens = interval[:, 1] - interval[:, 0] + + starts_local = starts_global - own_start + + chunk = z_local.narrow(axis, starts_local[0].item(), lens[0].item()).narrow( + axis + 1, starts_local[1].item(), lens[1].item() + ) + chunk = chunk.contiguous() + ops.append(dist.P2POp(dist.isend, chunk, peer_rank)) + + send_metadata_for_bwd.append( + ( + peer_rank, + starts_local.to(device=cpu_device, dtype=torch.long), + lens.to(device=cpu_device, dtype=torch.long), + ) + ) + + # Execute P2P + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # 6. Local Computation + if need_start_vec[0] >= need_end_vec[0] or need_start_vec[1] >= need_end_vec[1]: + buffer_shape = list(z_local.shape) + buffer_shape[axis] = 0 + buffer_shape[axis + 1] = 0 + z_buffer = torch.empty(buffer_shape, dtype=z_local.dtype, device=device) + else: + buffer_shape = list(z_local.shape) + buffer_shape[axis] = (need_end_vec[0] - need_start_vec[0]).item() + buffer_shape[axis + 1] = (need_end_vec[1] - need_start_vec[1]).item() + z_buffer = torch.zeros(buffer_shape, dtype=z_local.dtype, device=device) + + for item in needed_chunks: + peer = item["peer"] + interval = item["interval"] + if peer not in recv_bufs: + raise RuntimeError( + f"Peer {peer} not in recv_bufs, which should not happen because " + f"the recv planning should guarantee all peer ranks' contribution " + f"to the recv buffer arrive by now" + ) + buf = recv_bufs[peer] + starts = interval[:, 0] + lens = interval[:, 1] - interval[:, 0] + + starts_local_z_buffer = starts - need_start_vec_cpu + + target = z_buffer.narrow(axis, starts_local_z_buffer[0].item(), lens[0].item()).narrow( + axis + 1, starts_local_z_buffer[1].item(), lens[1].item() + ) + target.copy_(buf) + + local_idx_n = idx_n_local - need_start_vec[0].item() + local_idx_m = idx_m_local - need_start_vec[1].item() + + # Use OuterGather instead of RectangularOuterGather + out_local = OuterGather.apply(z_buffer, local_idx_n, local_idx_m, axis, idx_n_mask_local, idx_m_mask_local) + + # 7. Output DTensor + out_global_shape = list(idx_n_dtensor.shape) + H = idx_m_dtensor.shape[-1] + feature_shape = z_dtensor.shape[axis + 2 :] + + final_global_shape = list(out_global_shape) + [H] + list(feature_shape) + strides_out = update_exhaustive_strides(out_local.shape, out_local.stride(), tuple(final_global_shape)) + + out_dtensor = DTensor.from_local( + out_local, idx_n_dtensor.device_mesh, idx_n_placements, shape=tuple(final_global_shape), stride=strides_out + ) + + tensors_to_save = [idx_n_local, idx_m_local] + if idx_n_mask_local is not None: + tensors_to_save.append(idx_n_mask_local) + if idx_m_mask_local is not None: + tensors_to_save.append(idx_m_mask_local) + ctx.save_for_backward(*tensors_to_save) + ctx.has_n_mask = idx_n_mask_local is not None + ctx.has_m_mask = idx_m_mask_local is not None + ctx.comm_meta = { + "recv_metadata_for_bwd": recv_metadata_for_bwd, + "send_metadata_for_bwd": send_metadata_for_bwd, + "z_local_shape": z_local.shape, + "z_buffer_shape": z_buffer.shape, + "need_interval": need_interval, + "axis": axis, + "device_mesh_z": z_dtensor.device_mesh, + "z_placements": z_placements, + "z_global_shape": z_dtensor.shape, + "output_placements": out_dtensor.placements, + "own_interval": own_interval, + "device_mesh_output": out_dtensor.device_mesh, + "mesh_dim_replicate_idx": mesh_dim_replicate_idx, + } + + return out_dtensor + + @staticmethod + def backward(ctx, grad_output: DTensor) -> Tuple[DTensor, None, None, None, None, None, None]: + """Backward pass for DistributedOuterGather. + + Uses forward-phase send/recv metadata (saved in ctx) to drive the P2P + communication plan that redistributes gradient contributions back to the + owning shards of ``z_dtensor``. + + Args: + grad_output (DTensor): Gradient of the output with shape + ``(*batch, K, H, *features)`` and the same device mesh/placements + as the forward output. + + Returns: + Tuple[DTensor, None, None, None, None, None, None]: Gradient for ``z_dtensor`` and + ``None`` for non-differentiable inputs. + """ + saved = ctx.saved_tensors + idx_n_local = saved[0] + idx_m_local = saved[1] + idx = 2 + idx_n_mask_local = saved[idx] if ctx.has_n_mask else None + if ctx.has_n_mask: + idx += 1 + idx_m_mask_local = saved[idx] if ctx.has_m_mask else None + meta = ctx.comm_meta + recv_meta = meta["recv_metadata_for_bwd"] + send_meta = meta["send_metadata_for_bwd"] + z_local_shape = meta["z_local_shape"] + z_buffer_shape = meta["z_buffer_shape"] + need_interval = meta["need_interval"] + need_start_vec = need_interval[:, 0] + axis = meta["axis"] + device_mesh_z = meta["device_mesh_z"] + z_placements = meta["z_placements"] + output_placements = meta["output_placements"] + z_global_shape = meta["z_global_shape"] + own_interval = meta["own_interval"] + device_mesh_output = meta["device_mesh_output"] + mesh_dim_replicate_idx = meta["mesh_dim_replicate_idx"] + + if device_mesh_output != grad_output.device_mesh: + raise ValueError( + f"grad_output device_mesh mismatch: expected same device_mesh from fwd's output: {device_mesh_output}, " + f"but got {grad_output.device_mesh}" + ) + + # TODO: we can actually support grad_output.placements[mesh_dim_replicate_idx] == Partial("sum") + # where then the latter division of grad_z_local by device_mesh.size(mesh_dim_replicate_idx) is + # not more necessary because the grad_output virtually would perform the all-reduce along that mesh axis + if output_placements != grad_output.placements: + raise ValueError( + f"grad_output placements mismatch: expected same placements from fwd's output: {output_placements}, " + f"but got {grad_output.placements}" + ) + + grad_local = grad_output.to_local() + # Ensure contiguous for backward ops + grad_local = grad_local.contiguous() + + local_idx_n = idx_n_local - need_start_vec[0].item() + local_idx_m = idx_m_local - need_start_vec[1].item() + + # Use OuterGather.backward logic (actually uses outer_gather_backward helper) + + grad_z_buffer = outer_gather_backward( + grad_local, z_buffer_shape, local_idx_n, local_idx_m, axis, idx_n_mask_local, idx_m_mask_local + ) + + ops = [] + grad_z_local = torch.zeros(z_local_shape, dtype=grad_local.dtype, device=grad_local.device) + + # 1. Backward Send (Reverse of Fwd Recv) + for peer, interval, shape in recv_meta: + # grad_z_buffer corresponds to the z_buffer in the forward pass + # so we need to convert the interval to the local indices of grad_z_buffer + # by subtracting the need_start_vec, which is the global indices of the z_buffer + need_start_vec_cpu = need_start_vec.to(interval.device) + starts_local_z_buffer = interval[:, 0] - need_start_vec_cpu + lens = interval[:, 1] - interval[:, 0] + + grad_chunk = grad_z_buffer.narrow(axis, starts_local_z_buffer[0].item(), lens[0].item()).narrow( + axis + 1, starts_local_z_buffer[1].item(), lens[1].item() + ) + + if peer == dist.get_rank(): + # Self-accumulate into grad_z_local + # Here we need to the offset the z_local as in the forward pass input + # to copy upstream adjoints into z_local's gradients + starts_local = interval[:, 0] - own_interval[:, 0] + + target = grad_z_local.narrow(axis, starts_local[0].item(), lens[0].item()).narrow( + axis + 1, starts_local[1].item(), lens[1].item() + ) + target.add_(grad_chunk) + else: + ops.append(dist.P2POp(dist.isend, grad_chunk.contiguous(), peer)) + + # 2. Backward Recv (Reverse of Fwd Send) + bwd_recv_bufs = [] + + for peer, starts_local, lens in send_meta: + # send_meta only contains peers != self + shape = list(z_local_shape) + shape[axis : axis + 2] = lens.tolist() + + buf = torch.empty(shape, dtype=grad_local.dtype, device=grad_local.device) + ops.append(dist.P2POp(dist.irecv, buf, peer)) + bwd_recv_bufs.append((buf, starts_local, lens)) + + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + for buf, starts_local, lens in bwd_recv_bufs: + target = grad_z_local.narrow(axis, starts_local[0].item(), lens[0].item()).narrow( + axis + 1, starts_local[1].item(), lens[1].item() + ) + target.add_(buf) + + if mesh_dim_replicate_idx is not None: + # When grad_output.placements[mesh_dim_replicate_idx] == Replicate(), + # the grad_output contribution among that mesh dimensions are duplicated + # so we need to divide by the number of devices in that mesh dimension to get the + # correct gradient. + grad_z_local = grad_z_local / device_mesh_output.size(mesh_dim_replicate_idx) + + grad_z_dtensor = DTensor.from_local( + grad_z_local, + device_mesh_z, + z_placements, + shape=z_global_shape, + stride=update_exhaustive_strides(grad_z_local.shape, grad_z_local.stride(), z_global_shape), + ) + + return grad_z_dtensor, None, None, None, None, None, None + + +def distributed_outer_gather( + z_dtensor: DTensor, + idx_n_dtensor: DTensor, + idx_m_dtensor: DTensor, + axis: int = 1, + are_ids_contiguous: bool = False, + idx_n_mask: DTensor | None = None, + idx_m_mask: DTensor | None = None, +) -> DTensor: + """Distributed outer gather convenience wrapper. + + "Outer" means taking the Cartesian product of two index sets: for every pair + of indices drawn from ``idx_n_dtensor`` (along the ``N`` axis) and + ``idx_m_dtensor`` (along the ``M`` axis), the corresponding rectangular + block ``z_dtensor[..., n, m, ...]`` is gathered, producing an output of + shape ``(*batch, K, H, *features)`` where ``K`` is the shared leading index + count and ``H`` is the width of the ``m`` index set. + + Args: + z_dtensor (DTensor): Shape ``(*batch, N, M, *features)``. Must be sharded + on the ``(N, M)`` block along two mesh dims corresponding to + ``axis``/``axis+1``; no ``Partial`` placements allowed. + idx_n_dtensor (DTensor): Shape ``(*batch, K, W)``. Shares device mesh and + placements with ``idx_m_dtensor``. Must be sharded on ``-2`` over one + of the two mesh dims that shard ``z_dtensor``'s ``(N, M)`` block. + Device mesh can be the same as ``z_dtensor`` or the flatten over the + two sharding mesh dims. + idx_m_dtensor (DTensor): Shape ``(*batch, K, H)``, same mesh/placements as + ``idx_n_dtensor``. + axis (int, optional): Start axis of the ``(N, M)`` block in ``z_dtensor``. + Defaults to 1. + are_ids_contiguous (bool, optional): This is a heuristic for selecting the underlying + send/recv strategy for performance purpose. Currently only True is supported, + which means that the idx_n and idx_m tensors map to a contiguous block of z + along (axis, axis + 1) dimensions for all the shards and for all the leading + (batch) dimensions. When True, the underlying strategy will use the min/max + of idx_n and idx_m to compute the needed interval, assuming the resulting + buffer to be communicated across the ranks is fully (or approximately so) + utilized. In the case the data inside idx_n and idx_m doesn't mapped to + contiguous blocks, the result will still be correct by setting + are_ids_contiguous=True but the buffer of z chunks communicated will contain + a lot of unused elements, making the sed/recv inefficient. + Defaults to False as a reminder to the user to understand the performance implications. + idx_n_mask (DTensor, optional): Mask for idx_n_dtensor of shape ``(*batch, K, W)``. + Same device_mesh and placements as idx_n_dtensor. True = valid, False = invalid. + idx_m_mask (DTensor, optional): Mask for idx_m_dtensor of shape ``(*batch, K, H)``. + Same device_mesh and placements as idx_m_dtensor. True = valid, False = invalid. + + Returns: + DTensor: Gathered output with shape ``(*batch, K, H, *features)``, using + the mesh/placements of the index tensors. + """ + return DistributedOuterGather.apply( + z_dtensor, idx_n_dtensor, idx_m_dtensor, axis, are_ids_contiguous, idx_n_mask, idx_m_mask + ) diff --git a/src/boltz/distributed/model/layers/outer_op.py b/src/boltz/distributed/model/layers/outer_op.py new file mode 100644 index 000000000..e8a63bca9 --- /dev/null +++ b/src/boltz/distributed/model/layers/outer_op.py @@ -0,0 +1,704 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from enum import Enum, auto +from typing import Optional + +import torch +from torch.distributed.tensor import DTensor, Replicate, Shard + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.model.layers.redistribute_transpose_without_dtensor import transpose_then_redistribute +from boltz.distributed.utils import LayoutRightMap + + +class OuterOp(Enum): + SUM = auto() + SUBTRACT = auto() + EQUAL = auto() + BITAND = auto() + PROD = auto() + CDIST = auto() # Special case for pairwise distance computation + + +class DistributedOuterOp(torch.autograd.Function): + """Custom autograd function for distributed outer-[add, subtract, equal, bitand] operations. + + The outer operation assumes the input tensor is sharded along axis 0 of the + transpose_comm's 2d group grid and also replicated along axis 1 of the same grid: + + input_expanded + input_expanded.transpose(axis, axis + 1) + return binary_op(input_expanded, input_expanded.transpose(axis, axis + 1)) + + Currently, only add, subtract, equal and bitand operations are supported. + """ + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + input_t: torch.Tensor | None, + op: OuterOp, + axis: int, + transpose_comm: TransposeComm, + group_replicate: torch.distributed.ProcessGroup, + ) -> torch.Tensor: + """Forward pass for DistributedOuterOp. + + Args: + ctx: Context object to save information for backward pass + input: Input tensor for outer operation. This is assumed sharded + along axis 0 of the transpose_comm's 2d group grid and also + replicated along axis 1 of the same grid. + input_t: Second input tensor for outer operation to be transposed across 2d group grid. Will use `input` if not provided with None. + op: the binary operation to perform + axis: Axis along which to perform the outer op + transpose_comm: Communication object for distributed operations + group_replicate: Process group for input's replication across ranks + + Returns: + Tensor with outer op computed + + Raises: + ValueError: If ranks are inconsistent with the expected process group configuration + """ + rank_replicate = torch.distributed.get_rank(group_replicate) + rank_global = transpose_comm.global_rank + if rank_replicate < 0: + raise ValueError( + f"global rank {rank_global} doesn't belong to group_replicate as " + f"get_rank(group_replicate) returned {rank_replicate}" + ) + if rank_replicate != transpose_comm.rank_coords[1]: + raise ValueError( + f"global rank {rank_global} is not along the input tensor replicating axis, " + f"which is assumed axis 1 of the transpose_comm's 2d grid, as its rank in the " + f"grid is {transpose_comm.rank_coords[1]} but the group_replicate's rank is {rank_replicate}" + ) + ctx.transpose_comm = transpose_comm + ctx.group_replicate = group_replicate + ctx.axis = axis + ctx.op = op + ctx.is_symmetric = input_t is None + + if input_t is None: + input_t = input + + input_expanded = input.unsqueeze(axis + 1) + input_expanded_t = input_t.unsqueeze(axis + 1) + transposed = transpose_then_redistribute(input_expanded_t, axis, axis + 1, transpose_comm) + if op == OuterOp.SUM: + output = input_expanded + transposed + elif op == OuterOp.SUBTRACT: + output = input_expanded - transposed + elif op == OuterOp.EQUAL: + # boolean output can't be backpropagated but + # if we were to output a float equivalent, we + # save the output as a mask to be used in + # the backward pass + output = input_expanded == transposed + ctx.mark_non_differentiable(output) + elif op == OuterOp.BITAND: + if input_expanded.dtype.is_floating_point or transposed.dtype.is_floating_point: + raise ValueError("input_expanded and transposed must be boolean tensors") + # bitwise AND operation can't be backpropagated + output = input_expanded & transposed + ctx.mark_non_differentiable(output) + else: + raise ValueError(f"Unsupported operation: {op}") + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> tuple[torch.Tensor | None, torch.Tensor | None, None, None, None, None]: + """Backward pass for DistributedOuterOp. + + Args: + ctx: Context object with saved information from forward pass + grad_output: Gradient tensor from downstream layers + + Returns: + Tuple containing the gradient for input tensor and None for other parameters + """ + transpose_comm = ctx.transpose_comm + group_replicate = ctx.group_replicate + axis = ctx.axis + op = ctx.op + if op == OuterOp.EQUAL or op == OuterOp.BITAND: + # If EQUAL op and the forward were to output float instead of bool mask + # then we can use the saved output as mask applied on grad_output here + # e.g., + # if op == OuterOp.EQUAL: + # mask = ctx.saved_tensors + # grad_output = grad_output * mask + # BITAND also produces non-differentiable output + return None, None, None, None, None, None + + # grad on right summand + grad_transposed = grad_output.sum(dim=axis, keepdim=True).transpose(axis, axis + 1).contiguous() + if op == OuterOp.SUBTRACT: + grad_transposed = -grad_transposed + grad_transposed_recv = transpose_comm.enqueue_to_dispatch(grad_transposed) + # grad on left summand, which always retain the positive sign + grad_input_expanded = grad_output.sum(dim=axis + 1, keepdim=True) + transpose_comm.wait_until_finished() + + # perform allreduce to get the row- and column-wise contributions + if ctx.is_symmetric: + grad_input = (grad_input_expanded + grad_transposed_recv).squeeze(dim=axis + 1) + torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, group=group_replicate) + return grad_input, None, None, None, None, None + else: + torch.distributed.all_reduce(grad_input_expanded, op=torch.distributed.ReduceOp.SUM, group=group_replicate) + torch.distributed.all_reduce(grad_transposed_recv, op=torch.distributed.ReduceOp.SUM, group=group_replicate) + grad_transposed_recv = grad_transposed_recv.squeeze(dim=axis + 1) + grad_input_expanded = grad_input_expanded.squeeze(dim=axis + 1) + return grad_input_expanded, grad_transposed_recv, None, None, None, None + + +class DistributedCDist(torch.autograd.Function): + """Custom CP autograd function for torch.cdist. + + Currently supports the default computation, which is the L2 norm. + Currently supports self-distance. + """ + + @staticmethod + def forward( + ctx, + input_array: torch.Tensor, + transpose_comm: TransposeComm, + group_replicate: torch.distributed.ProcessGroup, + ) -> torch.Tensor: + """Forward pass for DistributedCDist. + + Args: + ctx: Context object to save information for backward pass + input_array: Input tensor for outer operation. This is assumed sharded + along axis 0 of the transpose_comm's 2d group grid and also + replicated along axis 1 of the same grid. + The input tensors sharding dimension is expected to be (-2). + transpose_comm: Communication object for distributed operations + group_replicate: Process group for input's replication across ranks + + Returns: + Tensor with outer op computed + + Raises: + ValueError: If ranks are inconsistent with the expected process group configuration + """ + rank_replicate = torch.distributed.get_rank(group_replicate) + rank_global = transpose_comm.global_rank + if rank_replicate < 0: + raise ValueError( + f"global rank {rank_global} doesn't belong to group_replicate as " + f"get_rank(group_replicate) returned {rank_replicate}" + ) + ctx.transpose_comm = transpose_comm + ctx.group_replicate = group_replicate + axis = len(input_array.shape) - 2 + ctx.axis = axis + input_expanded = input_array.unsqueeze(axis + 1) + transposed = transpose_then_redistribute(input_expanded, axis, axis + 1, transpose_comm) + # Pairwise distance calculation + # diff is size (B, N, N, D) + diff = transposed - input_expanded + diff_sq = diff * diff + diff_sum = diff_sq.sum(dim=-1) + output = diff_sum.sqrt() + if input_array.requires_grad: + ctx.save_for_backward(diff) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, None, None]: + """Backward pass for DistributedOuterOp. + + Args: + ctx: Context object with saved information from forward pass + grad_output: Gradient tensor from downstream layers + + Returns: + Tuple containing the gradient for input tensor and None for other parameters + """ + if ctx.needs_input_grad[0] is False: + return None, None, None + transpose_comm = ctx.transpose_comm + group_replicate = ctx.group_replicate + (diff,) = ctx.saved_tensors + # Dists is recomputed to save memory. + diff_sq = diff * diff + diff_sum = diff_sq.sum(dim=-1) + dists = diff_sum.sqrt() + # grad is (B, N, N) + # grad transposed is (B, N, N) + grad_transposed = transpose_then_redistribute(grad_output, ctx.axis, ctx.axis + 1, transpose_comm) + dists = dists + 1e-8 + # (B, N, N, D) + diff_over_dists = diff / dists.unsqueeze(-1) + grad_term = diff_over_dists * (grad_output + grad_transposed).unsqueeze(-1) + # (B, N, N, D) to (B, N, D) + grad_input = -grad_term.sum(dim=2) + # Sum reduction across ranks - this is an extension of the above sum. + torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, group=group_replicate) + return grad_input, None, None + + +def distributed_cdist( + input_array: torch.Tensor, + transpose_comm: TransposeComm | None = None, + group_replicate: torch.distributed.ProcessGroup | None = None, +): + """Performs cdist operation, sharded if communication objects are passed, serial if not. + + Args: + input_array: Input tensor for outer operation. This is assumed sharded + along axis 0 of the transpose_comm's 2d group grid and also + replicated along axis 1 of the same grid. + The input tensors sharding dimension is expected to be (-2). + transpose_comm: Communication object for distributed operations + group_replicate: Process group for input's replication across ranks + """ + if (transpose_comm is None) != (group_replicate is None): + raise ValueError("transpose_comm and group_replicate must both be provided or both be None") + if transpose_comm is None: + return torch.cdist(input_array, input_array) + return DistributedCDist.apply(input_array, transpose_comm, group_replicate) + + +class _ReplicateToShardOuterOp(torch.autograd.Function): + """DTensor version of DistributedOuterOp for outer-[add, subtract, equal, bitand, cdist] operations with R2S transformation. + + The outer operation assumes the input DTensors have placements of (Shard(0), Shard(1), Replicate()) + and produces output DTensors with placements of (Shard(0), Shard(1), Shard(2)). + + The Replicate() placement in the input is used as an extra buffer for the result's Shard(2) placement, + enabling the R2S (Replicate -> Shard) transformation during the outer operation. + + For most operations: + input_expanded + input_expanded.transpose(axis, axis + 1) + return binary_op(input_expanded, input_expanded.transpose(axis, axis + 1)) + + For CDIST operation: + diff = input_expanded.transpose(axis, axis + 1) - input_expanded + return (diff * diff).sum(dim=-1).sqrt() + + Currently, add, subtract, equal, bitand, and cdist operations are supported. + """ + + @staticmethod + def forward( + ctx, + input: DTensor, + input_t: DTensor | None, + op: OuterOp, + axis: int, + transpose_comm: TransposeComm, + ) -> DTensor: + """Forward pass for _ReplicateToShardOuterOp. + + Args: + ctx: Context object to save information for backward pass + input: Input DTensor for outer operation. Must have placements (Shard(0), Shard(1), Replicate()) + input_t: Second input DTensor for outer operation to be transposed across 2d group grid. + Will use `input` if not provided with None. + op: the binary operation to perform + axis: Axis along which to perform the outer op. Must be one of the Shard placement dimensions (0 or 1). + transpose_comm: Communication object for distributed operations + + Returns: + DTensor with outer op computed + + Raises: + TypeError: If inputs are not DTensors + ValueError: If device meshes don't match, placements are incorrect, axis is invalid, or uneven sharding detected + """ + # Type checking + if not isinstance(input, DTensor): + raise TypeError(f"input must be DTensor, got {type(input)}") + + if input_t is not None and not isinstance(input_t, DTensor): + raise TypeError(f"input_t must be DTensor or None, got {type(input_t)}") + + # Check required placements for input: (Shard(0), Shard(1), Replicate()) + input_placements = (Shard(0), Shard(1), Replicate()) + + if input.placements != input_placements: + raise ValueError(f"input must have placements {input_placements}. Got {input.placements}") + + # Get device mesh from input + device_mesh = input.device_mesh + + # Check that axis is one of the Shard placement dimensions + shard_dims = [0, 1] # From Shard(0) and Shard(1) + if axis not in shard_dims: + raise ValueError(f"axis must be one of the Shard placement dimensions {shard_dims}. Got {axis}") + + # Check for uneven sharding in input tensor + if input.shape[0] % device_mesh.shape[0] != 0: + raise ValueError( + f"Uneven sharding detected: input tensor dimension 0 of size {input.shape[0]} " + f"is not evenly divisible by device mesh dimension 0 of size {device_mesh.shape[0]}" + ) + + if input.shape[1] % device_mesh.shape[1] != 0: + raise ValueError( + f"Uneven sharding detected: input tensor dimension 1 of size {input.shape[1]} " + f"is not evenly divisible by device mesh dimension 1 of size {device_mesh.shape[1]}" + ) + + # Set input_t to input if None (symmetric case) + if input_t is None: + input_t = input + is_symmetric = True + else: + is_symmetric = False + # Check device mesh compatibility + if input_t.device_mesh != input.device_mesh: + raise ValueError( + f"input_t device_mesh mismatch: expected {input.device_mesh}, got {input_t.device_mesh}" + ) + if input_t.placements != input_placements: + raise ValueError(f"input_t must have placements {input_placements}. Got {input_t.placements}") + if input_t.shape != input.shape: + raise ValueError(f"input and input_t must have the same shape. Got {input.shape} and {input_t.shape}") + + # Infer group_replicate from device mesh (axis 2 corresponds to Replicate()) + group_replicate = device_mesh.get_group(2) + + # Validate ranks (adapted from original DistributedOuterOp) + rank_replicate = torch.distributed.get_rank(group_replicate) + rank_global = transpose_comm.global_rank + if rank_replicate < 0: + raise ValueError( + f"global rank {rank_global} doesn't belong to group_replicate as " + f"get_rank(group_replicate) returned {rank_replicate}" + ) + if rank_replicate != transpose_comm.rank_coords[1]: + raise ValueError( + f"global rank {rank_global} is not along the input tensor replicating axis, " + f"which is assumed axis 1 of the transpose_comm's 2d grid, as its rank in the " + f"grid is {transpose_comm.rank_coords[1]} but the group_replicate's rank is {rank_replicate}" + ) + + # Define output placements: (Shard(0), Shard(1), Shard(2)) - R2S transformation + output_placements = (Shard(0), Shard(1), Shard(2)) + + # Save context for backward pass + ctx.device_mesh = device_mesh + ctx.input_placements = input_placements + ctx.output_placements = output_placements + ctx.transpose_comm = transpose_comm + ctx.group_replicate = group_replicate + ctx.axis = axis + ctx.op = op + ctx.is_symmetric = is_symmetric + ctx.input_shape = input.shape + ctx.input_stride = input.stride() + ctx.input_t_shape = input_t.shape + ctx.input_t_stride = input_t.stride() + + # Extract local tensors + input_local = input.to_local() + input_t_local = input_t.to_local() + + # Perform the outer operation computation (adapted from original) + input_expanded = input_local.unsqueeze(axis + 1) + input_expanded_t = input_t_local.unsqueeze(axis + 1) + transposed = transpose_then_redistribute(input_expanded_t, axis, axis + 1, transpose_comm) + + if op == OuterOp.SUM: + output_local = input_expanded + transposed + elif op == OuterOp.SUBTRACT: + output_local = input_expanded - transposed + elif op == OuterOp.EQUAL: + # boolean output can't be backpropagated but + # if we were to output a float equivalent, we + # save the output as a mask to be used in + # the backward pass + output_local = input_expanded == transposed + ctx.mark_non_differentiable(output_local) + elif op == OuterOp.BITAND: + # bitwise AND operation can't be backpropagated + output_local = input_expanded & transposed + ctx.mark_non_differentiable(output_local) + elif op == OuterOp.PROD: + output_local = input_expanded * transposed + # Save operands for backward: d(a*b)/da = b, d(a*b)/db = a + if input.requires_grad: + ctx.save_for_backward(input_expanded, transposed) + elif op == OuterOp.CDIST: + # Pairwise distance calculation: L2 norm of difference + diff_local = input_expanded - transposed + output_local = (diff_local * diff_local).sum(dim=-1).sqrt() + # Save diff for backward pass + if input.requires_grad: + # this is the gradient of the output with respect to the difference + # as a prefactor to be multiplied with the downstream gradient + d_output_d_diff_local = diff_local / (output_local.unsqueeze(-1) + torch.finfo(output_local.dtype).tiny) + ctx.save_for_backward(d_output_d_diff_local) + + # Compute output shape and stride + if op == OuterOp.CDIST: + # CDIST output has reduced the last dimension + shape_output = input.shape[:2] + (input.shape[1],) + else: + # Outer operation adds a dimension: (B, N, D, ...) -> (B, N, N, D, ...) for most ops + shape_output = input.shape[:2] + (input.shape[1],) + input.shape[2:] + + # Use LayoutRightMap for the output shape + layout_right = LayoutRightMap(shape_output) + strides_output = layout_right.strides + + # Convert back to DTensor with output placements (R2S transformation: Replicate -> Shard) + output = DTensor.from_local( + output_local, device_mesh, output_placements, shape=shape_output, stride=strides_output + ) + + return output + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[DTensor | None, DTensor | None, None, None, None]: + """Backward pass for _ReplicateToShardOuterOp. + + Args: + ctx: Context object with saved information from forward pass + grad_output: Gradient DTensor from downstream layers + + Returns: + Tuple containing the gradient DTensors for input and input_t, and None for other parameters + """ + # Sanity check grad_output DTensor metadata + if not isinstance(grad_output, DTensor): + raise TypeError(f"grad_output must be DTensor, got {type(grad_output)}") + + if grad_output.device_mesh != ctx.device_mesh: + raise ValueError( + f"grad_output device_mesh mismatch: expected {ctx.device_mesh}, got {grad_output.device_mesh}" + ) + + if grad_output.placements != ctx.output_placements: + raise ValueError( + f"grad_output placements mismatch: expected {ctx.output_placements}, got {grad_output.placements}" + ) + + # Extract local gradient + grad_output_local = grad_output.to_local() + + # Backward pass logic (adapted from original DistributedOuterOp) + transpose_comm = ctx.transpose_comm + group_replicate = ctx.group_replicate + axis = ctx.axis + op = ctx.op + + if op == OuterOp.EQUAL or op == OuterOp.BITAND: + # If EQUAL op and the forward were to output float instead of bool mask + # then we can use the saved output as mask applied on grad_output here + # e.g., + # if op == OuterOp.EQUAL: + # mask = ctx.saved_tensors + # grad_output = grad_output * mask + # BITAND also produces non-differentiable output + return None, None, None, None, None + elif op == OuterOp.PROD: + # d(a*b)/da = b, d(a*b)/db = a — each gradient term uses a different multiplier, + # so we handle the full reduction here rather than falling through to the common path. + (input_expanded_local, transposed_local) = ctx.saved_tensors + # local grad for left operand (input): multiply by transposed, reduce over broadcast dim + grad_input_expanded = (grad_output_local * transposed_local).sum(dim=axis + 1, keepdim=True) + # local grad for right operand (input_t): multiply by input_expanded, reduce then transpose back + grad_transposed = ( + (grad_output_local * input_expanded_local) + .sum(dim=axis, keepdim=True) + .transpose(axis, axis + 1) + .contiguous() + ) + grad_transposed_recv = transpose_comm.enqueue_to_dispatch(grad_transposed) + transpose_comm.wait_until_finished() + + if ctx.is_symmetric: + grad_input_local = (grad_input_expanded + grad_transposed_recv).squeeze(dim=axis + 1) + torch.distributed.all_reduce(grad_input_local, op=torch.distributed.ReduceOp.SUM, group=group_replicate) + grad_input = DTensor.from_local( + grad_input_local, + ctx.device_mesh, + ctx.input_placements, + shape=ctx.input_shape, + stride=ctx.input_stride, + ) + return grad_input, None, None, None, None + else: + torch.distributed.all_reduce( + grad_input_expanded, op=torch.distributed.ReduceOp.SUM, group=group_replicate + ) + torch.distributed.all_reduce( + grad_transposed_recv, op=torch.distributed.ReduceOp.SUM, group=group_replicate + ) + grad_input_expanded = grad_input_expanded.squeeze(dim=axis + 1) + grad_transposed_recv = grad_transposed_recv.squeeze(dim=axis + 1) + grad_input = DTensor.from_local( + grad_input_expanded, + ctx.device_mesh, + ctx.input_placements, + shape=ctx.input_shape, + stride=ctx.input_stride, + ) + grad_input_t = DTensor.from_local( + grad_transposed_recv, + ctx.device_mesh, + ctx.input_placements, + shape=ctx.input_t_shape, + stride=ctx.input_t_stride, + ) + return grad_input, grad_input_t, None, None, None + elif op == OuterOp.CDIST: + # we multiply the d_output_d_diff_local by the upstream adjoint so that + # the rest of the backward pass is the same as the other symmetric ops + (d_output_d_diff_local,) = ctx.saved_tensors + # grad_output_local is now d_loss_d_diff as in the upstream adjoint of OuterOp.SUBTRACT + grad_output_local = grad_output_local.unsqueeze(-1) * d_output_d_diff_local + + # grad on right summand + grad_transposed = grad_output_local.sum(dim=axis, keepdim=True).transpose(axis, axis + 1).contiguous() + if op == OuterOp.SUBTRACT or op == OuterOp.CDIST: + grad_transposed = -grad_transposed + grad_transposed_recv = transpose_comm.enqueue_to_dispatch(grad_transposed) + # grad on left summand, which always retain the positive sign + grad_input_expanded = grad_output_local.sum(dim=axis + 1, keepdim=True) + transpose_comm.wait_until_finished() + + # perform allreduce to get the row- and column-wise contributions + if ctx.is_symmetric: + grad_input_local = (grad_input_expanded + grad_transposed_recv).squeeze(dim=axis + 1) + torch.distributed.all_reduce(grad_input_local, op=torch.distributed.ReduceOp.SUM, group=group_replicate) + + # Convert gradients back to DTensors + grad_input = DTensor.from_local( + grad_input_local, ctx.device_mesh, ctx.input_placements, shape=ctx.input_shape, stride=ctx.input_stride + ) + return grad_input, None, None, None, None + else: + torch.distributed.all_reduce(grad_input_expanded, op=torch.distributed.ReduceOp.SUM, group=group_replicate) + torch.distributed.all_reduce(grad_transposed_recv, op=torch.distributed.ReduceOp.SUM, group=group_replicate) + grad_transposed_recv = grad_transposed_recv.squeeze(dim=axis + 1) + grad_input_expanded = grad_input_expanded.squeeze(dim=axis + 1) + + # Convert gradients back to DTensors + grad_input = DTensor.from_local( + grad_input_expanded, + ctx.device_mesh, + ctx.input_placements, + shape=ctx.input_shape, + stride=ctx.input_stride, + ) + grad_input_t = DTensor.from_local( + grad_transposed_recv, + ctx.device_mesh, + ctx.input_placements, + shape=ctx.input_t_shape, + stride=ctx.input_t_stride, + ) + return grad_input, grad_input_t, None, None, None + + +def distributed_outer_op( + input: torch.Tensor, + op: OuterOp, + axis: int, + input_t: Optional[torch.Tensor] = None, + transpose_comm: Optional[TransposeComm] = None, + group_replicate: Optional[torch.distributed.ProcessGroup] = None, +) -> torch.Tensor: + """Perform an outer op operation with optional distribution across processes. + + This function computes the outer op of a tensor along a specified axis. When + transpose_comm and group_replicate are provided, the operation is performed in + a distributed manner across multiple processes. + + Args: + input: Input tensor for outer op operation. This is assumed sharded + along axis 0 of the transpose_comm's 2d group grid and also + replicated along axis 1 of the same grid. + op: the binary operation to perform + axis: Axis along which to perform the outer op + transpose_comm: Optional communication object for distributed operations + group_replicate: Optional process group for replication across ranks + input_t: Optional second input tensor for outer op operation. Will use `input` if not provided with None. + + Returns: + Tensor with outer op computed + + Raises: + ValueError: If only one of transpose_comm or group_replicate is provided + """ + if (transpose_comm is None) != (group_replicate is None): + raise ValueError("transpose_comm and group_replicate must both be provided or both be None") + if transpose_comm is None and group_replicate is None: + if input_t is None: + input_t = input + input_expanded = input.unsqueeze(axis + 1) + input_expanded_t = input_t.unsqueeze(axis + 1) + if op == OuterOp.SUM: + return input_expanded + input_expanded_t.transpose(axis, axis + 1) + elif op == OuterOp.SUBTRACT: + return input_expanded - input_expanded_t.transpose(axis, axis + 1) + elif op == OuterOp.EQUAL: + return input_expanded == input_expanded_t.transpose(axis, axis + 1) + elif op == OuterOp.BITAND: + return input_expanded & input_expanded_t.transpose(axis, axis + 1) + else: + raise ValueError(f"Unsupported operation: {op}") + else: + return DistributedOuterOp.apply(input, input_t, op, axis, transpose_comm, group_replicate) + + +def replicate_to_shard_outer_op( + input: DTensor, + op: OuterOp, + axis: int, + transpose_comm: TransposeComm, + input_t: Optional[DTensor] = None, +) -> DTensor: + """Perform an outer op operation on DTensors with distributed computation and R2S transformation. + + This function computes the outer op of DTensors along a specified axis with + distributed processing across multiple devices. It performs a Replicate->Shard (R2S) + transformation where the input's Replicate() placement becomes Shard(2) in the output. + + Args: + input: Input DTensor for outer op operation. Must have placements (Shard(0), Shard(1), Replicate()) + op: the binary operation to perform (SUM, SUBTRACT, EQUAL, BITAND, or CDIST) + axis: Axis along which to perform the outer op. Must be 0 or 1 (corresponding to Shard dimensions). + transpose_comm: communication handle for distributed operations + input_t: Optional second input DTensor for outer op operation. Will use `input` if not provided. + Note: For CDIST operation, input_t is ignored and input is used for both operands. + + Returns: + DTensor with outer op computed and placements (Shard(0), Shard(1), Shard(2)) + For CDIST operation, the output shape is (B, N, N) instead of (B, N, N, D) + + Raises: + TypeError: If inputs are not DTensors + ValueError: If placements are incorrect or transpose_comm is None when DTensors are provided + """ + return _ReplicateToShardOuterOp.apply(input, input_t, op, axis, transpose_comm) diff --git a/src/boltz/distributed/model/layers/outer_product_mean.py b/src/boltz/distributed/model/layers/outer_product_mean.py new file mode 100644 index 000000000..a5262bc90 --- /dev/null +++ b/src/boltz/distributed/model/layers/outer_product_mean.py @@ -0,0 +1,597 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import torch +import torch.distributed as dist +from torch import nn +from torch.distributed.tensor import DTensor, Shard +from torch.distributed.tensor.device_mesh import DeviceMesh + +from boltz.distributed.comm import Ring2DComm +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.linear import LinearParamsReplicated +from boltz.distributed.utils import LayoutRightMap +from boltz.model.layers.outer_product_mean import OuterProductMean as SerialOuterProductMean + + +class _OuterProductMeanImpl(torch.autograd.Function): + """Distributed implementation of outer product mean using ring communication. + + This autograd function implements a memory-efficient distributed outer product mean + operation across a 2D process grid. The computation is parallelized using ring + communication patterns to reduce memory usage and communication overhead. + + The outer product mean computes: + z[i,j] = mean_s(a[s,i] ⊗ b[s,j]) + + where ⊗ denotes outer product, and the mean is taken over the sequence dimension s. + + Key features: + - Distributed across a 2D grid with sharding on sequence (dim 1) and token (dim 2) dimensions + - Uses ring communication to rotate data chunks during computation + - Memory-efficient implementation that avoids materializing full tensors + - Supports gradient computation through custom backward pass + + Notes + ----- + Input tensors must be DTensors with: + - Shape: (B, N_seq, N_token, c_hidden) for tensors a and b + - Shape: (B, N_seq, N_token, 1) for mask tensor + - Sharding on dimensions 1 and 2 (Shard(1) and Shard(2) placements) + - Identical device mesh and placements across all inputs + + The algorithm uses a ring-based communication pattern where: + - Tensor a is transposed and rotated by row + - Tensor b is rotated by column + - Each process computes partial outer products and accumulates results + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward(ctx, a: DTensor, b: DTensor, mask: DTensor, ring_comm: Ring2DComm): + """Forward pass of distributed outer product mean computation. + + Computes the outer product mean of input tensors a and b using distributed + ring communication to minimize memory usage and communication overhead. + + a/b sharding is done along the s and N dimensions (dim 1 and 2). + For an i, j rank, the output is the sum of outer products over a[..., i] and b[..., j]. + Example initial layout for a 2x2 grid:: + [[a00, a01], and [[b00, b01], + [a10, a11]] and [b10, b11]] + a is transposed and rotated by row, b is NOT transposed, and is rotated by column. + After transpose, the layout is: + [[a00, a10], + [a01, a11]] + An initial offset is added for each. For a, row i is rotated by i elements, and likewise for the columns of b. + After offset, + [[a00, a10], and [[b00, b11], + [a11, a10]] and [b10, b10]] + Note that (for example) grid element (1, 0) has a[11] and b[10], corresponding to a[.., i] and b[..., ,j], + with matching secondary index of 1. + After 1 rotation + [[a10, a00], and [[b10, b00], + [a01, a11]] and [b00, b11]], + the same (1, 0) index has a01 and b00, with the same i,j match. + + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object for saving information needed in backward pass. + a : DTensor + First input tensor with shape (B, N_seq, N_token, c_hidden). + Must be sharded on dimensions 1 and 2. + b : DTensor + Second input tensor with shape (B, N_seq, N_token, c_hidden). + Must have identical shape, device mesh, and placements as tensor a. + mask : DTensor + Mask tensor with shape (B, N_seq, N_token). + Must have same device mesh and placements as input tensors. + ring_comm : Ring2DComm + Ring communication object configured for the distributed computation. + + Returns + ------- + DTensor + Output tensor with shape (B, N_token, N_token, c_hidden*c_hidden). + Contains the distributed outer product mean result. + + Raises + ------ + TypeError + If inputs are not DTensor type. + ValueError + If tensor shapes, device meshes, or placements are incompatible, + or if ring communication setup is inconsistent. + """ + # Check if inputs a and b are of type DTensor + if not isinstance(a, DTensor) or not isinstance(b, DTensor) or not isinstance(mask, DTensor): + raise TypeError( + f"Inputs 'a', 'b', and 'mask' must be of type DTensor. Got types {type(a)}, {type(b)}, and {type(mask)}." + ) + + # Check if inputs a and b have identical device mesh + device_mesh_input = a.device_mesh + if device_mesh_input != b.device_mesh: + raise ValueError( + f"Input tensors 'a' and 'b' must have identical device mesh. " + f"Got device meshes {device_mesh_input} and {b.device_mesh}." + ) + if device_mesh_input != mask.device_mesh: + raise ValueError( + f"Input tensor 'mask' must have the same device mesh as the input tensors 'a' and 'b'. " + f"Got device meshes {mask.device_mesh} and {device_mesh_input}." + ) + + # Check if inputs a and b have identical placements + placements_input = a.placements + if placements_input != b.placements: + raise ValueError( + f"Input tensors 'a' and 'b' must have identical placements. " + f"Got placements {placements_input} and {b.placements}." + ) + if placements_input != mask.placements: + # TODO: in the future, if a and b are sharded along the hidden dimension, we need to + # skip the check on the corresponding grid axes + raise ValueError( + f"Input tensor 'mask' must have the same placements as the input tensors 'a' and 'b'. " + f"Got placements {mask.placements} and {placements_input}." + ) + if placements_input != (Shard(0), Shard(1), Shard(2)): + # For debugging, we requires the placements to be (Shard(0), Shard(1), Shard(2)) + # TODO: remove this to only use the previous check + raise ValueError( + f"Input tensors 'a' and 'b''s placements are not (Shard(0), Shard(1), Shard(2)). " + f"Got placements {placements_input}." + ) + + # Check if inputs a and b have the same shape + if a.shape != b.shape: + raise ValueError(f"Input tensors 'a' and 'b' must have the same shape. Got shapes {a.shape} and {b.shape}.") + + if mask.shape[:3] != a.shape[:3]: + raise ValueError( + f"Input tensor 'mask' doesn't have the same size in the 3 leading dimensions " + f"as the input tensors 'a' and 'b': Got shape mask's shape: {mask.shape} vs. a.shape: {a.shape}" + ) + + # to stay off potential DTensor issue, let's not use negative dim axis when we can avoid it + # but this requires we assume the semantics of the axes in order to check the placements: + # a.shape [B, N_seq, N_token, c_hidden] + if a.ndim != 4: + raise ValueError( + f"Input tensors 'a' and 'b' must have 4 dimensions. " + f"Got {a.ndim} dimensions for tensor 'a' and {b.ndim} dimensions for tensor 'b'." + ) + + # Perform consistency check between the ring_comm and the device_mesh_input + # TODO: we could also check the device_mesh_input.mesh tensor, but it would be too expensive as the check + # will likely go through all elements on the mesh + # NOTE: leading batch dimensions may or may not be sharded, as this algorithm operates orthogonally to them. + # 1. Check if Shard(1) and Shard(2) exist in placements_input + i_tensor_dim_to_i_grid_axis = [-1] * a.ndim + for i_grid_axis, placement in enumerate(placements_input): + if isinstance(placement, Shard): + i_tensor_dim_to_i_grid_axis[placement.dim] = i_grid_axis + if i_tensor_dim_to_i_grid_axis[1] == -1 or i_tensor_dim_to_i_grid_axis[2] == -1: + raise ValueError( + f"Input tensors 'a', 'b' and 'mask's dimensions 1 and 2 must be sharded. Got placements {placements_input}." + ) + # 2. Check if ring_comm.group_col match the device_mesh_input's group + # NOTE: ring_comm.group_col is the group sharding the input tensors' axis 1 + if ring_comm.group_col != device_mesh_input.get_group(i_tensor_dim_to_i_grid_axis[1]): + raise ValueError( + "Input ring_comm's group_col process group is not the same as the group sharding the input tensors' axis 1" + ) + # 3. Check if the rank coordinates are consistent + coord_device_mesh_input = device_mesh_input.get_coordinate() + if coord_device_mesh_input is None: + raise ValueError( + f"ring_comm.coord_2d {ring_comm.coord_2d} is not on device_mesh_input {device_mesh_input}." + ) + if ring_comm.coord_2d != ( + coord_device_mesh_input[i_tensor_dim_to_i_grid_axis[1]], + coord_device_mesh_input[i_tensor_dim_to_i_grid_axis[2]], + ): + raise ValueError( + f"Input ring_comm's coord_2d {ring_comm.coord_2d} does not match the " + f"device mesh's rank coordinates {coord_device_mesh_input} for the sharded dimensions " + f"{i_tensor_dim_to_i_grid_axis[1]} and {i_tensor_dim_to_i_grid_axis[2]}." + ) + + ctx.mark_non_differentiable(mask) + mask_local = mask.to_local().unsqueeze(-1) + # DTensor.to_local() returns a view to the shard so we need to clone it to avoid modifying the original DTensor. + a_local = a.to_local() * mask_local + b_local = b.to_local() * mask_local + a_local_copy = a_local.detach().clone() + b_local_copy = b_local.detach().clone() + + # Sum mask count along columns to get divisor for mean. + + B, _, N, c_hidden = a_local.shape + + # send off A transpose + row init, b column init + a_recv = ring_comm.comm_transpose_row_init.enqueue_to_dispatch(a_local) + b_recv = ring_comm.comm_col_init.enqueue_to_dispatch(b_local) + + z_local = torch.zeros((B, N, N, c_hidden, c_hidden), dtype=a_local.dtype, device=a_local.device) + + ring_comm.comm_col_init.wait_until_finished() + ring_comm.comm_transpose_row_init.wait_until_finished() + + num_mask_local = mask_local[:, :, 0].sum(1)[:, None, None] + num_mask_work = dist.all_reduce(num_mask_local, group=ring_comm.group_col, async_op=True) + a_buffer = [a_recv, a_local] + b_buffer = [b_recv, b_local] + i_ready = 0 + i_recv = i_ready ^ 1 + for k_step in range(ring_comm.group_layout.shape[1]): + a_ready = a_buffer[i_ready] + b_ready = b_buffer[i_ready] + if k_step < ring_comm.group_layout.shape[1] - 1: + a_buffer[i_recv] = ring_comm.comm_row.enqueue_to_dispatch(a_ready, a_buffer[i_recv]) + b_buffer[i_recv] = ring_comm.comm_col.enqueue_to_dispatch(b_ready, b_buffer[i_recv]) + z_local = z_local + torch.einsum("bsic,bsjd->bijcd", a_ready, b_ready) + if k_step < ring_comm.group_layout.shape[1] - 1: + ring_comm.comm_row.wait_until_finished() + ring_comm.comm_col.wait_until_finished() + i_ready = i_ready ^ 1 + i_recv = i_recv ^ 1 + + num_mask_work.wait() + num_mask_local_clamped = num_mask_local.clamp(min=1) + z_local = z_local.flatten(start_dim=-2) / num_mask_local_clamped + + # Compute output shape and stride + shape_output = (a.shape[0], a.shape[2], a.shape[2], z_local.shape[-1]) # (B, N, N, c_hidden * c_hidden) + + # Use LayoutRightMap for the output shape + layout_right = LayoutRightMap(shape_output) + strides_output = layout_right.strides + + if a.requires_grad or b.requires_grad: + ctx.save_for_backward(a_local_copy, b_local_copy, mask_local.detach().clone(), num_mask_local_clamped) + ctx.ring_comm = ring_comm + ctx.placements_input = placements_input + ctx.device_mesh_input = device_mesh_input + ctx.input_shape_a = a.shape + ctx.input_stride_a = a.stride() + ctx.input_shape_b = b.shape + ctx.input_stride_b = b.stride() + + z = DTensor.from_local( + z_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=shape_output, + stride=strides_output, + ) + return z + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_z: DTensor) -> tuple[DTensor, DTensor, None, None]: + """Backward pass of distributed outer product mean computation. + + Computes gradients with respect to input tensors a and b using the same + ring communication pattern as the forward pass but with transposed operations. + + The gradient computation follows: + - grad_a[s,i] = sum_j(grad_z[i,j] ⊗ b[s,j]) + - grad_b[s,j] = sum_i(grad_z[i,j] ⊗ a[s,i]) + + The sharding of a and b is described above in the forward pass. + z is similarly sharded, but the two sharding dimensions are both (N, N). + For an i, j rank: + Gradient of A is the sum of outer products over grad_z[j, ...] and b[j, ...]. + Gradient of B is the sum of outer products over grad_z[..., j] and a[i, ...]. + For the gradient of a, the approach is identical to that of the forward pass, this time with transposition of + grad_z, rotation of b by row, and grad_z by column. View the above schematic for details. + For the gradient of b, the approach is similar, but with no transposition of grad_z. + Starting with: + [[g00, g01], and [[a00, a01], + [g10, g11]] and [a10, a11]] + After initial rotation: + [[g00, g11], and [[a00, a01], + [g10, g01]] and [a11, a10]] + where index i,j =(0, 1) has g11 and a01 corresponding to grad_z[..., j] and a[i, ...], with matching secondary + index of 1. + The next rotation of grad_z up and b left yields: + [[g10, g01], and [[a01, a00], + [g00, g11]] and [a11, a10]] + where index i,j = (0, 1) has g01 and a00, corresponding to grad_z[..., j] and a[i, ...], with matching + secondary index of 0. + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object containing saved tensors and communication setup. + grad_z : DTensor + Gradient tensor from upstream with shape (B, N_token, N_token, c_hidden*c_hidden). + Must have same device mesh and placements as forward pass inputs. + + Returns + ------- + tuple[DTensor, DTensor, None, None] + Gradients with respect to inputs: + - grad_a: Gradient for tensor a with shape (B, N_seq, N_token, c_hidden) + - grad_b: Gradient for tensor b with shape (B, N_seq, N_token, c_hidden) + - None: No gradient for mask (marked non-differentiable) + - None: No gradient for ring_comm + + Raises + ------ + TypeError + If grad_z is not a DTensor. + ValueError + If grad_z has incompatible shape, device mesh, or placements. + """ + + if not isinstance(grad_z, DTensor): + raise TypeError(f"Input 'grad_z' must be of type DTensor. Got type {type(grad_z)}.") + + if grad_z.ndim != 4: + raise ValueError(f"Input 'grad_z' must have 4 dimensions but got {grad_z.ndim} dimensions") + + if grad_z.device_mesh != ctx.device_mesh_input: + raise ValueError( + f"Input 'grad_z' must have the same device mesh as the input tensors 'a' and 'b'. " + f"Got device meshes {grad_z.device_mesh} and {ctx.device_mesh_input}." + ) + + if grad_z.placements != ctx.placements_input: + raise ValueError( + f"Input 'grad_z' must have the same placements as the input tensors 'a' and 'b'. " + f"Got placements {grad_z.placements} and {ctx.placements_input}." + ) + + a_local, b_local, mask_local, num_mask_local_clamped = ctx.saved_tensors + + # reshape grad_z's last axis to perform GEMV + c_hidden = a_local.shape[-1] + # apply the mask and clone + grad_z_local = grad_z.to_local().unflatten(-1, (c_hidden, c_hidden)) / num_mask_local_clamped.unsqueeze(-1) + + if grad_z_local.shape[:3] != ( + a_local.shape[0], + a_local.shape[2], + a_local.shape[2], + ): + raise ValueError( + f"grad_z shard shape {grad_z_local.shape} does not match expected shape " + f"({a_local.shape[0]}, {a_local.shape[2]}, {a_local.shape[2]}) for outer product mean." + ) + + # Save for reset for B grad compute. + grad_z_save = grad_z_local.clone() + + ring_comm: Ring2DComm = ctx.ring_comm + + # Initialize gradient buffers + grad_a_local = torch.zeros_like(a_local) + + # For grad_a computation: Transpose z grad, rotate z by column and b by row. + grad_z_recv = ring_comm.comm_transpose_col_init.enqueue_to_dispatch(grad_z_local) + b_recv = ring_comm.comm_row_init.enqueue_to_dispatch(b_local) + ring_comm.comm_transpose_col_init.wait_until_finished() + ring_comm.comm_row_init.wait_until_finished() + + b_buffer = [b_recv, b_local] + grad_z_buffer = [grad_z_recv, grad_z_local] + i_ready = 0 + i_recv = i_ready ^ 1 + + # Compute grad_a + for k_step in range(ring_comm.group_layout.shape[1]): + b_ready = b_buffer[i_ready] + grad_z_ready = grad_z_buffer[i_ready] + + if k_step < ring_comm.group_layout.shape[1] - 1: + b_buffer[i_recv] = ring_comm.comm_row.enqueue_to_dispatch(b_ready, b_buffer[i_recv]) + grad_z_buffer[i_recv] = ring_comm.comm_col.enqueue_to_dispatch(grad_z_ready, grad_z_buffer[i_recv]) + + grad_a_local = grad_a_local + torch.einsum("bijcd,bsjd->bsic", grad_z_ready, b_ready) + + if k_step < ring_comm.group_layout.shape[1] - 1: + ring_comm.comm_row.wait_until_finished() + ring_comm.comm_col.wait_until_finished() + i_ready = i_ready ^ 1 + i_recv = i_recv ^ 1 + + grad_a_local = grad_a_local * mask_local + grad_a = DTensor.from_local( + grad_a_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=ctx.input_shape_a, + stride=ctx.input_stride_a, + ) + + # For grad_b computation: no transpose, rotate z by column, a by row. + # NOTE: we're using the stored original grad_output. + grad_z_recv = ring_comm.comm_col_init.enqueue_to_dispatch(grad_z_save) + a_recv = ring_comm.comm_row_init.enqueue_to_dispatch(a_local) + ring_comm.comm_row_init.wait_until_finished() + ring_comm.comm_col_init.wait_until_finished() + + a_buffer = [a_recv, a_local] + grad_z_buffer = [grad_z_recv, grad_z_local] + i_ready = 0 + i_recv = i_ready ^ 1 + + # Reuse b for grad_b, since we're done with it. + grad_b_local = b_local.view(b_local.shape) + grad_b_local *= 0.0 + # Compute grad_b + for k_step in range(ring_comm.group_layout.shape[1]): + a_ready = a_buffer[i_ready] + grad_z_ready = grad_z_buffer[i_ready] + + if k_step < ring_comm.group_layout.shape[1] - 1: + a_buffer[i_recv] = ring_comm.comm_row.enqueue_to_dispatch(a_ready, a_buffer[i_recv]) + grad_z_buffer[i_recv] = ring_comm.comm_col.enqueue_to_dispatch(grad_z_ready, grad_z_buffer[i_recv]) + + grad_b_local = grad_b_local + torch.einsum("bijcd,bsic->bsjd", grad_z_ready, a_ready) + + if k_step < ring_comm.group_layout.shape[1] - 1: + ring_comm.comm_row.wait_until_finished() + ring_comm.comm_col.wait_until_finished() + i_ready = i_ready ^ 1 + i_recv = i_recv ^ 1 + + grad_b_local = grad_b_local * mask_local + grad_b = DTensor.from_local( + grad_b_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=ctx.input_shape_b, + stride=ctx.input_stride_b, + ) + + return grad_a, grad_b, None, None + + +class OuterProductMean(nn.Module): + """Distributed outer product mean layer for sequence-to-pair transformations. + + This layer implements a distributed version of the outer product mean operation, + which transforms sequence representations into pairwise representations. It's + commonly used in protein structure prediction and other tasks requiring + sequence-to-pair information propagation. + + The layer performs the following operations: + 1. Layer normalization of input sequences + 2. Linear projections to create two representation streams (a and b) + 3. Distributed outer product mean computation using ring communication + 4. Final linear projection to output dimension + + The outer product mean operation computes: + z[i,j] = mean_s(proj_a(norm(m))[s,i] ⊗ proj_b(norm(m))[s,j]) + + where the mean is taken over the sequence dimension s, masked by the input mask. + + Parameters + ---------- + layer : SerialOuterProductMean + The serial outer product mean layer to convert to distributed version. + Used to initialize projection weights and normalization parameters. + The weights and biases of the input layer will be replicated across the device mesh. + device_mesh : DeviceMesh + The device mesh for distributed computation across multiple GPUs. + comm : Ring2DComm + Ring communication object for efficient distributed outer product computation. + + Attributes + ---------- + c_hidden : int + Hidden dimension size from the projection layers. + c_in : int + Input dimension size. + norm : LayerNormParamsReplicated + Distributed layer normalization. + proj_a : LinearParamsReplicated + First projection layer (input -> hidden). + proj_b : LinearParamsReplicated + Second projection layer (input -> hidden). + proj_o : LinearParamsReplicated + Output projection layer (hidden*hidden -> output). + + Notes + ----- + This implementation requires input tensors to be DTensors with appropriate + sharding patterns. The layer is designed for large-scale distributed training + where memory efficiency is critical. + """ + + def __init__(self, layer: SerialOuterProductMean, device_mesh: DeviceMesh, comm: Ring2DComm) -> None: + """Initialize the distributed outer product mean layer. + + Parameters + ---------- + layer : SerialOuterProductMean + The serial outer product mean layer containing weights to be distributed. + device_mesh : DeviceMesh + Device mesh defining the distributed computation topology. + comm : Ring2DComm + Ring communication handler for distributed outer product operations. + + """ + super().__init__() + self.device_mesh = device_mesh + self.ring_comm = comm + self.c_hidden = layer.c_hidden + self.c_in = layer.proj_a.in_features + self.norm = LayerNormParamsReplicated(layer.norm, self.device_mesh) + self.proj_a = LinearParamsReplicated(layer.proj_a, self.device_mesh) + self.proj_b = LinearParamsReplicated(layer.proj_b, self.device_mesh) + self.proj_o = LinearParamsReplicated(layer.proj_o, self.device_mesh) + + def forward(self, m: DTensor, mask: DTensor) -> DTensor: + """Forward pass of the distributed outer product mean layer. + + Transforms sequence representations into pairwise representations using + the distributed outer product mean operation with masking support. + + Parameters + ---------- + m : DTensor + Input sequence tensor with shape (B, S, N, c_in). + - B: batch size + - S: sequence length + - N: number of tokens + - c_in: input feature dimension + It's expected that the sequence and token dimensions are both sharded + the device mesh, which is further required to be the same as the one + the weights of the layer are placed on + mask : DTensor + Mask tensor with shape (B, S, N) indicating valid positions. + Values should be 1.0 for valid positions and 0.0 for masked positions. + It's expected that the sequence and token dimensions are both sharded + the device mesh, which is further required to be the same as the one + the weights of the layer are placed on + + Returns + ------- + DTensor + Output pairwise tensor with shape (B, N, N, c_out). + Contains pairwise representations between all token pairs. + + Notes + ----- + The computation pipeline is: + 1. Apply layer normalization to input sequences + 2. Project to two hidden representations (a and b) and apply masking + 3. Compute distributed outer product mean: z[i,j] = mean_s(a[s,i] ⊗ b[s,j]) + 4. Apply final projection to get output dimension + """ + # No need to expand mask here because it's handled in _OuterProductMeanImpl + + # Compute projections + m = self.norm(m) + a = self.proj_a(m) + b = self.proj_b(m) + + z = _OuterProductMeanImpl.apply(a, b, mask, self.ring_comm) + z = self.proj_o(z) + return z diff --git a/src/boltz/distributed/model/layers/pair_averaging.py b/src/boltz/distributed/model/layers/pair_averaging.py new file mode 100644 index 000000000..6095d4da0 --- /dev/null +++ b/src/boltz/distributed/model/layers/pair_averaging.py @@ -0,0 +1,726 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from copy import deepcopy + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor, Shard +from torch.distributed.tensor.device_mesh import DeviceMesh + +from boltz.distributed.comm import One2OneComm, TransposeComm, ternary_parity +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.linear import LinearParamsReplicated +from boltz.distributed.utils import LayoutMap, get_group_rank_from_axial_shift, tiled_softmax_attention_update +from boltz.model.layers.pair_averaging import PairWeightedAveraging as SerialPairWeightedAveraging + + +class Ring2DCommPairAveraging: + """ + Implements communication primitives for distributed operations on a 2D grid of devices. + + This class provides general-purpose ring communication patterns for operations like + TriangleMultiplication and OuterProductMean across a 2D grid of devices. Unlike + Ring2DCommTriAttn which is specialized for triangular attention, this class provides + more general ring communication patterns. + + The communication patterns implemented include: + 1. Transpose communication for matrix operations + 2. Row-wise ring communication (left shifts) + 3. Column-wise ring communication (up shifts) + + Parameters + ---------- + group_2d : dist.ProcessGroup + The process group representing the 2D grid of devices. This should include + all processes in the distributed computation. + group_col : dist.ProcessGroup + A subprocess group that provides communication between ranks in the same column. + group_layout : LayoutMap + A mapping from the 2D grid index to the flattened index of the devices on the 2D grid. + Must represent a square grid (same dimensions in both axes). + + Notes + ----- + The class implements various communication patterns needed for distributed matrix + operations, including initial communication (with different shift patterns based on + coordinates) and subsequent iterations (with fixed shifts). + + Communication ordering is carefully managed to prevent deadlocks by using + ternary_parity to determine consistent send/receive ordering across different ranks. + """ + + def __init__( + self, + group_2d: dist.ProcessGroup, + group_col: dist.ProcessGroup, + group_layout: LayoutMap, + ): + """ + Ring comm over a 2d grid of devices with comm happening along both axes + Arguments: + group_2d: Group torch process group that provides communication + across the full cross-device + group_col: Subprocess group that provides communication + between ranks in the same column + group_layout: mapping from the 2d grid index to the flatten index + of the devices on the 2d grid + """ + # TODO: consolidate the ring 2d comm groups with other modules e,g. triangle attn + self.group_2d = group_2d + self.group_col = group_col + self.group_layout = group_layout + ranks_group_2d = set(dist.get_process_group_ranks(self.group_2d)) + ranks_group_col = set(dist.get_process_group_ranks(self.group_col)) + + if not ranks_group_col.issubset(ranks_group_2d): + raise ValueError("The col ranks are not a subset of ranks_group_2d") + + self.size_2d = dist.get_world_size(self.group_2d) + + if self.size_2d != self.group_layout.numel: + raise ValueError( + f"size of group_2d {self.size_2d} differs from the number of elements in group_layout {self.group_layout.numel}" + ) + + if self.group_layout.shape[0] != self.group_layout.shape[1]: + raise ValueError(f"group_layout.shape {self.group_layout.shape} is not square") + + self.rank_2d = dist.get_rank(self.group_2d) + self.coord_2d = self.group_layout.unravel(self.rank_2d) + + # all the send/recv ranks must be global in order to use isend/irecv + # only need transpose at the beginning of the batched GEMM for b or a + self.comm_2d_trans = TransposeComm(self.group_2d, self.group_layout) + + # always do left shift per row + # for iteration 0, i'th row left shift by i column + self.send_rank_row_init = get_group_rank_from_axial_shift( + self.coord_2d, 1, -self.coord_2d[0], self.group_layout + ) + self.recv_rank_row_init = get_group_rank_from_axial_shift(self.coord_2d, 1, self.coord_2d[0], self.group_layout) + + self.comm_row_init = One2OneComm( + self.group_2d, + self.send_rank_row_init, + self.recv_rank_row_init, + parity=ternary_parity(self.rank_2d, self.send_rank_row_init, self.recv_rank_row_init), + ) + # for other iterations left shift by 1 col + self.send_rank_row = get_group_rank_from_axial_shift(self.coord_2d, 1, -1, self.group_layout) + self.recv_rank_row = get_group_rank_from_axial_shift(self.coord_2d, 1, 1, self.group_layout) + + self.comm_row = One2OneComm( + self.group_2d, + self.send_rank_row, + self.recv_rank_row, + parity=ternary_parity(self.rank_2d, self.send_rank_row, self.recv_rank_row), + ) + + # always do up shift per col + # for iteration 0, j'th col up shift by j row + self.send_rank_col_init = get_group_rank_from_axial_shift( + self.coord_2d, 0, -self.coord_2d[1], self.group_layout + ) + self.recv_rank_col_init = get_group_rank_from_axial_shift(self.coord_2d, 0, self.coord_2d[1], self.group_layout) + self.comm_col_init = One2OneComm( + self.group_2d, + self.send_rank_col_init, + self.recv_rank_col_init, + parity=ternary_parity(self.rank_2d, self.send_rank_col_init, self.recv_rank_col_init), + ) + # for other iterations, up shift by 1 row + self.send_rank_col = get_group_rank_from_axial_shift(self.coord_2d, 0, -1, self.group_layout) + self.recv_rank_col = get_group_rank_from_axial_shift(self.coord_2d, 0, 1, self.group_layout) + self.comm_col = One2OneComm( + self.group_2d, + self.send_rank_col, + self.recv_rank_col, + parity=ternary_parity(self.rank_2d, self.send_rank_col, self.recv_rank_col), + ) + + self.comm_d_init = deepcopy(self.comm_row_init) + self.comm_d = deepcopy(self.comm_row) + + self.comm_db = deepcopy(self.comm_col) + + # down shift j'th col by j row to reset db's data ownership as the input b + self.send_rank_db_final = get_group_rank_from_axial_shift(self.coord_2d, 0, self.coord_2d[1], self.group_layout) + self.recv_rank_db_final = get_group_rank_from_axial_shift( + self.coord_2d, 0, -self.coord_2d[1], self.group_layout + ) + self.comm_db_final = One2OneComm( + self.group_2d, + self.send_rank_db_final, + self.recv_rank_db_final, + parity=ternary_parity(self.rank_2d, self.send_rank_db_final, self.recv_rank_db_final), + ) + + self.comm_2d_trans_lse_m = deepcopy(self.comm_2d_trans) + self.comm_2d_trans_amax = deepcopy(self.comm_2d_trans) + + +class _PairWeightedAveragingImpl(torch.autograd.Function): + """Distributed implementation of pair weighted averaging using ring communication. + + This autograd function implements a memory-efficient distributed pair weighted averaging + operation across a 2D process grid. The computation is parallelized using ring + communication patterns to reduce memory usage and communication overhead. + + The pair weighted averaging operation computes: + o[s,i] = sum_j(softmax(b[i,j]) * v[s,j]) + + where the softmax is taken over the last dimension j, and is masked by the input mask. + + Key features: + - Distributed across a 2D grid with sharding on sequence (dim 1) and token (dim 2) dimensions + - Uses ring communication to rotate data chunks during computation + - Memory-efficient implementation that avoids materializing full tensors + - Supports gradient computation through custom backward pass + + Notes + ----- + Input tensors must be DTensors with: + - Shape: (B, H, S, N, c_h) for tensor v (value) + - Shape: (B, H, N, N) for tensor b (bias/attention weights) + - Shape: (B, N, N) for mask tensor + - Sharding on dimensions 2 and 3 (Shard(2) and Shard(3) placements) + - Identical device mesh and placements across all inputs + + The algorithm uses a ring-based communication pattern where: + - Tensor v is rotated by row + - Tensor b is rotated by column + - Each process computes partial weighted averages and accumulates results + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, v: DTensor, b: DTensor, mask: DTensor, g: DTensor, comm: Ring2DCommPairAveraging, n_heads: int, inf: float + ) -> DTensor: + """Forward pass of distributed pair weighted averaging computation. + + Computes the pair weighted averaging of input tensors v and b using distributed + ring communication to minimize memory usage and communication overhead. + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object for saving information needed in backward pass. + v : DTensor + Value tensor with shape (B, S, N, h_c_h). + Must be sharded on dimensions 1 and 2. + b : DTensor + Bias/attention weights tensor with shape (B, N, N, n_heads). + Must have compatible device mesh and placements and sharded on dimensions 1 and 2. + mask : DTensor + Mask tensor with shape (B, N, N). + Must have same device mesh and placements as input tensors and sharded on dimensions 1 and 2. + g : DTensor + pre-sigmoid gate tensor with shape (B, S, N, h_c_h). + Must have same device mesh and placements as input tensors and sharded on dimensions 1 and 2. + comm : Ring2DCommPairAveraging + Ring communication object configured for the distributed computation. + n_heads : int + Number of heads. The input tensor v's last dimension is h_c_h = n_heads * c_h where + c_h will be derived by h_c_h // n_heads + + Returns + ------- + DTensor + gated output tensor with shape (B, S, N, h_c_h). + Contains the distributed pair weighted averaging result ready for the output projection. + + Raises + ------ + TypeError + If inputs are not DTensor type. + ValueError + If tensor shapes, device meshes, or placements are incompatible, + or if ring communication setup is inconsistent. + """ + + if not isinstance(comm, Ring2DCommPairAveraging): + raise ValueError(f"Input comm must be of type Ring2DCommPairAveraging. Got type {type(comm)}.") + + # Check if inputs are of type DTensor + if ( + not isinstance(v, DTensor) + or not isinstance(b, DTensor) + or not isinstance(mask, DTensor) + or not isinstance(g, DTensor) + ): + raise TypeError( + f"Inputs 'v', 'b', 'mask' and 'g' must be of type DTensor. Got types {type(v)}, {type(b)}, {type(mask)}, and {type(g)}." + ) + + # Check if inputs have identical device mesh + device_mesh_input = v.device_mesh + if device_mesh_input != b.device_mesh: + raise ValueError( + f"Input tensors 'v' and 'b' must have identical device mesh. " + f"Got device meshes {device_mesh_input} and {b.device_mesh}." + ) + if device_mesh_input != mask.device_mesh: + raise ValueError( + f"Input tensor 'mask' must have the same device mesh as the input tensors 'v' and 'b'. " + f"Got device meshes {mask.device_mesh} and {device_mesh_input}." + ) + if device_mesh_input != g.device_mesh: + raise ValueError( + f"Input tensor 'g' must have the same device mesh as the input tensors 'v' and 'b'. " + f"Got device meshes {g.device_mesh} and {device_mesh_input}." + ) + + # Check if inputs have compatible placements + placements_input = v.placements + if placements_input != b.placements: + raise ValueError( + f"Input tensors 'v' and 'b' must have identical placements. " + f"Got placements {placements_input} and {b.placements}." + ) + if placements_input != mask.placements: + raise ValueError( + f"Input tensor 'mask' must have the same placements as the input tensors 'v' and 'b'. " + f"Got placements {mask.placements} and {placements_input}." + ) + if placements_input != g.placements: + raise ValueError( + f"Input tensor 'g' must have the same placements as the input tensors 'v' and 'b'. " + f"Got placements {g.placements} and {placements_input}." + ) + if placements_input != (Shard(0), Shard(1), Shard(2)): + # For debugging, we requires the placements to be (Shard(0), Shard(1), Shard(2)) + # TODO: remove this to only use the previous check + raise ValueError( + f"Input tensors 'v', 'b' and 'mask's placements are not (Shard(0), Shard(1), Shard(2)). " + f"Got placements {placements_input}." + ) + + # Check tensor dimensions + if v.ndim != 4: + raise ValueError(f"Input tensor 'v' must have 4 dimensions. Got {v.ndim} dimensions.") + if b.ndim != 4: + raise ValueError(f"Input tensor 'b' must have 4 dimensions. Got {b.ndim} dimensions.") + if mask.ndim != 3: + raise ValueError(f"Input tensor 'mask' must have 3 dimensions. Got {mask.ndim} dimensions.") + if g.ndim != 4: + raise ValueError(f"Input tensor 'g' must have 4 dimensions. Got {g.ndim} dimensions.") + + v_local = v.to_local() + b_local = b.to_local() + mask_local = mask.to_local() + g_local = g.to_local() + + # Check shape compatibility + B, S, N, h_c_h = v_local.shape + if h_c_h % n_heads != 0: + raise ValueError( + f"Input tensor 'v' must have a last dimension divisible by n_heads {n_heads}. Got {h_c_h}." + ) + c_h = h_c_h // n_heads + if b_local.shape != (B, N, N, n_heads): + raise ValueError( + f"Input tensor 'b' must have shape ({B}, {N}, {N}, {n_heads}). Got shape {b.to_local().shape}." + ) + if mask_local.shape != (B, N, N): + raise ValueError(f"Input tensor 'mask' must have shape ({B}, {N}, {N}). Got shape {mask.to_local().shape}.") + if g_local.shape != (B, S, N, h_c_h): + raise ValueError( + f"Input tensor 'g' must have shape ({B}, {S}, {N}, {h_c_h}). Got shape {g.to_local().shape}." + ) + + # Perform consistency check between the comm and the device_mesh_input + # Check if Shard(1) and Shard(2) exist in placements_input (for S and N dimensions) + i_tensor_dim_to_i_grid_axis = [-1] * v.ndim + for i_grid_axis, placement in enumerate(placements_input): + if isinstance(placement, Shard): + i_tensor_dim_to_i_grid_axis[placement.dim] = i_grid_axis + if i_tensor_dim_to_i_grid_axis[1] == -1 or i_tensor_dim_to_i_grid_axis[2] == -1: + raise ValueError( + f"Input tensors 'v', 'b' and 'mask's dimensions 1 and 2 must be sharded." + f"Got placements {placements_input}." + ) + + # Check if ring_comm.group_col match the device_mesh_input's group + if comm.group_col != device_mesh_input.get_group(i_tensor_dim_to_i_grid_axis[1]): + raise ValueError( + "Input comm's group_col process group is not the same as the group sharding the input tensors' axis 1" + ) + + # Check if the rank coordinates are consistent + coord_device_mesh_input = device_mesh_input.get_coordinate() + if coord_device_mesh_input is None: + raise ValueError(f"comm.coord_2d {comm.coord_2d} is not on device_mesh_input {device_mesh_input}.") + if comm.coord_2d != ( + coord_device_mesh_input[i_tensor_dim_to_i_grid_axis[1]], + coord_device_mesh_input[i_tensor_dim_to_i_grid_axis[2]], + ): + raise ValueError( + f"Input comm's coord_2d {comm.coord_2d} does not match the " + f"device mesh's rank coordinates {coord_device_mesh_input} for the sharded dimensions " + f"{i_tensor_dim_to_i_grid_axis[1]} and {i_tensor_dim_to_i_grid_axis[2]}." + ) + + requires_grad = v.requires_grad or b.requires_grad + ctx.mark_non_differentiable(mask) + + # # Save device mesh and placements for backward pass + ctx.device_mesh_input = device_mesh_input + ctx.placements_input = placements_input + ctx.input_shape = v.shape + ctx.input_stride = v.stride() + ctx.g_shape = g.shape + ctx.g_stride = g.stride() + ctx.b_shape = b.shape + ctx.b_stride = b.stride() + ctx.comm = comm + ctx.n_heads = n_heads + ctx.c_h = c_h + + # reshape v_local from (B, S, N, n_heads * c_h) to (B, n_heads, S, N, c_h) + # TODO: reshape v_local to (B, n_heads, c_h, S, N) to be more GEMM friendly + v_local = ( + v_local.unflatten(-1, (n_heads, c_h)).permute(0, 3, 1, 2, 4).clone(memory_format=torch.contiguous_format) + ) + + if requires_grad: + v_local_copy = v_local.detach().clone() + else: + v_local_copy = None + + v_local_recv = comm.comm_row_init.enqueue_to_dispatch(v_local) + + # convert mask_local to bias and reshape it to (B, N0, N1, 1) + mask_bias_local = 1 - mask_local + mask_bias_local *= -inf + mask_bias_local = mask_bias_local.unsqueeze(-1) + + # apply mask_bias to b and reshape b_local from (B, N0, N1, n_heads) to (B, n_heads, N1, N0) + b_local = b_local + mask_bias_local + bT_local = b_local.permute(0, 3, 2, 1).contiguous() + + bT_local_recv = comm.comm_2d_trans.enqueue_to_dispatch(bT_local) + + g_local = g_local.sigmoid() + + comm.comm_2d_trans.wait_until_finished() + + bT_local = comm.comm_col_init.enqueue_to_dispatch(bT_local_recv, bT_local) + + # cumulative amax until the current block + amax = None + # cumulative lse_m + lse_m = None + + # bT_local is ready + comm.comm_col_init.wait_until_finished() + + # v_local_recv is ready + comm.comm_row_init.wait_until_finished() + + bT_local_buffer = [bT_local, bT_local_recv] + v_local_buffer = [v_local_recv, v_local] + i_ready = 0 + i_recv = i_ready ^ 1 + + # receive other tensor blocks + num_steps = comm.group_layout.shape[0] + o_local = None + for step in range(num_steps): + there_is_another_step = (step + 1) < num_steps + if there_is_another_step: + v_local_buffer[i_recv] = comm.comm_row.enqueue_to_dispatch( + v_local_buffer[i_ready], v_local_buffer[i_recv] + ) + bT_local_buffer[i_recv] = comm.comm_col.enqueue_to_dispatch( + bT_local_buffer[i_ready], bT_local_buffer[i_recv] + ) + + # (B, n_heads, 1, N) + amax_block = bT_local_buffer[i_ready].amax(dim=-2, keepdim=True) + lse_m_block = torch.logsumexp(bT_local_buffer[i_ready] - amax_block, dim=-2, keepdim=True) + + p = bT_local_buffer[i_ready].softmax(dim=-2) + + o_block = torch.einsum("bhsjd,bhji->bhsid", v_local_buffer[i_ready], p) + + # reshape o_block from (B, n_heads, S, N, c_h) to (B, n_heads, N, S * c_h) + # reshape lse_m_block from (B, n_heads, 1, N) to (B, n_heads, N, 1) + # reshape amax_block from (B, n_heads, 1, N) to (B, n_heads, N, 1) + o_local, lse_m, amax = tiled_softmax_attention_update( + o_block.transpose(-3, -2).flatten(start_dim=-2), + lse_m_block.transpose(-2, -1), + amax_block.transpose(-2, -1), + o_local, + lse_m, + amax, + ) + + if there_is_another_step: + comm.comm_row.wait_until_finished() + comm.comm_col.wait_until_finished() + i_ready = i_ready ^ 1 + i_recv = i_recv ^ 1 + + # reshape o_local from (B, n_heads, N, S *c_h) to (B, S, N, n_heads * c_h) + o_local = o_local.unflatten(-1, (S, c_h)).permute(0, 3, 2, 1, 4).flatten(start_dim=-2) + + o_local_copy = o_local.detach().clone(memory_format=torch.contiguous_format) + + if requires_grad: + # transpose lse_m and amax + amax_recv = comm.comm_2d_trans_amax.enqueue_to_dispatch(amax) + # normalize b_local to be post-softmax matrix + # b_local is of shape: (B, n_heads, N0, N1) + b_local = b_local.permute(0, 3, 1, 2).contiguous() + lse_m_recv = comm.comm_2d_trans_lse_m.enqueue_to_dispatch(lse_m) + # subtract amax from b_local first + # as amax tends to store extreme values from b_local masked + # lse_m and amax of shape: (B, n_heads, N0, 1) + # transpose lse across the grid to match b_local's placements + # b_local = torch.exp(b_local - amax - lse_m) + comm.comm_2d_trans_amax.wait_until_finished() + b_local -= amax_recv + comm.comm_2d_trans_lse_m.wait_until_finished() + b_local -= lse_m_recv + b_local.exp_() + ctx.save_for_backward(v_local_copy, b_local, g_local, o_local_copy) + + o_local *= g_local + + o = DTensor.from_local( + o_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=ctx.input_shape, + stride=ctx.input_stride, + ) + + return o + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, do: DTensor) -> tuple[DTensor, DTensor, None, None, None, None, None]: + if not isinstance(do, DTensor): + raise TypeError(f"Input 'do' must be of type DTensor. Got type {type(do)}.") + + if do.device_mesh != ctx.device_mesh_input: + raise ValueError( + f"Input 'do' must have the same device mesh as the input tensors. " + f"Got device meshes {do.device_mesh} and {ctx.device_mesh_input}." + ) + + if do.placements != ctx.placements_input: + raise ValueError( + f"Input 'do' must have the same placements as the input tensors. " + f"Got placements {do.placements} and {ctx.placements_input}." + ) + + v_local, p_local, g_local, o_local = ctx.saved_tensors + comm = ctx.comm + num_steps = comm.group_layout.shape[0] + + S = do.to_local().shape[1] + + p_local_recv = comm.comm_col_init.enqueue_to_dispatch(p_local) + + # do.to_local() is of shape (B, S, N, n_heads * c_h) + # g_local is of shape (B, S, N, n_heads * c_h) + # cast do to the same dtype as g_local to avoid type promotion to FP32 if do is FP32 + # which can cause NCCL hang due to size mismatch in P2P communication + do_local = do.to_local().to(dtype=g_local.dtype) * g_local + dsigmoid = 1 - g_local + + # input o_local is of shape (B, S, N, n_heads * c_h) + dsigmoid *= do_local + dsigmoid *= o_local + + dg = DTensor.from_local( + dsigmoid, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=ctx.g_shape, + stride=ctx.g_stride, + ) + + # reshape do from (B, S, N, n_heads * c_h) to (B, n_heads, S * c_h, N) + do_local = ( + do_local.unflatten(-1, (ctx.n_heads, ctx.c_h)) # (B, S, N, n_heads, c_h) + .permute(0, 3, 1, 4, 2) # (B, n_heads, S, c_h, N) + .flatten(start_dim=-3, end_dim=-2) # (B, n_heads, S * c_h, N) + .contiguous() + ) + + do_local_recv = comm.comm_row_init.enqueue_to_dispatch(do_local) + + # reshape o_local from (B, S, N, n_heads * c_h) to (B, n_heads, S * c_h, N) + o_local = ( + o_local.unflatten(-1, (ctx.n_heads, ctx.c_h)) # (B, S, N, n_heads, c_h) + .permute(0, 3, 1, 4, 2) # (B, n_heads, S, c_h, N) + .flatten(start_dim=-3, end_dim=-2) # (B, n_heads, S * c_h, N) + .contiguous() + ) + + d = torch.einsum("bhti,bhti->bhi", do_local, o_local).contiguous() + + d_recv = comm.comm_d_init.enqueue_to_dispatch(d) + + # reshape v_local from (B, n_heads, S, N, c_h) to (B, n_heads, S * c_h, N) + v_local = v_local.transpose(-2, -1).flatten(start_dim=-3, end_dim=-2).contiguous() + + comm.comm_col_init.wait_until_finished() # p_local_recv is ready + comm.comm_row_init.wait_until_finished() # do_local_recv is ready + comm.comm_d_init.wait_until_finished() # d_recv is ready + + i_ready = 0 + i_recv = i_ready ^ 1 + + p_local_buffer = [p_local_recv, p_local] + do_local_buffer = [do_local_recv, do_local] + d_buffer = [d_recv, d] + db_local_buffer = [torch.zeros_like(p_local), torch.zeros_like(p_local)] + + dv_local = torch.zeros_like(o_local) + + for step in range(num_steps): + there_is_another_step = step < num_steps - 1 + if there_is_another_step: + p_local_buffer[i_recv] = comm.comm_col.enqueue_to_dispatch( + p_local_buffer[i_ready], p_local_buffer[i_recv] + ) + do_local_buffer[i_recv] = comm.comm_row.enqueue_to_dispatch( + do_local_buffer[i_ready], do_local_buffer[i_recv] + ) + d_buffer[i_recv] = comm.comm_d.enqueue_to_dispatch(d_buffer[i_ready], d_buffer[i_recv]) + + dv_block = torch.einsum("bhti,bhij->bhtj", do_local_buffer[i_ready], p_local_buffer[i_ready]) + dv_local += dv_block + + # dp + db_local_block = torch.einsum("bhti,bhtj->bhij", do_local_buffer[i_ready], v_local).contiguous() + # dp - d + db_local_block -= d_buffer[i_ready].unsqueeze(-1) + # p * (dp - d) + db_local_block *= p_local_buffer[i_ready] + # virtual all-reduce db_local + if step > 0: + comm.comm_db.wait_until_finished() + db_local_buffer[i_ready] += db_local_block + # db send/recv will carry through num_steps + # explicitly cast to the recv buffer's dtype to prevent NCCL hang due to potential dtype mismatch + # if db_local_buffer[i_ready] was promoted to FP32 during accumulation + db_local_buffer[i_recv] = comm.comm_db.enqueue_to_dispatch( + db_local_buffer[i_ready].to(dtype=db_local_buffer[i_recv].dtype), db_local_buffer[i_recv] + ) + + if there_is_another_step: + comm.comm_col.wait_until_finished() + comm.comm_row.wait_until_finished() + comm.comm_d.wait_until_finished() + i_ready = i_ready ^ 1 + i_recv = i_recv ^ 1 + + # reshape dv_local from (B, n_heads, S * c_h, N) to (B, S, N, n_heads * c_h) + dv_local = dv_local.unflatten(-2, (S, ctx.c_h)).permute(0, 2, 4, 1, 3).flatten(start_dim=-2) + + dv = DTensor.from_local( + dv_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=ctx.input_shape, + stride=ctx.input_stride, + ) + + # reshape db_local from (B, n_heads, N0, N1) to (B, N0, N1, n_heads) + # the last comm of db is ready + comm.comm_db.wait_until_finished() + i_ready = i_ready ^ 1 + i_recv = i_recv ^ 1 + + # revert the comm_col_init to get the db data ownership as the input b + db_local_buffer[i_recv] = comm.comm_db_final.enqueue_to_dispatch( + db_local_buffer[i_ready], db_local_buffer[i_recv] + ) + + comm.comm_db_final.wait_until_finished() + i_ready = i_ready ^ 1 + i_recv = i_recv ^ 1 + + # reshape db_local from (B, n_heads, N0, N1) to (B, N0, N1, n_heads) + db_local = db_local_buffer[i_ready].permute(0, 2, 3, 1) + + db = DTensor.from_local( + db_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=ctx.b_shape, + stride=ctx.b_stride, + ) + + return dv, db, None, dg, None, None, None + + +class PairWeightedAveraging(torch.nn.Module): + """Pair weighted averaging layer.""" + + def __init__( + self, + layer: SerialPairWeightedAveraging, + device_mesh: DeviceMesh, + comm: Ring2DCommPairAveraging, + ) -> None: + super().__init__() + self.comm = comm + self.c_m = layer.c_m + self.c_z = layer.c_z + self.c_h = layer.c_h + self.num_heads = layer.num_heads + self.inf = layer.inf + + self.device_mesh = device_mesh + self.comm = comm + + self.norm_m = LayerNormParamsReplicated(layer.norm_m, self.device_mesh) + self.norm_z = LayerNormParamsReplicated(layer.norm_z, self.device_mesh) + + self.proj_m = LinearParamsReplicated(layer.proj_m, self.device_mesh) + self.proj_g = LinearParamsReplicated(layer.proj_g, self.device_mesh) + self.proj_z = LinearParamsReplicated(layer.proj_z, self.device_mesh) + self.proj_o = LinearParamsReplicated(layer.proj_o, self.device_mesh) + + def forward(self, m: DTensor, z: DTensor, mask: DTensor) -> DTensor: + # Compute layer norms + m = self.norm_m(m) + z = self.norm_z(z) + + g = self.proj_g(m) + + # TODO: fuse the m -> {v, g} projection in one kernel + v = self.proj_m(m) + + b = self.proj_z(z) + + o = _PairWeightedAveragingImpl.apply(v, b, mask, g, self.comm, self.num_heads, self.inf) + + o = self.proj_o(o) + return o diff --git a/src/boltz/distributed/model/layers/pairformer.py b/src/boltz/distributed/model/layers/pairformer.py new file mode 100644 index 000000000..3083beb66 --- /dev/null +++ b/src/boltz/distributed/model/layers/pairformer.py @@ -0,0 +1,350 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.distributed.tensor import DTensor +from torch.utils.checkpoint import checkpoint + +from boltz.distributed.comm import AttentionPairBiasComm, Ring2DComm, Ring2DCommTriAttn +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.attention import AttentionPairBias +from boltz.distributed.model.layers.dropout import apply_dropout_mask_msa_or_pair +from boltz.distributed.model.layers.elementwise_op import ElementwiseOp, elementwise_op +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.transition import Transition +from boltz.distributed.model.layers.triangular_attention import ( + TriangleAttentionEndingNode, + TriangleAttentionStartingNode, +) +from boltz.distributed.model.layers.triangular_mult import ( + TriangleMultiplicationIncoming, + TriangleMultiplicationOutgoing, +) +from boltz.distributed.model.modules.utils import TriAttnBackend, get_cpu_offload_context +from boltz.model.layers.pairformer import ( + PairformerLayer as SerialPairformerLayer, +) +from boltz.model.layers.pairformer import ( + PairformerModule as SerialPairformerModule, +) +from boltz.model.layers.pairformer import ( + PairformerNoSeqLayer as SerialPairformerNoSeqLayer, +) +from boltz.model.layers.pairformer import ( + PairformerNoSeqModule as SerialPairformerNoSeqModule, +) + + +class PairformerLayer(nn.Module): + """Distributed PairformerLayer using DTensor (V2: attention with k_in). + When self.no_seq=True, only the pairwise stack is applied (no sequence track). + """ + + def __init__( + self, + layer: Union[SerialPairformerLayer, SerialPairformerNoSeqLayer], + dist_manager: DistributedManager, + ) -> None: + """Initialize the distributed PairformerLayer module. + + Parameters + ---------- + layer : SerialPairformerLayer or SerialPairformerNoSeqLayer + The serial Pairformer layer to be distributed. Passing a + ``SerialPairformerNoSeqLayer`` activates pair-only mode automatically. + dist_manager : DistributedManager + Distributed manager defining the distributed computation topology and groups. + """ + if not isinstance(layer, (SerialPairformerLayer, SerialPairformerNoSeqLayer)): + raise TypeError( + f"layer must be SerialPairformerLayer or SerialPairformerNoSeqLayer, got {type(layer).__name__}" + ) + super().__init__() + self.dist_manager = dist_manager + self.device_mesh = dist_manager.device_mesh_subgroups + self.no_seq = isinstance(layer, SerialPairformerNoSeqLayer) + + self.token_z = layer.token_z + self.dropout = layer.dropout + self.post_layer_norm = layer.post_layer_norm + + # Mutable backend selection for triangle attention layers. + # Default is REFERENCE (no fused kernels). To switch backend for the + # entire model, use ``model.apply(SetTriAttnBackend(backend))`` + # (see boltz.distributed.model.modules.utils.SetTriAttnBackend). + self.triattn_backend = TriAttnBackend.REFERENCE + + # Ring comms for triangular layers + ring_comm_2d_trimul_outgoing = Ring2DComm( + self.dist_manager.group["cp"], + self.dist_manager.subgroups["cp"][0], + self.dist_manager.layout_subgroups["cp"], + ) + ring_comm_2d_trimul_incoming = Ring2DComm( + self.dist_manager.group["cp"], + self.dist_manager.subgroups["cp"][0], + self.dist_manager.layout_subgroups["cp"], + ) + ring_comm_tri_attn_start = Ring2DCommTriAttn( + self.dist_manager.group["cp"], + self.dist_manager.layout_subgroups["cp"], + 1, + ) + ring_comm_tri_attn_end = Ring2DCommTriAttn( + self.dist_manager.group["cp"], + self.dist_manager.layout_subgroups["cp"], + 0, + ) + + self.tri_mul_out = TriangleMultiplicationOutgoing( + layer.tri_mul_out, self.device_mesh, ring_comm_2d_trimul_outgoing + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + layer.tri_mul_in, self.device_mesh, ring_comm_2d_trimul_incoming + ) + self.tri_att_start = TriangleAttentionStartingNode( + layer.tri_att_start, self.device_mesh, ring_comm_tri_attn_start + ) + self.tri_att_end = TriangleAttentionEndingNode(layer.tri_att_end, self.device_mesh, ring_comm_tri_attn_end) + self.transition_z = Transition(layer.transition_z, self.device_mesh) + + if not self.no_seq: + self.num_heads = layer.num_heads + # Detect V1 vs V2 from the serial attention module. + # V1 AttentionPairBias uses initial_norm=True which creates norm_s + # (see src/boltz/model/layers/attention.py AttentionPairBias.__init__). + # V2 does not have norm_s. Prefer the explicit .v2 attribute when + # set by the serial PairformerLayer; otherwise infer from norm_s. + if hasattr(layer, "v2"): + self.v2 = layer.v2 + else: + self.v2 = not hasattr(layer.attention, "norm_s") + self.pre_norm_s = LayerNormParamsReplicated(layer.pre_norm_s, self.device_mesh) + attention_pair_bias_comm = AttentionPairBiasComm( + self.dist_manager.group["cp"], + self.dist_manager.layout_subgroups["cp"], + self.dist_manager.subgroups["cp"][0], + self.dist_manager.subgroups["cp"][1], + ) + if self.v2: + self.attention = AttentionPairBias( + layer.attention, + self.device_mesh, + attention_pair_bias_comm, + apply_initial_norm=False, + compute_pair_bias=True, + use_model_cache=False, + ) + else: + self.attention = AttentionPairBias( + layer.attention, + self.device_mesh, + attention_pair_bias_comm, + apply_initial_norm=True, + compute_pair_bias=True, + use_model_cache=True, + ) + self.transition_s = Transition(layer.transition_s, self.device_mesh) + if self.post_layer_norm: + self.s_post_norm = LayerNormParamsReplicated(layer.s_post_norm, self.device_mesh) + else: + self.s_post_norm = None + + def forward( + self, + s: Optional[DTensor] = None, + z: Optional[DTensor] = None, + mask: Optional[DTensor] = None, + pair_mask: Optional[DTensor] = None, + ) -> Union[Tuple[DTensor, DTensor], DTensor]: + """Forward pass. Pass s= and mask= for the full pairformer; omit them for pair-only.""" + assert z is not None and pair_mask is not None + z = elementwise_op( + z, + apply_dropout_mask_msa_or_pair( + self.tri_mul_out(z, mask=pair_mask), + self.dropout, + self.training, + ), + ElementwiseOp.SUM, + ) + z = elementwise_op( + z, + apply_dropout_mask_msa_or_pair( + self.tri_mul_in(z, mask=pair_mask), + self.dropout, + self.training, + ), + ElementwiseOp.SUM, + ) + z = elementwise_op( + z, + apply_dropout_mask_msa_or_pair( + self.tri_att_start(z, mask=pair_mask, triattn_backend=self.triattn_backend), + self.dropout, + self.training, + ), + ElementwiseOp.SUM, + ) + z = elementwise_op( + z, + apply_dropout_mask_msa_or_pair( + self.tri_att_end(z, mask=pair_mask, triattn_backend=self.triattn_backend), + self.dropout, + self.training, + columnwise=True, + ), + ElementwiseOp.SUM, + ) + z = elementwise_op(z, self.transition_z(z), ElementwiseOp.SUM) + if s is None: + return z + assert mask is not None + with torch.autocast("cuda", enabled=False): + safe_dtype = torch.promote_types(s.dtype, torch.float32) + s = s.to(dtype=safe_dtype) + s_normed = self.pre_norm_s(s) + s = elementwise_op( + s, + self.attention(s=s_normed, z=z.to(dtype=safe_dtype), mask=mask.to(dtype=safe_dtype), k_in=s_normed), + ElementwiseOp.SUM, + ) + s = elementwise_op(s, self.transition_s(s), ElementwiseOp.SUM) + s = self.s_post_norm(s) if self.s_post_norm is not None else s + return s, z + + +class PairformerModule(nn.Module): + """Distributed PairformerModule using DTensor (V2). + + Handles both full pairformer (sequence + pairwise stacks) and pair-only + mode. The mode is inferred from the serial module type: passing a + ``SerialPairformerNoSeqModule`` activates pair-only mode automatically. + """ + + def __init__( + self, + module: Union[SerialPairformerModule, SerialPairformerNoSeqModule], + dist_manager: DistributedManager, + cpu_offloading: bool = False, + ) -> None: + """Initialize the distributed PairformerModule module. + + Parameters + ---------- + module : SerialPairformerModule or SerialPairformerNoSeqModule + The serial Pairformer module containing weights and configuration to be distributed. + dist_manager : DistributedManager + Distributed manager defining the distributed computation topology and groups. + cpu_offloading : bool, optional + Whether to offload checkpoint-boundary activations to CPU when + activation checkpointing is enabled. This is a distributed-only + option (the serial Boltz-2 MSAModule does not support it). + Defaults to False. + """ + super().__init__() + self.dist_manager = dist_manager + self.device_mesh = dist_manager.device_mesh_subgroups + + no_seq = isinstance(module, SerialPairformerNoSeqModule) + self.no_seq = no_seq + + self.token_z = module.token_z + self.num_blocks = module.num_blocks + self.dropout = module.dropout + self.post_layer_norm = module.post_layer_norm + self.activation_checkpointing = module.activation_checkpointing + self.cpu_offloading = cpu_offloading + + if not no_seq: + self.num_heads = module.num_heads + + self.layers = nn.ModuleList() + for serial_layer in module.layers: + self.layers.append(PairformerLayer(serial_layer, dist_manager)) + + def forward( + self, + s: Optional[DTensor] = None, + z: Optional[DTensor] = None, + mask: Optional[DTensor] = None, + pair_mask: Optional[DTensor] = None, + ) -> Union[Tuple[DTensor, DTensor], DTensor]: + """Forward pass. Pass s= and mask= for the full pairformer; omit them for pair-only.""" + if self.activation_checkpointing and self.training: + if self.cpu_offloading: + with get_cpu_offload_context(optimized=True): + for layer in self.layers: + result = checkpoint( + layer, + s, + z, + mask, + pair_mask, + use_reentrant=False, + ) + if self.no_seq: + z = result + else: + s, z = result + else: + for layer in self.layers: + result = checkpoint( + layer, + s, + z, + mask, + pair_mask, + use_reentrant=False, + ) + if self.no_seq: + z = result + else: + s, z = result + else: + for layer in self.layers: + result = layer( + s=s, + z=z, + mask=mask, + pair_mask=pair_mask, + ) + if self.no_seq: + z = result + else: + s, z = result + return z if self.no_seq else (s, z) + + +class PairformerNoSeqLayer(PairformerLayer): + """Distributed PairformerNoSeqLayer (pairwise stack only, no sequence track).""" + + pass + + +class PairformerNoSeqModule(PairformerModule): + """Distributed PairformerNoSeqModule (pairwise stack only).""" + + pass diff --git a/src/boltz/distributed/model/layers/redistribute_transpose.py b/src/boltz/distributed/model/layers/redistribute_transpose.py new file mode 100755 index 000000000..1221e2e13 --- /dev/null +++ b/src/boltz/distributed/model/layers/redistribute_transpose.py @@ -0,0 +1,324 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from typing import Optional + +import torch +from torch import Tensor +from torch.autograd.function import FunctionCtx +from torch.distributed.tensor import DTensor, Partial, Placement, Shard + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.model.layers.dtensor_metadata_tools import ( + raise_if_incorrect_dtensor_metadata_args, +) + + +def redistribute_transpose( + input: DTensor, + transpose_comm: Optional[TransposeComm], + output_placements: Optional[tuple[Placement, ...]], + dim0: Optional[int] = None, + dim1: Optional[int] = None, +) -> DTensor: + """Transpose a DTensor across device mesh (and locally). + + Use cases in Boltz: + (1) boltz.model.modules.trunk.py: DistogramModule.forward [impl'd] + - redistribute_transpose(z, self.distogram_comm, (Shard(0), Shard(1), Shard(2)), 1, 2), + + (2) boltz.model.modules.trunk.py: MSAModule.forward [impl'd] + - redistribute_transpose(emb, self.comm_transpose, (Shard(0), Shard(1), Shard(2)), 1, 2) + + (3) boltz.model.loss.distogram.py: distogram_loss + - redistribute_transpose(mask, comm, (Shard(0), Shard(1), Shard(2)), 1, 2) + + (4) boltz.model.modules.encoders.py: AtomAttentionEncoder.forward + - redistribute_transpose(c, self.transpose_comm_c, (Shard(0), Replicate(), Shard(1)), None, None) + + (5) boltz.model.model.py: Boltz1.forward + - redistribute_transpose(s_inputs, self.transpose_comm, (Shard(0), Replicate(), Shard(1)), None, None) + + Parameters + ---------- + input : DTensor + Input tensor to transpose + transpose_comm : Optional[TransposeComm] + Communication object for distributed operations + output_placements : Optional[tuple[Placement, ...]] + Output placements for the DTensor. + dim0 : Optional[int] + First dimension to transpose locally. + dim1 : Optional[int] + Second dimension to transpose locally. + + Returns + ------- + DTensor + DTensor transposed across device mesh (and locally). + """ + return _RedistributeTransposeImpl.apply(input, output_placements, transpose_comm, dim0, dim1) + + +class _RedistributeTransposeImpl(torch.autograd.Function): + """Custom autograd function to transpose a DTensor across device mesh (and locally).""" + + @staticmethod + def forward( + ctx: FunctionCtx, + input: DTensor, + output_placements: Optional[tuple[Placement, ...]], + transpose_comm: Optional[TransposeComm], + dim0: Optional[int] = None, + dim1: Optional[int] = None, + ) -> DTensor: + """Forward pass for _RedistributeTransposeImpl custom autograd function. + + Parameters + ---------- + ctx : FunctionCtx + Context object to save information for backward pass + input : DTensor + Input tensor to transpose and redistribute + output_placements : tuple[Placement, ...] + Output placements for the DTensor. + transpose_comm : Union[TransposeComm, None] + Communication object for distributed operations + dim0 : int + First tensor dimension to transpose locally. + dim1 : int + Second tensor dimension to transpose locally. + + Returns + ------- + DTensor + DTensor transposed across device mesh (and locally). + """ + # check input options + if (dim0 is None) != (dim1 is None): + raise ValueError( + " When using redistribute_transpose, either both dim0 and dim1 must be None if no local transposition, or both must be not None for local transposition" + ) + + ctx.is_local_transpose = dim1 is not None + ctx.is_device_mesh_transpose = transpose_comm is not None + + # Short circuit if no local or device mesh transpose is performed + if not ctx.is_local_transpose and not ctx.is_device_mesh_transpose: + return input + + if (transpose_comm is None) != (output_placements is None): + raise ValueError( + "transpose_comm and output_placements must be either both None or both not None for device mesh transpose" + ) + + if ctx.is_local_transpose and ctx.is_device_mesh_transpose and (output_placements != input.placements): + raise ValueError( + "Simultaneous redistribute and local transpose is only supported when the two transposing axes are the sharding axes involved in the said redistribute. For other usage cases, consider decompose the operation in a redistribute-only followed by a local transpose-only operations by calling redistribute_transpose() twice with different arguments" + ) + + axis_mesh_shard_dim0 = None + axis_mesh_shard_dim1 = None + + if ctx.is_device_mesh_transpose: + axes_mesh_transpose = [] + else: + axes_mesh_transpose = None + + for i_dim_device_mesh, placement in enumerate(input.placements): + # Check if partial placements + if isinstance(placement, Partial): + raise ValueError( + f"Partial placements are not supported for redistribute_transpose but {input.placements} is given" + ) + + # Check if sharding is even + if ( + isinstance(placement, Shard) + and input.shape[placement.dim] % input.device_mesh.shape[i_dim_device_mesh] != 0 + ): + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {input.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {input.device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + + if ctx.is_local_transpose and isinstance(placement, Shard): + if placement.dim == dim0: + axis_mesh_shard_dim0 = i_dim_device_mesh + if placement.dim == dim1: + axis_mesh_shard_dim1 = i_dim_device_mesh + + if axes_mesh_transpose is not None: + if placement != output_placements[i_dim_device_mesh]: + axes_mesh_transpose.append(i_dim_device_mesh) + + # Check if locally transposed dimensions are sharded if both local and device mesh transposes are performed + if (ctx.is_local_transpose and ctx.is_device_mesh_transpose) and ( + axis_mesh_shard_dim0 is None or axis_mesh_shard_dim1 is None + ): + raise ValueError( + f"Both dim0 and dim1 must be sharded when doing both local and device mesh transposes " + f"but dim0={dim0} and dim1={dim1} are given with placements={input.placements}" + ) + + # Check if locally transposed dimensions are sharded if only local transpose is performed + if (ctx.is_local_transpose and not ctx.is_device_mesh_transpose) and ( + axis_mesh_shard_dim0 is not None or axis_mesh_shard_dim1 is not None + ): + raise NotImplementedError( + "Local transpose on sharded dimensions is not supported when only local transpose is performed" + ) + + if ctx.is_device_mesh_transpose: + device_mesh_coords = input.device_mesh.get_coordinate() + if len(axes_mesh_transpose) == 0: + # this implies dim{0, 1} are sharded and output_placements == input_placements + # but the underlying device mesh transpose will be performed by the transpose_comm + # along the two Sharding placement axes + assert axis_mesh_shard_dim0 is not None and axis_mesh_shard_dim1 is not None + axes_mesh_transpose = [axis_mesh_shard_dim0, axis_mesh_shard_dim1] + else: + # assert output placements is strictly a permutation of input placements + if not ( + input.placements[axes_mesh_transpose[0]] == output_placements[axes_mesh_transpose[1]] + and input.placements[axes_mesh_transpose[1]] == output_placements[axes_mesh_transpose[0]] + ): + raise ValueError( + "Input and output placements are not strictly a permutation of each other along mesh transpose axes:" + f"input.placements={input.placements} vs. output_placements={output_placements}" + ) + + # assert the correspondence of transpose_comm's underlying group to the device mesh axes + if ( + device_mesh_coords[axes_mesh_transpose[0]], + device_mesh_coords[axes_mesh_transpose[1]], + ) != transpose_comm.rank_coords: + raise ValueError( + f"Inconsistent device mesh coordinate {device_mesh_coords} along mesh transpose axes {axes_mesh_transpose} " + f"compared to transpose_comm rank_coords {transpose_comm.rank_coords}" + ) + + # transpose input tensor + output_local: Tensor = input.to_local() + + if ctx.is_device_mesh_transpose: + output_local_ = transpose_comm.enqueue_to_dispatch(output_local.contiguous()) + transpose_comm.wait_until_finished() + output_local = output_local_ + if ctx.is_local_transpose: + output_local = output_local.transpose(dim0, dim1) + + output_shape = torch.Size( + _swap_tuple_elements(input.shape, dim0, dim1) if ctx.is_local_transpose else input.shape + ) + output_stride = _swap_tuple_elements(input.stride(), dim0, dim1) if ctx.is_local_transpose else input.stride() + + if input.requires_grad: + ctx.input_shape = input.shape + ctx.output_shape = output_shape + ctx.input_stride = input.stride() + ctx.output_stride = output_stride + ctx.input_placements = input.placements + ctx.output_placements = input.placements if output_placements is None else output_placements + ctx.device_mesh = input.device_mesh + ctx.dim0 = dim0 + ctx.dim1 = dim1 + ctx.transpose_comm = transpose_comm + + # Create a new DTensor called output + output: DTensor = DTensor.from_local( + local_tensor=output_local, + shape=output_shape, + stride=output_stride, + device_mesh=input.device_mesh, + placements=input.placements if output_placements is None else output_placements, + ) + + return output + + @staticmethod + def backward( + ctx: FunctionCtx, + grad_output: DTensor, + ) -> tuple[DTensor, None, None, None, None]: + """Backward pass for _RedistributeTranspose custom autograd.Function + + Parameters + ---------- + ctx : FunctionCtx + Context object with saved information from forward pass + grad_output : DTensor + Gradient tensor from downstream layers + + Returns + ------- + tuple[DTensor, None, None, None] + Tuple containing the gradient for input tensor and None for other parameters + """ + # Short circuit if no local or device mesh transpose is performed + if not ctx.is_local_transpose and not ctx.is_device_mesh_transpose: + return grad_output, None, None, None, None + + # metadata check on grad_output + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=grad_output, + dtensor_name="grad_output", + expected_shape=ctx.output_shape, + expected_device_mesh=ctx.device_mesh, + expected_placements=ctx.output_placements, + check_for_partial_placements=False, + ) + + # transpose gradient tensor + grad_input_local = grad_output.to_local() + + if ctx.is_device_mesh_transpose: + grad_input_local_ = ctx.transpose_comm.enqueue_to_dispatch(grad_input_local.contiguous()) + ctx.transpose_comm.wait_until_finished() + grad_input_local = grad_input_local_ + if ctx.is_local_transpose: + grad_input_local = grad_input_local.transpose(ctx.dim0, ctx.dim1) + + # Create a new DTensor called output + grad_input: DTensor = DTensor.from_local( + grad_input_local, + shape=ctx.input_shape, + stride=ctx.input_stride, + device_mesh=ctx.device_mesh, + placements=ctx.input_placements, + ) + return grad_input, None, None, None, None + + +def _swap_tuple_elements(x: tuple[int, ...], i: int, j: int) -> tuple[int, ...]: + """Swap two elements of a tuple. + + Parameters + ---------- + x : tuple[int, ...] + Tuple to swap elements of + i : int + Index of first element to swap + j : int + """ + y = list(x) + y[i], y[j] = y[j], y[i] + return tuple(y) diff --git a/src/boltz/distributed/model/layers/redistribute_transpose_without_dtensor.py b/src/boltz/distributed/model/layers/redistribute_transpose_without_dtensor.py new file mode 100644 index 000000000..16e7a0cff --- /dev/null +++ b/src/boltz/distributed/model/layers/redistribute_transpose_without_dtensor.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from typing import Optional + +import torch + +from boltz.distributed.comm import TransposeComm + + +def transpose_then_redistribute( + input: torch.Tensor, dim0: int, dim1: int, transpose_comm: TransposeComm +) -> torch.Tensor: + """Transpose a tensor and redistribute it across processes. + + This function first performs a transpose operation on the input tensor + and then redistributes the result using the provided communication object. + + Parameters + ---------- + input : torch.Tensor + Input tensor to transpose + dim0 : int + First dimension to transpose + dim1 : int + Second dimension to transpose + transpose_comm: TransposeComm + Communication object for distributed operations + + Returns + ------- + torch.Tensor + Transposed and redistributed tensor + + """ + inputT = input.transpose(dim0, dim1).contiguous() + inputT_recv = transpose_comm.enqueue_to_dispatch(inputT) + transpose_comm.wait_until_finished() + return inputT_recv + + +class RedistributeTranspose(torch.autograd.Function): + """Custom autograd function to perform transpose with redistribution + + This operation performs a tensor transpose across a grid of processes + encapsulated by the input TransposeComm object. + """ + + @staticmethod + def forward(ctx, input: torch.Tensor, dim0: int, dim1: int, transpose_comm: TransposeComm) -> torch.Tensor: + """Forward pass for RedistributeTranspose. + + Args: + ctx: Context object to save information for backward pass + input: Input tensor to transpose and redistribute + dim0: First dimension to transpose + dim1: Second dimension to transpose + transpose_comm: Communication object for distributed operations + + Returns: + Transposed and redistributed tensor + """ + ctx.dim0 = dim0 + ctx.dim1 = dim1 + ctx.transpose_comm = transpose_comm + return transpose_then_redistribute(input, dim0, dim1, transpose_comm) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, None]: + """Backward pass for RedistributeTranspose. + + Args: + ctx: Context object with saved information from forward pass + grad_output: Gradient tensor from downstream layers + + Returns: + Tuple containing the gradient for input tensor and None for other parameters + """ + dim0 = ctx.dim0 + dim1 = ctx.dim1 + transpose_comm = ctx.transpose_comm + grad_input = transpose_then_redistribute(grad_output, dim1, dim0, transpose_comm) + return grad_input, None, None, None + + +def redistribute_transpose( + input: torch.Tensor, dim0: int, dim1: int, transpose_comm: Optional[TransposeComm] = None +) -> torch.Tensor: + """Transpose a tensor with optional redistribution for distributed training. + + When the input TransposeComm is not None, the input tensor is redistributed + across the grid of processes encapsulated by the TransposeComm object. This implies + the return tensor's memory contiguity (layout right be default). By design, this + intention is to be consistent with the equivalent operation: + 1) inputT_global = input.transpose(dim0, dim1) + 2) scatter(result, [inputT_global_chunk_0.contiguous(), inputT_global_chunk_1.contiguous(), ...], + src=0) # scatter from root process (0) to all other processes + where the the scatter op requires contiguous source tensors. Similarly for the the + backward pass. + + Args: + input: Input tensor to transpose + dim0: First dimension to transpose + dim1: Second dimension to transpose + transpose_comm: Optional communication object for distributed operations. + If None, performs a regular transpose without redistribution. + + Returns: + Transposed tensor, potentially redistributed across processes + """ + if transpose_comm is None: + return input.transpose(dim0, dim1) + else: + return RedistributeTranspose.apply(input, dim0, dim1, transpose_comm) diff --git a/src/boltz/distributed/model/layers/repeat_interleave.py b/src/boltz/distributed/model/layers/repeat_interleave.py new file mode 100644 index 000000000..15553bdf2 --- /dev/null +++ b/src/boltz/distributed/model/layers/repeat_interleave.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import torch +from torch import Tensor +from torch.autograd.function import FunctionCtx +from torch.distributed.tensor import DTensor, Partial, Shard + +from boltz.distributed.model.layers.dtensor_metadata_tools import ( + raise_if_incorrect_dtensor_metadata_args, +) +from boltz.distributed.utils import update_exhaustive_strides + + +class _ShardwiseRepeatInterleaveImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx: FunctionCtx, + x: DTensor, + repeats: int, + dim: int, + ) -> DTensor: + """Forward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object for saving information needed in backward pass. + x : DTensor + Input DTensor. + repeats : int + Number of repetitions for each element. + dim : int + Dimension to repeat_interleave along. + + Returns + ------- + DTensor + DTensor after repeat_interleave operation. + """ + # Type checking + if not isinstance(x, DTensor): + raise TypeError(f"Expected DTensor, got {type(x)}") + if not isinstance(repeats, int): + raise TypeError(f"Expected int for repeats, got {type(repeats)}") + if not isinstance(dim, int): + raise TypeError(f"Expected int for dim, got {type(dim)}") + + dim_normalized = dim if dim >= 0 else dim + x.ndim + + # Check placements and handle sharded dimensions + for i_dim_device_mesh, placement in enumerate(x.placements): + if isinstance(placement, Shard): + # Check that sharded dimensions are evenly divided + if x.shape[placement.dim] % x.device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {x.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {x.device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + elif isinstance(placement, Partial): + raise ValueError(f"Placements of type {Partial} are not supported") + + x_local = x.to_local() + + # Perform operation on local tensors + output_local: Tensor = torch.repeat_interleave(x_local, repeats=repeats, dim=dim) + + # Compute output shape and stride + shape_output = list(x.shape) + shape_output[dim_normalized] = x.shape[dim_normalized] * repeats + shape_output = tuple(shape_output) + + # Use update_exhaustive_strides to compute new strides + strides_output = update_exhaustive_strides(output_local.shape, output_local.stride(), shape_output) + + # Create output DTensor using input tensor's device mesh and placements + result: DTensor = DTensor.from_local( + output_local, + device_mesh=x.device_mesh, + placements=x.placements, + shape=shape_output, + stride=strides_output, + ) + + # Save information for backward pass + ctx.repeats = repeats + ctx.dim_normalized = dim_normalized + ctx.input_device_mesh = x.device_mesh + ctx.input_placements = x.placements + ctx.input_shape = x.shape + ctx.input_stride = x.stride() + ctx.output_shape = result.shape + + return result + + @staticmethod + def backward( + ctx: FunctionCtx, + grad_output: DTensor, + ) -> tuple[DTensor, None, None]: + """Backward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object containing saved tensors and metadata from forward pass. + grad_output : DTensor + Gradient of the loss with respect to the output. + + Returns + ------- + tuple[DTensor, None, None] + Gradient with respect to input, None for repeats and dim parameters. + """ + # Check that grad_output has the expected shape, device_mesh and placements + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=grad_output, + dtensor_name="_ShardwiseRepeatInterleaveImpl.backward grad_output", + expected_shape=ctx.output_shape, + expected_device_mesh=ctx.input_device_mesh, + expected_placements=ctx.input_placements, + ) + + # Perform backward pass on local tensors + grad_output_local = grad_output.to_local() + + # Reshape and sum to reverse the repeat_interleave operation + # Get the original size along the dimension that was repeated + original_size = grad_output_local.shape[ctx.dim_normalized] // ctx.repeats + + # Unflatten the repeated dimension and sum along the repeats dimension + grad_unflattened = torch.unflatten(grad_output_local, ctx.dim_normalized, (original_size, ctx.repeats)) + grad_input_local = grad_unflattened.sum(dim=ctx.dim_normalized + 1) + + # Create output DTensor using the saved metadata + grad_input = DTensor.from_local( + grad_input_local, + device_mesh=ctx.input_device_mesh, + placements=ctx.input_placements, + shape=ctx.input_shape, + stride=ctx.input_stride, + ) + + return grad_input, None, None + + +def shardwise_repeat_interleave(x: DTensor, repeats: int, dim: int) -> DTensor: + """Repeat elements of a DTensor along a specified dimension. + + This function repeats elements of a DTensor along the specified dimension. + Each element along the specified dimension is repeated `repeats` times. + + Parameters + ---------- + x : DTensor + Input DTensor to repeat_interleave. + repeats : int + Number of repetitions for each element. + dim : int + Dimension to repeat_interleave along. + + Returns + ------- + DTensor + DTensor after repeat_interleave operation. + + Raises + ------ + TypeError + If inputs are not of correct type. + ValueError + If validation errors occur. + """ + return _ShardwiseRepeatInterleaveImpl.apply(x, repeats, dim) diff --git a/src/boltz/distributed/model/layers/replicate_op.py b/src/boltz/distributed/model/layers/replicate_op.py new file mode 100644 index 000000000..68a068fec --- /dev/null +++ b/src/boltz/distributed/model/layers/replicate_op.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Replicate op: lhs op rhs.unsqueeze(dim) with lhs sharded and rhs replicated on that dim.""" + +from enum import Enum, auto + +import torch +from torch.autograd.function import FunctionCtx +from torch.distributed.tensor import DTensor, Partial, Replicate, Shard + + +class ReplicateOp(Enum): + """Supported operations for replicate_op.""" + + ADD = auto() + SUB = auto() + PROD = auto() + DIV = auto() + + +class _ReplicateOpImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx: FunctionCtx, + lhs: DTensor, + rhs: DTensor, + dim_to_unsqueeze_rhs: int, + op: ReplicateOp, + ) -> DTensor: + if not isinstance(lhs, DTensor): + raise TypeError(f"Input 'lhs' must be of type DTensor. Got type {type(lhs)}.") + if not isinstance(rhs, DTensor): + raise TypeError(f"Input 'rhs' must be of type DTensor. Got type {type(rhs)}.") + if not isinstance(dim_to_unsqueeze_rhs, int): + raise TypeError(f"Input 'dim_to_unsqueeze_rhs' must be of type int. Got type {type(dim_to_unsqueeze_rhs)}.") + if not isinstance(op, ReplicateOp): + raise TypeError(f"Input 'op' must be of type ReplicateOp. Got type {type(op)}.") + + if op not in (ReplicateOp.ADD, ReplicateOp.SUB, ReplicateOp.PROD, ReplicateOp.DIV): + raise ValueError(f"Unsupported operation: {op}. Only ADD, SUB, PROD, and DIV are supported.") + + dim_to_unsqueeze_rhs = ( + dim_to_unsqueeze_rhs if dim_to_unsqueeze_rhs >= 0 else dim_to_unsqueeze_rhs + rhs.ndim + 1 + ) + shape_lhs_expected = ( + rhs.shape[:dim_to_unsqueeze_rhs] + (lhs.shape[dim_to_unsqueeze_rhs],) + rhs.shape[dim_to_unsqueeze_rhs:] + ) + + if lhs.shape != shape_lhs_expected: + raise ValueError( + f"Shape mismatch: lhs is expected to have the same shape as rhs except for " + f"the unsqueezed dimension {shape_lhs_expected} " + f"but got {lhs.shape}" + ) + + if lhs.device_mesh != rhs.device_mesh: + raise ValueError( + f"Device mesh mismatch: lhs.device_mesh={lhs.device_mesh} != rhs.device_mesh={rhs.device_mesh}" + ) + + placements_pair_expected = (Shard(dim_to_unsqueeze_rhs), Replicate()) + dim_device_mesh_reduce = None + for i_dim_device_mesh, (p_lhs, p_rhs) in enumerate(zip(lhs.placements, rhs.placements)): + if p_lhs == placements_pair_expected[0] and p_rhs == placements_pair_expected[1]: + if dim_device_mesh_reduce is not None: + raise ValueError( + f"Duplicate placements pair {placements_pair_expected} found " + f"in lhs.placements {lhs.placements} and rhs.placements {rhs.placements}" + ) + dim_device_mesh_reduce = i_dim_device_mesh + continue + if isinstance(p_lhs, Partial) or isinstance(p_rhs, Partial): + raise ValueError("Partial placements are not supported") + if isinstance(p_lhs, Shard) and isinstance(p_rhs, Shard): + if p_lhs.dim < dim_to_unsqueeze_rhs: + p_rhs_expected = Shard(p_lhs.dim) + else: + p_rhs_expected = Shard(p_lhs.dim - 1) + if p_rhs != p_rhs_expected: + raise ValueError( + f"rhs.placements[{i_dim_device_mesh}] is expected to be {p_rhs_expected} but got {p_rhs}" + ) + elif p_lhs != p_rhs: + raise ValueError( + f"lhs.placements[{i_dim_device_mesh}] is expected to be the same as rhs.placements[{i_dim_device_mesh}] " + f"but got {p_lhs} and {p_rhs}" + ) + if dim_device_mesh_reduce is None: + raise ValueError( + f"lhs.placements is expected to contain Shard({dim_to_unsqueeze_rhs}) and " + f"rhs.placements is expected to contain Replicate() along the same device mesh axis " + f"but got {lhs.placements} and {rhs.placements}" + ) + + if lhs.requires_grad or rhs.requires_grad: + ctx.placements_output = lhs.placements + ctx.placements_dLHS = lhs.placements + ctx.placements_dRHS = rhs.placements + ctx.device_mesh = lhs.device_mesh + ctx.dim_to_squeeze_dRHS = dim_to_unsqueeze_rhs + ctx.group_all_reduce_dRHS = ctx.device_mesh.get_group(dim_device_mesh_reduce) + ctx.op = op + ctx.shape_dLHS = lhs.shape + ctx.stride_dLHS = lhs.stride() + ctx.shape_dRHS = rhs.shape + ctx.stride_dRHS = rhs.stride() + + lhs_local = lhs.to_local() + rhs_local = rhs.to_local() + rhs_unsqueezed_local = rhs_local.unsqueeze(dim_to_unsqueeze_rhs) + + if op == ReplicateOp.ADD: + output_local = lhs_local + rhs_unsqueezed_local + elif op == ReplicateOp.SUB: + output_local = lhs_local - rhs_unsqueezed_local + elif op == ReplicateOp.PROD: + output_local = lhs_local * rhs_unsqueezed_local + ctx.save_for_backward( + lhs_local.detach().clone() if rhs.requires_grad else None, + rhs_local.detach().clone() if lhs.requires_grad else None, + ) + elif op == ReplicateOp.DIV: + output_local = lhs_local / rhs_unsqueezed_local + ctx.save_for_backward( + lhs_local.detach().clone() if rhs.requires_grad else None, + rhs_local.detach().clone(), + ) + else: + raise ValueError(f"Unsupported operation: {op}") + + output = DTensor.from_local( + output_local, placements=lhs.placements, device_mesh=lhs.device_mesh, shape=lhs.shape, stride=lhs.stride() + ) + return output + + @staticmethod + def backward(ctx: FunctionCtx, grad_output: DTensor) -> tuple: + grad_lhs = None + grad_rhs = None + + grad_output_local = grad_output.to_local() + + if ctx.op == ReplicateOp.ADD: + if ctx.needs_input_grad[0]: + grad_lhs = DTensor.from_local( + grad_output_local.clone(), + placements=ctx.placements_dLHS, + device_mesh=ctx.device_mesh, + shape=ctx.shape_dLHS, + stride=ctx.stride_dLHS, + ) + if ctx.needs_input_grad[1]: + grad_rhs_reduced_local = grad_output_local.sum(dim=ctx.dim_to_squeeze_dRHS) + elif ctx.op == ReplicateOp.SUB: + if ctx.needs_input_grad[0]: + grad_lhs = DTensor.from_local( + grad_output_local.clone(), + placements=ctx.placements_dLHS, + device_mesh=ctx.device_mesh, + shape=ctx.shape_dLHS, + stride=ctx.stride_dLHS, + ) + if ctx.needs_input_grad[1]: + grad_rhs_reduced_local = -grad_output_local.sum(dim=ctx.dim_to_squeeze_dRHS) + elif ctx.op == ReplicateOp.PROD: + lhs_local, rhs_local = ctx.saved_tensors + if ctx.needs_input_grad[0]: + rhs_unsqueezed_local = rhs_local.unsqueeze(ctx.dim_to_squeeze_dRHS) + grad_lhs = DTensor.from_local( + grad_output_local * rhs_unsqueezed_local, + placements=ctx.placements_dLHS, + device_mesh=ctx.device_mesh, + shape=ctx.shape_dLHS, + stride=ctx.stride_dLHS, + ) + if ctx.needs_input_grad[1]: + grad_rhs_reduced_local = (grad_output_local * lhs_local).sum(dim=ctx.dim_to_squeeze_dRHS) + elif ctx.op == ReplicateOp.DIV: + lhs_local, rhs_local = ctx.saved_tensors + rhs_unsqueezed_local = rhs_local.unsqueeze(ctx.dim_to_squeeze_dRHS) + grad_lhs_local = grad_output_local / rhs_unsqueezed_local + if ctx.needs_input_grad[0]: + grad_lhs = DTensor.from_local( + grad_lhs_local, + placements=ctx.placements_dLHS, + device_mesh=ctx.device_mesh, + shape=ctx.shape_dLHS, + stride=ctx.stride_dLHS, + ) + if ctx.needs_input_grad[1]: + grad_rhs_reduced_local = -(grad_lhs_local * lhs_local / rhs_unsqueezed_local).sum( + dim=ctx.dim_to_squeeze_dRHS + ) + + if ctx.needs_input_grad[1]: + torch.distributed.all_reduce( + grad_rhs_reduced_local, + op=torch.distributed.ReduceOp.SUM, + group=ctx.group_all_reduce_dRHS, + async_op=False, + ) + grad_rhs = DTensor.from_local( + grad_rhs_reduced_local, + placements=ctx.placements_dRHS, + device_mesh=ctx.device_mesh, + shape=ctx.shape_dRHS, + stride=ctx.stride_dRHS, + ) + + return grad_lhs, grad_rhs, None, None + + +def replicate_op(lhs: DTensor, rhs: DTensor, dim_to_unsqueeze_rhs: int, op: ReplicateOp) -> DTensor: + """lhs op rhs.unsqueeze(dim_to_unsqueeze_rhs); lhs sharded on that dim, rhs replicated on it.""" + return _ReplicateOpImpl.apply(lhs, rhs, dim_to_unsqueeze_rhs, op) diff --git a/src/boltz/distributed/model/layers/scatter.py b/src/boltz/distributed/model/layers/scatter.py new file mode 100644 index 000000000..79b6030e7 --- /dev/null +++ b/src/boltz/distributed/model/layers/scatter.py @@ -0,0 +1,706 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from typing import Dict, List, Tuple + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor, Shard +from torch.distributed.tensor import Partial + +from boltz.distributed.model.layers.outer_gather import get_overlap_from_peers +from boltz.distributed.utils import update_exhaustive_strides + + +class DistributedScatterReduce(torch.autograd.Function): + @staticmethod + def forward( + ctx, + output_size_per_rank: int, + axis: int, + idx_dtensor: DTensor, + src_dtensor: DTensor, + reduce: str, + idx_mask: DTensor | None, + are_ids_contiguous: bool, + ) -> DTensor: + """Distributed scatter reduce. + + Scatters values from src into an output tensor at positions specified by idx, + applying a reduction operation for duplicate indices. The output is initialized + to zeros (not from a dst tensor). + + Args: + output_size_per_rank: Size of the output's scatter axis per rank. The global + output shape will be ``(*batch, output_size_per_rank * num_shards, *features)``. + axis: Axis corresponding to the scatter dimension in output and src. + idx_dtensor: DTensor with shape ``(*batch, N_src)`` that provides scatter + indices into the output's scatter dimension. Values must be in range + [0, output_size_per_rank * num_shards). + src_dtensor: DTensor with shape ``(*batch, N_src, *features)`` - source data + to scatter into output. Must have same placements as idx_dtensor. + reduce: Reduction operation - "sum" or "mean". + idx_mask: Optional DTensor with shape ``(*batch, N_src)`` and same device_mesh + and placements as ``idx_dtensor``. Elements with True indicate valid indices, + elements with False indicate invalid indices that should be ignored. + are_ids_contiguous: This is a heuristic for selecting the underlying + send/recv strategy for performance purpose. Currently only True is supported, + which means that the idx_dtensor maps to a contiguous block of dst along (axis) + dimensions for all the shards and for all the leading (batch) dimensions. + When True, the underlying strategy will use the min/max of idx_dtensor to + compute the needed interval, assuming the resulting buffer to be communicated + across the ranks is fully (or approximately so) utilized. + + Returns: + DTensor with shape ``(*batch, output_size_per_rank * num_shards, *features)`` + containing the scattered and reduced values. + """ + if not are_ids_contiguous: + raise NotImplementedError("DistributedScatterReduce currently only supports are_ids_contiguous=True") + + if reduce not in ("sum", "mean"): + raise ValueError( + f"reduce must be 'sum' or 'mean', got '{reduce}'. " + "Other reductions (amax, amin, prod) are not supported due to complex backward requirements." + ) + + if not isinstance(idx_dtensor, DTensor) or not isinstance(src_dtensor, DTensor): + raise TypeError("idx_dtensor and src_dtensor must be DTensors") + + # Validate shapes + # idx: (*batch, N_src) - indices into output dimension + # src: (*batch, N_src, *features) - values to scatter + # output: (*batch, output_size_per_rank * num_shards, *features) + batch_dims = idx_dtensor.shape[:axis] + feature_shape = src_dtensor.shape[axis + 1 :] + + # idx should be (*batch, N_src) + if idx_dtensor.ndim != axis + 1: + raise ValueError(f"idx_dtensor should have ndim={axis + 1} for axis={axis}, got ndim={idx_dtensor.ndim}") + + N_src = idx_dtensor.shape[axis] + + # src should be (*batch, N_src, *features) + if src_dtensor.shape[:axis] != batch_dims: + raise ValueError(f"Batch dimensions must match: idx {batch_dims} vs src {src_dtensor.shape[:axis]}") + if src_dtensor.shape[axis] != N_src: + raise ValueError(f"src axis {axis} size {src_dtensor.shape[axis]} must match idx axis size {N_src}") + + # Validate device_mesh and placements + mesh = idx_dtensor.device_mesh + if src_dtensor.device_mesh != mesh: + raise ValueError("idx and src must be on the same DeviceMesh") + if src_dtensor.placements != idx_dtensor.placements: + raise ValueError("idx and src must have identical placements") + + # Validate idx_mask if provided + if idx_mask is not None: + if not isinstance(idx_mask, DTensor): + raise TypeError("idx_mask must be a DTensor") + if idx_mask.shape != idx_dtensor.shape: + raise ValueError(f"idx_mask shape {idx_mask.shape} must match idx_dtensor shape {idx_dtensor.shape}") + if idx_mask.device_mesh != idx_dtensor.device_mesh: + raise ValueError("idx_mask must have the same device_mesh as idx_dtensor") + if idx_mask.placements != idx_dtensor.placements: + raise ValueError("idx_mask must have the same placements as idx_dtensor") + if idx_mask.dtype != torch.bool: + raise TypeError( + f"idx_mask must have dtype torch.bool, got {idx_mask.dtype}. Use mask.bool() to convert." + ) + + placements = idx_dtensor.placements + + ndim_src = src_dtensor.ndim + if axis < 0: + axis += ndim_src + if axis < 0 or axis >= ndim_src: + raise ValueError(f"axis {axis} out of range for src.ndim={ndim_src}") + + # Identify shard axis on mesh for the scatter dimension + mesh_dim_axis = None + for i_mesh_dim, p in enumerate(placements): + if isinstance(p, Partial): + raise ValueError("Partial placements are not supported") + if isinstance(p, Shard): + if p.dim == axis: + mesh_dim_axis = i_mesh_dim + # Enforce even sharding for idx and src + if idx_dtensor.shape[p.dim] % mesh.size(i_mesh_dim) != 0: + raise ValueError( + f"idx_dtensor axis {p.dim} size {idx_dtensor.shape[p.dim]} " + f"not evenly divisible by mesh dim {i_mesh_dim}" + ) + if src_dtensor.shape[p.dim] % mesh.size(i_mesh_dim) != 0: + raise ValueError( + f"src_dtensor axis {p.dim} size {src_dtensor.shape[p.dim]} " + f"not evenly divisible by mesh dim {i_mesh_dim}" + ) + if idx_mask is not None and idx_mask.shape[p.dim] % mesh.size(i_mesh_dim) != 0: + raise ValueError( + f"idx_mask axis {p.dim} size {idx_mask.shape[p.dim]} " + f"not evenly divisible by mesh dim {i_mesh_dim}" + ) + + if mesh_dim_axis is None: + raise ValueError(f"Tensors must be sharded along axis {axis}") + + # Compute output shapes + size_group = mesh.size(mesh_dim_axis) + output_global_size = output_size_per_rank * size_group + output_global_shape = batch_dims + (output_global_size,) + feature_shape + + idx_local = idx_dtensor.to_local() + src_local = src_dtensor.to_local() + idx_mask_local = idx_mask.to_local() if idx_mask is not None else None + device = idx_local.device + cpu_device = torch.device("cpu") + + # Local output shape: use local batch dims (from src_local) + output_size_per_rank + local features + batch_dims_local = src_local.shape[:axis] + feature_shape_local = src_local.shape[axis + 1 :] + output_local_shape = batch_dims_local + (output_size_per_rank,) + feature_shape_local + + # Compute write interval (where our local idx values want to write to) + if idx_local.numel() > 0: + if idx_mask_local is not None: + if idx_mask_local.any(): + valid_idx = idx_local[idx_mask_local] + write_interval = torch.stack(valid_idx.aminmax()).to(dtype=torch.long).unsqueeze(0) # (1,2) + write_interval[:, -1] += 1 + else: + write_interval = torch.tensor([[0, 0]], device=device, dtype=torch.long) + else: + write_interval = torch.stack(idx_local.aminmax()).to(dtype=torch.long).unsqueeze(0) + write_interval[:, -1] += 1 + else: + write_interval = torch.tensor([[0, 0]], device=device, dtype=torch.long) + + write_start = write_interval[0, 0] + write_end = write_interval[0, 1] + + # Owned chunk interval (this rank's portion of output) + coord_axis = mesh.get_local_rank(mesh_dim_axis) + own_start = torch.tensor(coord_axis * output_size_per_rank, device=cpu_device, dtype=torch.long) + own_end = own_start + output_size_per_rank + own_interval = torch.stack([own_start, own_end]).unsqueeze(0) # (1,2) + + # All-gather write intervals along sharded mesh dim (metadata only) + group_axis = mesh.get_group(mesh_dim_axis) + write_range = [torch.zeros_like(write_interval) for _ in range(size_group)] + dist.all_gather(write_range, write_interval, group=group_axis) + write_range = torch.stack(write_range) # (size_group, 1, 2) + write_range_cpu = write_range.cpu() + + ranks_global_on_mesh = mesh.mesh + my_coords = mesh.get_coordinate() + index_list_submesh = [] + for dim in range(ranks_global_on_mesh.ndim): + if dim == mesh_dim_axis: + index_list_submesh.append(slice(None)) + else: + index_list_submesh.append(torch.tensor(my_coords[dim], device=cpu_device)) + ranks_global_on_submesh = ranks_global_on_mesh[tuple(index_list_submesh)] # (size_group,) + + my_rank = mesh.mesh[tuple(my_coords)].item() + + # Compute owned intervals for all peers + start_peers_own = torch.arange(size_group, device=cpu_device, dtype=torch.long) * output_size_per_rank + end_peers_own = start_peers_own + output_size_per_rank + interval_peers_own = torch.stack([start_peers_own, end_peers_own], dim=-1).unsqueeze(1) # (size_group, 1, 2) + + # Shape info (batch dims are same for output/src/idx due to same placements) + N_src_local = idx_local.shape[axis] # local N_src dimension size + shape_leading_flat = torch.Size(batch_dims_local).numel() if batch_dims_local else 1 + shape_trailing_flat = torch.Size(feature_shape_local).numel() if feature_shape_local else 1 + + # Flatten for easier processing + # src: (*batch, N_src, *features) -> (B, N_src, F) + # idx: (*batch, N_src) -> (B, N_src) + src_flat = src_local.reshape(shape_leading_flat, N_src_local, shape_trailing_flat) + idx_flat = idx_local.reshape(shape_leading_flat, N_src_local) + if idx_mask_local is not None: + mask_flat = idx_mask_local.reshape(shape_leading_flat, N_src_local) + else: + mask_flat = None + + # SEND PLAN: determine which peers we need to send data to + # For each peer whose owned_interval overlaps with our write_interval + if write_start >= write_end: + send_plan: List[Dict] = [] + else: + send_plan = get_overlap_from_peers( + ranks_global_on_submesh, interval_peers_own, write_interval.to(cpu_device) + ) + + # RECV PLAN: determine which peers will send data to us + # For each peer whose write_interval overlaps with our owned_interval + recv_plan = get_overlap_from_peers( + ranks_global_on_submesh, + write_range_cpu.view(size_group, 1, 2), + own_interval.view(1, 1, 2), + ) + + # Batch indices tensor for indexing into flattened (B, N_src) tensors + # batch_indices[b, n] = b, used to track which batch element each element belongs to + batch_indices = torch.arange(shape_leading_flat, device=device).unsqueeze(1).expand(-1, N_src_local) + + # Phase 1: Prepare data to send and count elements + # send_counts[i] = number of (idx, src) pairs this rank will send to peer i + # Used for all-gather so each rank knows how much data to expect from others + send_counts = torch.zeros(size_group, device=device, dtype=torch.long) + # send_data_per_peer[peer_global_rank] = (batch_idx, dst_local_idx, src_values) + # - batch_idx: (count,) which batch element each pair belongs to + # - dst_local_idx: (count,) destination index relative to peer's owned chunk start + # - src_values: (count, F) flattened source values to scatter + send_data_per_peer: Dict[int, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {} + + for item in send_plan: + peer = item["peer"] + interval = item["interval"] + peer_start = interval[0, 0].item() + peer_end = interval[0, 1].item() + + # Create mask for idx values that fall in peer's interval + in_peer_interval = (idx_flat >= peer_start) & (idx_flat < peer_end) + if mask_flat is not None: + in_peer_interval = in_peer_interval & mask_flat + + count = in_peer_interval.sum().item() + # Get peer's rank within the group (peer is global rank, need group-local rank) + peer_coord = dist.get_group_rank(group_axis, peer) + send_counts[peer_coord] = count + + if count > 0: + peer_batch_idx = batch_indices[in_peer_interval] # (count,) + + # idx values adjusted to be relative to peer's owned chunk start + # Note: peer_start/peer_end are the overlap interval used for filtering, + # but we need peer's actual owned start for computing local indices + peer_owned_start = peer_coord * output_size_per_rank + peer_idx = idx_flat[in_peer_interval] - peer_owned_start # (count,) + + # src values + peer_src = src_flat[in_peer_interval] # (count, F) + + send_data_per_peer[peer] = (peer_batch_idx, peer_idx, peer_src) + + # All-gather counts from all ranks + all_counts = [torch.zeros_like(send_counts) for _ in range(size_group)] + dist.all_gather(all_counts, send_counts, group=group_axis) + all_counts = torch.stack(all_counts) # (size_group, size_group) + + # all_counts[i, j] = count of elements rank i sends to rank j + my_coord_in_group = coord_axis + recv_counts = all_counts[:, my_coord_in_group] # counts I receive from each rank + + # Phase 2: Exchange actual data + ops = [] + recv_bufs: Dict[int, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {} + + # Prepare receives (including self-data from send_data_per_peer) + for item in recv_plan: + peer = item["peer"] + if peer == my_rank: + # Self-data: directly use local data from send_data_per_peer instead of P2P + # This unifies self-data and received-data processing in the scatter loop below + if my_rank in send_data_per_peer: + recv_bufs[my_rank] = send_data_per_peer[my_rank] + continue + # Get peer's rank within the group (peer is global rank, need group-local rank) + peer_coord = dist.get_group_rank(group_axis, peer) + count = recv_counts[peer_coord].item() + if count > 0: + recv_batch_buf = torch.empty(count, device=device, dtype=torch.long) + recv_idx_buf = torch.empty(count, device=device, dtype=idx_local.dtype) + recv_src_buf = torch.empty(count, shape_trailing_flat, device=device, dtype=src_local.dtype) + ops.append(dist.P2POp(dist.irecv, recv_batch_buf, peer)) + ops.append(dist.P2POp(dist.irecv, recv_idx_buf, peer)) + ops.append(dist.P2POp(dist.irecv, recv_src_buf, peer)) + recv_bufs[peer] = (recv_batch_buf, recv_idx_buf, recv_src_buf) + + # Prepare sends (skip self since we handled it above) + for item in send_plan: + peer = item["peer"] + if peer == my_rank: + continue + if peer in send_data_per_peer: + peer_batch_idx, peer_idx, peer_src = send_data_per_peer[peer] + ops.append(dist.P2POp(dist.isend, peer_batch_idx.contiguous(), peer)) + ops.append(dist.P2POp(dist.isend, peer_idx.contiguous(), peer)) + ops.append(dist.P2POp(dist.isend, peer_src.contiguous(), peer)) + + # Execute P2P + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # Local scatter_reduce operation + # Initialize output with zeros (no dst tensor to clone from) + out_local = torch.zeros(output_local_shape, dtype=src_local.dtype, device=device) + + # Linearize output for scatter_reduce: (B, L, F) -> (B*L, F) + # This allows us to use a single scatter_reduce_ call instead of looping over batches. + # The key insight: scatter_reduce_ only indexes along one dimension, but we need to + # index by (batch, position) pairs. By linearizing to (B*L, F), we can compute + # linear_idx = batch_idx * L + position_idx to get a single index into the B*L dimension. + out_linear = out_local.reshape(shape_leading_flat * output_size_per_rank, shape_trailing_flat) + + # For "mean" reduction, we need to track counts for each output position + # count[j] = number of src values scattered to position j + # This is needed for both the forward (to compute mean) and backward (grad = grad_out / count) + if reduce == "mean": + # Initialize counts to 0 (no include_self since there's no dst) + count_linear = torch.zeros( + shape_leading_flat * output_size_per_rank, 1, dtype=out_local.dtype, device=device + ) + + # Process all data uniformly (self-data + received data from peers) + # We always use scatter_add_ since we're accumulating values and dividing for mean at the end + # Both are stored in recv_bufs as (batch_idx, local_dst_idx, src_values) + for peer, (recv_batch_idx, recv_idx, recv_src) in recv_bufs.items(): + recv_idx_long = recv_idx.to(torch.long) + + # Compute linearized index: maps (batch, position) -> single index in B*L + linear_idx = recv_batch_idx * output_size_per_rank + recv_idx_long + linear_idx_expanded = linear_idx.unsqueeze(-1).expand(-1, shape_trailing_flat) + + out_linear.scatter_add_(0, linear_idx_expanded, recv_src) + + # Accumulate counts for "mean" + if reduce == "mean": + ones = torch.ones(linear_idx.shape[0], 1, dtype=out_local.dtype, device=device) + count_linear.scatter_add_(0, linear_idx.unsqueeze(-1), ones) + + # For "mean", divide the accumulated sum by counts + # Positions with count=0 are already 0 from initialization, so 0/1=0 naturally + count_local = None + if reduce == "mean": + out_linear = out_linear / count_linear.clamp(min=1).expand(-1, shape_trailing_flat) + count_local = count_linear.reshape(shape_leading_flat, output_size_per_rank, 1) + count_local = count_local.expand(-1, -1, shape_trailing_flat).reshape(output_local_shape) + + # Reshape back to original structure + out_local = out_linear.reshape(output_local_shape) + + strides_out = update_exhaustive_strides(out_local.shape, out_local.stride(), output_global_shape) + + out_dtensor = DTensor.from_local(out_local, mesh, placements, shape=output_global_shape, stride=strides_out) + + # Save for backward + # For "mean", we also save count_local to correctly scale gradients + tensors_to_save = [idx_local, src_local] + if idx_mask_local is not None: + tensors_to_save.append(idx_mask_local) + if count_local is not None: + tensors_to_save.append(count_local) + ctx.save_for_backward(*tensors_to_save) + ctx.has_mask = idx_mask_local is not None + ctx.has_count = count_local is not None + ctx.reduce = reduce + ctx.comm_meta = { + "axis": axis, + "placements": placements, + "output_global_shape": output_global_shape, + "output_size_per_rank": output_size_per_rank, + "src_global_shape": src_dtensor.shape, + "output_global_stride": strides_out, + "src_global_stride": src_dtensor.stride(), + "own_interval": own_interval, + "device_mesh": mesh, + "mesh_dim_axis": mesh_dim_axis, + } + + return out_dtensor + + @staticmethod + def backward(ctx, grad_output: DTensor) -> Tuple[None, None, None, DTensor, None, None, None]: + """Backward pass for DistributedScatterReduce. + + The backward of scatter_reduce is essentially a gather operation: + - For "sum": grad_src = gather(grad_output, idx) + - For "mean": grad_src = gather(grad_output / count, idx) + + Note on reusing distributed_gather: + While the backward is semantically a gather operation, we don't directly call + distributed_gather here because: + 1. Shape mismatch: distributed_gather expects idx shape (*batch, K, W), but our + idx has shape (*batch, N_src) - a 1D index per batch element. + 2. For "mean", we need to gather from (grad_output / count), not raw grad_output. + The count tensor is local to each rank's output shard, so we need to apply the + division before sending chunks to other ranks during the gather. + + Instead, we implement the gather communication pattern directly, which also allows + us to handle the count scaling for "mean" reduction efficiently by using the + count-scaled gradient (grad_gather_source = grad_local / count_local) as the source. + + Returns gradients for: (output_size_per_rank, axis, idx, src, reduce, idx_mask, are_ids_contiguous) + Only src requires a gradient; others return None. + """ + # Unpack saved tensors based on what was saved + saved = ctx.saved_tensors + idx_local = saved[0] + src_local = saved[1] + tensor_idx = 2 + + idx_mask_local = None + if ctx.has_mask: + idx_mask_local = saved[tensor_idx] + tensor_idx += 1 + + count_local = None + if ctx.has_count: + count_local = saved[tensor_idx] + + reduce = ctx.reduce + meta = ctx.comm_meta + axis = meta["axis"] + placements = meta["placements"] + output_size_per_rank = meta["output_size_per_rank"] + src_global_shape = meta["src_global_shape"] + src_global_stride = meta["src_global_stride"] + own_interval = meta["own_interval"] + mesh = meta["device_mesh"] + mesh_dim_axis = meta["mesh_dim_axis"] + + grad_local = grad_output.to_local().contiguous() + device = grad_local.device + cpu_device = torch.device("cpu") + + # For "mean", we need to scale by 1/count because: + # output[j] = sum of src[i] where idx[i]==j / count[j] + # grad_src[i] = grad_output[idx[i]] * d(output[idx[i]])/d(src[i]) = grad_output[idx[i]] / count[idx[i]] + # We use the count-scaled gradients so that gathered values are automatically divided + # by count[idx] from the owning rank + if reduce == "mean" and count_local is not None: + # Clamp count to avoid division by zero for positions that received no values + count_clamped = count_local.clamp(min=1) + grad_gather_source = grad_local / count_clamped + else: + grad_gather_source = grad_local + + # grad_src requires gathering from grad_gather_source at idx positions + # This is essentially the same communication pattern as DistributedGather + + # Compute need interval (what we need from grad_output for our grad_src) + if idx_local.numel() > 0: + if idx_mask_local is not None: + if idx_mask_local.any(): + valid_idx = idx_local[idx_mask_local] + need_interval = torch.stack(valid_idx.aminmax()).to(dtype=torch.long).unsqueeze(0) + need_interval[:, -1] += 1 + else: + need_interval = torch.tensor([[0, 0]], device=device, dtype=torch.long) + else: + need_interval = torch.stack(idx_local.aminmax()).to(dtype=torch.long).unsqueeze(0) + need_interval[:, -1] += 1 + else: + need_interval = torch.tensor([[0, 0]], device=device, dtype=torch.long) + + need_start = need_interval[0, 0] + need_end = need_interval[0, 1] + need_start_cpu = need_start.to(cpu_device) + + own_start = own_interval[0, 0] + + # All-gather need intervals + group_axis = mesh.get_group(mesh_dim_axis) + size_group = mesh.size(mesh_dim_axis) + need_range = [torch.zeros_like(need_interval) for _ in range(size_group)] + dist.all_gather(need_range, need_interval, group=group_axis) + need_range = torch.stack(need_range) + need_range_cpu = need_range.cpu() + + ranks_global_on_mesh = mesh.mesh + my_coords = mesh.get_coordinate() + index_list_submesh = [] + for dim in range(ranks_global_on_mesh.ndim): + if dim == mesh_dim_axis: + index_list_submesh.append(slice(None)) + else: + index_list_submesh.append(torch.tensor(my_coords[dim], device=cpu_device)) + ranks_global_on_submesh = ranks_global_on_mesh[tuple(index_list_submesh)] + + my_rank = mesh.mesh[tuple(my_coords)].item() + + # RECEIVE PLAN for backward (gather pattern) + if need_start >= need_end: + needed_chunks: List[Dict] = [] + else: + start_peers_own = torch.arange(size_group, device=cpu_device, dtype=torch.long) * output_size_per_rank + end_peers_own = start_peers_own + output_size_per_rank + interval_peers_own = torch.stack([start_peers_own, end_peers_own], dim=-1).unsqueeze(1) + + needed_chunks = get_overlap_from_peers( + ranks_global_on_submesh, interval_peers_own, need_interval.to(cpu_device) + ) + + ops = [] + recv_bufs = {} + + for item in needed_chunks: + peer = item["peer"] + interval = item["interval"] + start_global = interval[0, 0] + length = interval[0, 1] - interval[0, 0] + + shape = list(grad_gather_source.shape) + shape[axis] = length.item() + buf = torch.empty(shape, dtype=grad_gather_source.dtype, device=device) + + if peer == my_rank: + start_local = start_global - own_start + buf.copy_(grad_gather_source.narrow(axis, start_local.item(), length.item())) + recv_bufs[peer] = buf + else: + ops.append(dist.P2POp(dist.irecv, buf, peer)) + recv_bufs[peer] = buf + + # SEND PLAN for backward + send_chunks = get_overlap_from_peers( + ranks_global_on_submesh, + need_range_cpu.view(size_group, 1, 2), + own_interval.view(1, 1, 2), + ) + + for item in send_chunks: + peer = item["peer"] + if peer == my_rank: + continue + interval = item["interval"] + start_global = interval[0, 0] + length = interval[0, 1] - interval[0, 0] + start_local = start_global - own_start + chunk = grad_gather_source.narrow(axis, start_local.item(), length.item()).contiguous() + ops.append(dist.P2POp(dist.isend, chunk, peer)) + + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # Assemble gradient buffer + if need_start >= need_end: + buffer_shape = list(grad_gather_source.shape) + buffer_shape[axis] = 0 + grad_buffer = torch.empty(buffer_shape, dtype=grad_gather_source.dtype, device=device) + else: + buffer_shape = list(grad_gather_source.shape) + buffer_shape[axis] = (need_end - need_start).item() + grad_buffer = torch.zeros(buffer_shape, dtype=grad_gather_source.dtype, device=device) + + for item in needed_chunks: + peer = item["peer"] + interval = item["interval"] + buf = recv_bufs.get(peer) + if buf is None: + raise RuntimeError(f"Missing recv buffer for peer {peer}") + + start_global = interval[0, 0] + length = interval[0, 1] - interval[0, 0] + start_local_buf = start_global - need_start_cpu + target = grad_buffer.narrow(axis, start_local_buf.item(), length.item()) + target.copy_(buf) + + # Gather grad_src from grad_buffer + idx_local_buffer = idx_local - need_start.item() + + shape_trailing = grad_buffer.shape[axis + 1 :] + shape_trailing_flat = torch.Size(shape_trailing).numel() if shape_trailing else 1 + + shape_leading = grad_buffer.shape[:axis] + shape_leading_flat = torch.Size(shape_leading).numel() + L = grad_buffer.shape[axis] + N_src_local = idx_local.shape[axis] + + if L == 0: + grad_src_local = torch.zeros_like(src_local) + else: + if idx_mask_local is not None: + idx_local_buffer = torch.where(idx_mask_local, idx_local_buffer, torch.zeros_like(idx_local_buffer)) + + grad_flat = grad_buffer.reshape(shape_leading_flat, L, shape_trailing_flat) + idx_flat = idx_local_buffer.reshape(shape_leading_flat, N_src_local) + + batch_idx = torch.arange(shape_leading_flat, device=device).reshape(shape_leading_flat, 1) + grad_src_flat = grad_flat[batch_idx, idx_flat, :] # (B, N_src, F) + + if idx_mask_local is not None: + mask_flat = idx_mask_local.reshape(shape_leading_flat, N_src_local, 1).to(grad_src_flat.dtype) + grad_src_flat = grad_src_flat * mask_flat + + if shape_trailing: + grad_src_local = grad_src_flat.reshape(*shape_leading, N_src_local, *shape_trailing) + else: + grad_src_local = grad_src_flat.reshape(*shape_leading, N_src_local) + + # Construct grad_src DTensor using saved strides from forward pass + grad_src_dtensor = DTensor.from_local( + grad_src_local.contiguous(), + grad_output.device_mesh, + placements, + shape=src_global_shape, + stride=src_global_stride, + ) + + # Return: (output_size_per_rank, axis, idx, src, reduce, idx_mask, are_ids_contiguous) + # Only src requires gradient + return None, None, None, grad_src_dtensor, None, None, None + + +def distributed_scatter_reduce( + output_size_per_rank: int, + axis: int, + idx: DTensor, + src: DTensor, + reduce: str, + idx_mask: DTensor | None = None, + are_ids_contiguous: bool = False, +) -> DTensor: + """Distributed scatter reduce. + + Scatters values from src into an output tensor at positions specified by idx, + applying a reduction operation for duplicate indices. The output is initialized + to zeros. + + This is a P2P-based alternative to: + ``torch.zeros(output_shape).scatter_reduce_(axis, idx.full_tensor(), src.full_tensor(), reduce)`` + + Args: + output_size_per_rank: Size of the output's scatter axis per rank. The global + output shape will be ``(*batch, output_size_per_rank * num_shards, *features)``. + axis: Axis corresponding to the scatter dimension in output and src. + idx: DTensor with shape ``(*batch, N_src)`` that provides scatter indices + into the output's scatter dimension. Values must be in range + [0, output_size_per_rank * num_shards). + src: DTensor with shape ``(*batch, N_src, *features)`` - source data to scatter. + Must have same placements as idx. + reduce: Reduction operation - "sum" or "mean". + idx_mask: Optional DTensor with shape ``(*batch, N_src)`` and same device_mesh + and placements as ``idx``. Elements with True indicate valid indices, + elements with False indicate invalid indices that should be ignored. + are_ids_contiguous: Heuristic for send/recv strategy. Currently only True is supported. + + Returns: + DTensor: Result with shape ``(*batch, output_size_per_rank * num_shards, *features)`` + and same placements as idx. + """ + return DistributedScatterReduce.apply(output_size_per_rank, axis, idx, src, reduce, idx_mask, are_ids_contiguous) diff --git a/src/boltz/distributed/model/layers/sharded_op.py b/src/boltz/distributed/model/layers/sharded_op.py new file mode 100644 index 000000000..0d356300f --- /dev/null +++ b/src/boltz/distributed/model/layers/sharded_op.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import torch +from torch.distributed.tensor import DTensor, Partial, Replicate, Shard + +from boltz.distributed.utils import update_exhaustive_strides + + +class _ShardedSumImpl(torch.autograd.Function): + """Distributed implementation of sharded summation aggregation on pair input using DTensors.""" + + @staticmethod + def forward(ctx, x: DTensor, dim: tuple[int, ...] | int, keepdim: bool = False) -> DTensor: + """Forward pass of distributed sharded summation aggregation. + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object for saving information needed in backward pass. + x : DTensor + Input tensor. Can have any shape and sharding strategy. + dim : tuple[int, ...] | int + Dimensions to reduce over. All of which must be sharded. + keepdim : bool, default=False + Whether to keep the reduced dimensions in the output. + + Returns + ------- + DTensor + Output tensor with reduced dimensions, maintaining the same device mesh + and placement strategy as the input tensor. + + Raises + ------ + TypeError + If input is not a DTensor. + ValueError + If Partial placements are used (not supported), or if dims are invalid. + """ + if not isinstance(x, DTensor): + raise TypeError(f"Input 'x' must be of type DTensor. Got type {type(x)}.") + + device_mesh = x.device_mesh + input_placements = x.placements + input_shape = x.shape + + # Validate placements and cache sharded dims + sharded_dims = set() + for i_dim_device_mesh, placement in enumerate(input_placements): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + if input_shape[placement.dim] % device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {input_shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + sharded_dims.add(placement.dim) + if isinstance(dim, int): + dim = (dim,) + dims = dim + + if not dims: + raise ValueError("Received empty dims argument") + + reduced_dims = [] + for dim in dims: + if dim < 0: + dim = len(input_shape) + dim + if dim >= len(input_shape): + raise ValueError(f"Input tensor has {len(input_shape)} dimensions but got dims {dims}") + if dim not in sharded_dims: + raise ValueError(f"Expected all dims are sharded but got {dims} for placements: {input_placements}") + reduced_dims.append(dim) + reduced_dims = tuple(sorted(reduced_dims)) + + # remap sharded dims due to shape change when keepdim=False + if not keepdim: + map_dims = {} + counter = 0 + for dim in range(x.ndim): + map_dims[dim] = dim - counter + counter += dim in reduced_dims + + x_local = x.to_local() + output_local = torch.sum(x_local, dim=reduced_dims, keepdim=keepdim) # new copy + + output_placements = [] + for placement, placement_group in zip(input_placements, device_mesh.get_all_groups()): + if isinstance(placement, Shard) and placement.dim in reduced_dims: + torch.distributed.all_reduce( + output_local, + op=torch.distributed.ReduceOp.SUM, + group=placement_group, + ) + output_placements.append(Replicate()) + + elif keepdim: # Shortcut for keepdim=True + output_placements.append(placement) + continue + + # Shift placement dimensions when keepdim=False + elif isinstance(placement, Shard): + output_placements.append(Shard(map_dims[placement.dim])) + elif isinstance(placement, Replicate): + output_placements.append(placement) + + if x.requires_grad: + ctx.device_mesh = device_mesh + ctx.input_placements = input_placements + ctx.input_local_shape = x_local.shape + ctx.input_shape = x.shape + ctx.input_stride = x.stride() + ctx.keepdim = keepdim + ctx.reduced_dims = reduced_dims + + # Compute output shape and stride + # Shape stays the same as if keepdim=True, just reduced dimensions become size 1 + shape_output = list(x.shape) + for dim in reduced_dims: + shape_output[dim] = 1 + shape_output = tuple(shape_output) + # Use update_exhaustive_strides to compute new strides + strides_output = update_exhaustive_strides(x.shape, x.stride(), shape_output) + if not keepdim: + shape_output_reduced = [] + strides_output_reduced = [] + for i_dim, dim in enumerate(shape_output): + if i_dim not in reduced_dims: + shape_output_reduced.append(dim) + strides_output_reduced.append(strides_output[i_dim]) + shape_output = tuple(shape_output_reduced) + strides_output = tuple(strides_output_reduced) + + out = DTensor.from_local( + output_local, + device_mesh=device_mesh, + placements=output_placements, + shape=shape_output, + stride=strides_output, + ) + return out + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[DTensor | None, None, None]: + """Backward pass of distributed sharded summation aggregation. + + The gradient of sum(x, dims) with respect to x is simple broadcasting: + - grad_output is broadcasted to the original tensor shape + - Since we only reduce over sharded dimensions, the gradient is just broadcasted back + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object containing saved tensors and metadata from forward pass. + grad_output : DTensor + Gradients of the loss with respect to the output tensor. + + Returns + ------- + tuple[DTensor | None, None, None] + Gradients with respect to x, dims, and keepdim. + """ + if not isinstance(grad_output, DTensor): + raise TypeError(f"Input 'grad_output' must be of type DTensor but got type {type(grad_output)}.") + + if grad_output.device_mesh != ctx.device_mesh: + raise ValueError( + f"Input 'grad_output' must have the same device mesh as the input tensor. " + f"Got device meshes {grad_output.device_mesh} and {ctx.device_mesh}." + ) + + dx_local = grad_output.to_local() + if not ctx.keepdim: + for dim in ctx.reduced_dims: + dx_local = dx_local.unsqueeze(dim) + dx_local = dx_local.expand(ctx.input_local_shape).clone(memory_format=torch.contiguous_format) + + dx = DTensor.from_local( + dx_local, + device_mesh=ctx.device_mesh, + placements=ctx.input_placements, + shape=ctx.input_shape, + stride=ctx.input_stride, + ) + return dx, None, None + + +def sharded_sum(x: DTensor, dim: tuple[int, ...] | int, keepdim: bool = False) -> DTensor: + """Perform sharded summation aggregation. + + Behave similarly to torch.sum but expect all reduced dimensions to be sharded. + + Parameters + ---------- + x: DTensor + Input distributed tensor. + dim: tuple[int, ...] | int + Dimensions to reduce over. + keepdim: bool, default=False + Whether to keep the reduced dimensions in the output. + + Returns + ------- + DTensor + Output distributed tensor with dimensions and placements reduced. + """ + return _ShardedSumImpl.apply(x, dim, keepdim) # type: ignore diff --git a/src/boltz/distributed/model/layers/shardwise_op.py b/src/boltz/distributed/model/layers/shardwise_op.py new file mode 100644 index 000000000..cb704dd5d --- /dev/null +++ b/src/boltz/distributed/model/layers/shardwise_op.py @@ -0,0 +1,1368 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from enum import Enum +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from torch.autograd.function import FunctionCtx +from torch.distributed.tensor import DTensor, Partial, Shard + +from boltz.distributed.model.layers.dtensor_metadata_tools import ( + raise_if_incorrect_dtensor_metadata_args, +) +from boltz.distributed.utils import LayoutRightMap, update_exhaustive_strides + + +class _ShardwiseSumImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx: FunctionCtx, + x: DTensor, + dim: int, + keepdim: Optional[bool] = None, + ) -> DTensor: + """Forward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object for saving information needed in backward pass. + x : DTensor + Input DTensor. + dim : int + The dimension to sum over. + keepdim : Optional[bool] + Whether to keep the dimension when summing. + + Returns + ------- + DTensor + DTensor after sum operation. + """ + # Type checking + if not isinstance(x, DTensor): + raise TypeError(f"Expected DTensor, got {type(x)}") + if not isinstance(dim, int): + raise TypeError(f"Expected int for dim, got {type(dim)}") + + device_mesh_input = x.device_mesh + placements_input = x.placements + + # Check placements and handle sharded dimensions + actual_dim = dim if dim >= 0 else len(x.shape) + dim + for i_dim_device_mesh, placement in enumerate(placements_input): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + # Check that sharded dimensions are evenly divided + if x.shape[placement.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {x.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + if placement.dim == actual_dim: + raise ValueError(f"Sum along sharded dimension {dim} is not supported") + + x_local = x.to_local() + + # Perform operation on local tensors + output_local = ( + torch.sum(x_local, dim=dim, keepdim=keepdim) if keepdim is not None else torch.sum(x_local, dim=dim) + ) + + shape_output = list(x.shape) + shape_output[actual_dim] = 1 + + # keep the layout mapping but with a new shape + strides_output = update_exhaustive_strides(x.shape, x.stride(), shape_output) + if not keepdim: + # remove the singleton dimension + shape_output = shape_output[:actual_dim] + shape_output[actual_dim + 1 :] + strides_output = strides_output[:actual_dim] + strides_output[actual_dim + 1 :] + + # Create output DTensor using input tensor's device mesh and placements + result: DTensor = DTensor.from_local( + output_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=tuple(shape_output), + stride=strides_output, + ) + + # Save information for backward pass + if x.requires_grad: + ctx.dim = dim + ctx.keepdim = keepdim + ctx.input_local_shape = x_local.shape + ctx.device_mesh_input = device_mesh_input + ctx.placements_input = placements_input + ctx.output_shape = result.shape + ctx.shape_input = x.shape + ctx.stride_input = x.stride() + + return result + + @staticmethod + def backward( + ctx: FunctionCtx, + *grad_outputs, + ) -> tuple[DTensor, None, None]: + """Backward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object containing saved tensors and metadata from forward pass. + grad_outputs : tuple + Gradients of the loss with respect to the output. + + Returns + ------- + tuple[DTensor, None, None] + Gradient with respect to input, None for dim and keepdim parameters. + """ + grad_output = grad_outputs[0] + + # Check that grad_output has the expected shape, device_mesh and placements + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=grad_output, + dtensor_name="_ShardwiseSumImpl.backward grad_output", + expected_shape=ctx.output_shape, + expected_device_mesh=ctx.device_mesh_input, + expected_placements=ctx.placements_input, + ) + + # Perform backward pass on local tensors + grad_output_local = grad_output.to_local() + + # For sum operation, gradient is broadcasted back to original shape + input_local_shape = list(ctx.input_local_shape) + if ctx.keepdim: + dx_local = grad_output_local.expand(input_local_shape) + else: + grad_expanded = grad_output_local.unsqueeze(ctx.dim) + dx_local = grad_expanded.expand(input_local_shape) + + # Create output DTensor using the saved metadata + grad_input = DTensor.from_local( + dx_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=ctx.shape_input, + stride=ctx.stride_input, + ) + + return grad_input, None, None + + +class _ShardwiseOneHotImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx: FunctionCtx, + input: DTensor, + num_classes: int = -1, + ) -> DTensor: + """Forward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object for saving information needed in backward pass. + input : DTensor + Input DTensor containing class indices. + num_classes : int + Number of classes for one-hot encoding. If -1, inferred from input. + + Returns + ------- + DTensor + DTensor after one-hot encoding. + """ + # Type checking + if not isinstance(input, DTensor): + raise TypeError(f"Expected DTensor, got {type(input)}") + if not isinstance(num_classes, int): + raise TypeError(f"Expected int for num_classes, got {type(num_classes)}") + + device_mesh_input = input.device_mesh + placements_input = input.placements + + # Check placements and handle sharded dimensions + for i_dim_device_mesh, placement in enumerate(placements_input): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + # Check that sharded dimensions are evenly divided + if input.shape[placement.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {input.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + + input_local = input.to_local() + + # Perform one-hot operation on local tensors + output_local = F.one_hot(input_local, num_classes=num_classes) + + # Compute output shape and stride (one-hot adds a dimension at the end) + shape_output = input.shape + (output_local.shape[-1],) + + # For one-hot, we append a new dimension, so we can use LayoutRightMap + layout_right = LayoutRightMap(shape_output) + strides_output = layout_right.strides + + # Create output DTensor + result: DTensor = DTensor.from_local( + output_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=shape_output, + stride=strides_output, + ) + ctx.mark_non_differentiable(result) + + return result + + @staticmethod + def backward( + ctx: FunctionCtx, + *grad_outputs, + ) -> tuple[None, None]: + """Backward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object containing saved tensors and metadata from forward pass. + grad_outputs : tuple + Gradients of the loss with respect to the output. + + Returns + ------- + tuple[None, None] + None gradients for input and num_classes (one_hot is not differentiable w.r.t. indices). + """ + # one_hot is not differentiable with respect to the input indices + # Return None for both input and num_classes gradients + return None, None + + +class _ShardwiseDistogramImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx: FunctionCtx, + d: DTensor, + boundaries: torch.Tensor, + ) -> DTensor: + """Forward pass: bin distances into a distogram. + + Computes ``(d.unsqueeze(-1) > boundaries).sum(dim=-1).long()`` element-wise + on the local shard and wraps the result back into a DTensor with the same + placements and shape as the input. + + Parameters + ---------- + ctx : FunctionCtx + Context object for saving information needed in backward pass. + d : DTensor + Input DTensor of pairwise distances. + boundaries : torch.Tensor + 1-D tensor of bin boundaries (not a DTensor). + + Returns + ------- + DTensor + Long DTensor of bin indices, same shape and placements as ``d``. + """ + device_mesh = d.device_mesh + placements = d.placements + + d_local = d.to_local() + output_local = (d_local.unsqueeze(-1) > boundaries).sum(dim=-1).long() + + stride_output = update_exhaustive_strides(output_local.shape, output_local.stride(), d.shape) + + result: DTensor = DTensor.from_local( + output_local, + device_mesh=device_mesh, + placements=placements, + shape=d.shape, + stride=stride_output, + ) + ctx.mark_non_differentiable(result) + + return result + + @staticmethod + def backward( + ctx: FunctionCtx, + *grad_outputs, + ) -> tuple[None, None]: + """Backward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object containing saved tensors and metadata from forward pass. + grad_outputs : tuple + Gradients of the loss with respect to the output. + + Returns + ------- + tuple[None, None] + None gradients for d and boundaries (distogram binning is not differentiable). + """ + return None, None + + +class _ShardwiseSoftmaxImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx: FunctionCtx, + x: DTensor, + dim: int = -1, + ) -> DTensor: + """Forward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object for saving information needed in backward pass. + x : DTensor + Input DTensor. + dim : int + The dimension to apply softmax over. + """ + device_mesh_input = x.device_mesh + placements_input = x.placements + + # Check placements and handle sharded dimensions + actual_dim = dim if dim >= 0 else len(x.shape) + dim + for i_dim_device_mesh, placement in enumerate(placements_input): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + # Check that the softmax dim is not sharded - must be Replicate() + if placement.dim == actual_dim: + raise ValueError(f"Softmax along sharded dimension {dim} is not supported") + + # Check that sharded dimensions are evenly divided + if x.shape[placement.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {x.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + + x_local = x.to_local().detach().requires_grad_(x.requires_grad) + + with torch.enable_grad(): + output_local = F.softmax(x_local, dim=dim) + + shape_output = x.shape + stride_output = x.stride() + + # Save information for backward pass + if x.requires_grad: + ctx.dim = dim + ctx.device_mesh_input = device_mesh_input + ctx.placements_input = placements_input + ctx.output_shape = shape_output + ctx.save_for_backward(x_local, output_local) # need x_local here for autograd.grad() in bwd + + output: DTensor = DTensor.from_local( + output_local.detach(), + device_mesh=device_mesh_input, + placements=placements_input, + shape=shape_output, + stride=stride_output, + ) + + return output + + @staticmethod + def backward( + ctx: FunctionCtx, + grad_output: DTensor, + ) -> tuple[DTensor, None]: + """Backward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object containing saved tensors and metadata from forward pass. + grad_output : DTensor + Gradient of the loss with respect to the output. + + Returns + ------- + tuple[DTensor, None] + Gradient with respect to input, None for dim parameter. + """ + # Check that grad_output has the expected shape, device_mesh and placements + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=grad_output, + dtensor_name="_ShardwiseSoftmaxImpl.backward grad_output", + expected_shape=ctx.output_shape, + expected_device_mesh=ctx.device_mesh_input, + expected_placements=ctx.placements_input, + ) + + # Perform backward pass on local tensors using saved subgraph + grad_output_local = grad_output.to_local() + x_local, softmax_output_local = ctx.saved_tensors + + (d_x_local,) = torch.autograd.grad( + outputs=[softmax_output_local], + inputs=[x_local], + grad_outputs=[grad_output_local], + retain_graph=False, + ) + + # Create output DTensor using the saved metadata + grad_input = DTensor.from_local( + d_x_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + + return grad_input, None + + +class _ShardwiseLogSoftmaxImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx: FunctionCtx, + x: DTensor, + dim: int = -1, + ) -> DTensor: + """Forward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object for saving information needed in backward pass. + x : DTensor + Input DTensor. + dim : int + The dimension to apply log_softmax over. + + Returns + ------- + DTensor + DTensor after log_softmax operation. + """ + # Type checking + if not isinstance(x, DTensor): + raise TypeError(f"Expected DTensor, got {type(x)}") + if not isinstance(dim, int): + raise TypeError(f"Expected int for dim, got {type(dim)}") + + device_mesh_input = x.device_mesh + placements_input = x.placements + + # Check placements and handle sharded dimensions + actual_dim = dim if dim >= 0 else len(x.shape) + dim + for i_dim_device_mesh, placement in enumerate(placements_input): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + # Check that sharded dimensions are evenly divided + if x.shape[placement.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {x.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + if placement.dim == actual_dim: + raise ValueError(f"Log_softmax along sharded dimension {dim} is not supported, must be Replicate") + + x_local = x.to_local() + + # Perform operation on local tensors + output_local = F.log_softmax(x_local, dim=dim) + + # Create output DTensor using input tensor's device mesh and placements + shape_output = x.shape + stride_output = x.stride() + output: DTensor = DTensor.from_local( + output_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=shape_output, + stride=stride_output, + ) + + # Save information for backward pass + if x.requires_grad: + ctx.dim = dim + ctx.device_mesh_input = device_mesh_input + ctx.placements_input = placements_input + ctx.output_shape = output.shape + ctx.save_for_backward(output_local) + + return output + + @staticmethod + def backward( + ctx: FunctionCtx, + grad_output: DTensor, + ) -> tuple[DTensor, None]: + """Backward pass. + + Parameters + ---------- + ctx : FunctionCtx + Context object containing saved tensors and metadata from forward pass. + grad_outputs : tuple + Gradients of the loss with respect to the output. + + Returns + ------- + tuple[DTensor, None] + Gradient with respect to input, None for dim parameter. + """ + # Check that grad_output has the expected shape, device_mesh and placements + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=grad_output, + dtensor_name="_ShardwiseLogSoftmaxImpl.backward grad_output", + expected_shape=ctx.output_shape, + expected_device_mesh=ctx.device_mesh_input, + expected_placements=ctx.placements_input, + ) + + # Perform backward pass on local tensors + grad_output_local = grad_output.to_local() + (log_softmax_output_local,) = ctx.saved_tensors + + # For log_softmax(x), the gradient is: + # grad_input = grad_output - exp(log_softmax(x)) * sum(grad_output, dim=dim, keepdim=True) + # This is equivalent to: grad_input = grad_output - softmax(x) * sum(grad_output, dim=dim, keepdim=True) + grad_sum = grad_output_local.sum(dim=ctx.dim, keepdim=True) + dx_local = grad_output_local - torch.exp(log_softmax_output_local) * grad_sum + + # Create output DTensor using the saved metadata + grad_input = DTensor.from_local( + dx_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + + return grad_input, None + + +class _ShardwiseArgmaxImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx: FunctionCtx, + x: DTensor, + dim: int, + keepdim: Optional[bool] = None, + ) -> DTensor: + """Forward pass for shardwise argmax.""" + if not isinstance(x, DTensor): + raise TypeError(f"Expected DTensor, got {type(x)}") + if not isinstance(dim, int): + raise TypeError(f"Expected int for dim, got {type(dim)}") + + device_mesh_input = x.device_mesh + placements_input = x.placements + + actual_dim = dim if dim >= 0 else len(x.shape) + dim + + # Validate placements and sharding + for i_dim_device_mesh, placement in enumerate(placements_input): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + if x.shape[placement.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {x.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + if placement.dim == actual_dim: + raise ValueError(f"Argmax along sharded dimension {dim} is not supported") + + x_local = x.to_local() + if keepdim is None: + output_local = torch.argmax(x_local, dim=dim) + else: + output_local = torch.argmax(x_local, dim=dim, keepdim=keepdim) + + shape_output = list(x.shape) + shape_output[actual_dim] = 1 + strides_output = update_exhaustive_strides(x.shape, x.stride(), shape_output) + if not keepdim: + shape_output = shape_output[:actual_dim] + shape_output[actual_dim + 1 :] + strides_output = strides_output[:actual_dim] + strides_output[actual_dim + 1 :] + + result: DTensor = DTensor.from_local( + output_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=tuple(shape_output), + stride=strides_output, + ) + ctx.mark_non_differentiable(result) + return result + + @staticmethod + def backward(ctx: FunctionCtx, *grad_outputs) -> tuple[None, None, None]: + return None, None, None + + +class _ShardwiseOffsetImpl(torch.autograd.Function): + @staticmethod + def forward( + ctx: FunctionCtx, + x: DTensor, + dim: int, + offset_per_rank: Any, + ) -> DTensor: + """Forward pass for shardwise offset. + + This function adds a rank-dependent offset to each shard of the input tensor. + The offset for each shard is computed as: rank_on_mesh_axis * offset_per_rank, + where rank_on_mesh_axis is the rank of the current process along the device mesh + axis that shards the specified dimension. + + Parameters + ---------- + ctx : FunctionCtx + Context object for saving information needed in backward pass. + x : DTensor + Input DTensor with dimension `dim` sharded. + dim : int + The dimension that must be sharded. + offset_per_rank : Any + The offset value per rank. Can be a scalar or tensor that broadcasts + with x_local. + + Returns + ------- + DTensor + DTensor with offset applied: x + rank * offset_per_rank + + Raises + ------ + TypeError + If inputs are not of correct type. + ValueError + If the specified dimension is not sharded, partial placements exist, + or uneven sharding is detected. + """ + # Type checking + if not isinstance(x, DTensor): + raise TypeError(f"Expected DTensor, got {type(x)}") + if not isinstance(dim, int): + raise TypeError(f"Expected int for dim, got {type(dim)}") + + device_mesh_input = x.device_mesh + placements_input = x.placements + + # Normalize negative dim + actual_dim = dim if dim >= 0 else len(x.shape) + dim + + # Find which device_mesh axis shards the specified dim + mesh_axis_for_dim = None + for i_dim_device_mesh, placement in enumerate(placements_input): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + # Check that sharded dimensions are evenly divided + if x.shape[placement.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {x.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + if placement.dim == actual_dim: + mesh_axis_for_dim = i_dim_device_mesh + + # Check that the specified dimension IS sharded + if mesh_axis_for_dim is None: + raise ValueError(f"Dimension {dim} must be sharded for shardwise_offset, but it is not") + + x_local = x.to_local() + + # Get the rank along the mesh axis that shards dim + rank_on_mesh_axis = device_mesh_input.get_local_rank(mesh_axis_for_dim) + + # Compute offset: x_local + rank * offset_per_rank + output_local = x_local + rank_on_mesh_axis * offset_per_rank + + # Create output DTensor with same shape and stride as input + result: DTensor = DTensor.from_local( + output_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=x.shape, + stride=x.stride(), + ) + + return result + + @staticmethod + def backward( + ctx: FunctionCtx, + grad_output: DTensor, + ) -> tuple[DTensor, None, None]: + """Backward pass for shardwise offset. + + Since the offset is a constant addition (rank * offset_per_rank is constant + for each shard), the gradient passes through unchanged. + + Parameters + ---------- + ctx : FunctionCtx + Context object containing saved tensors and metadata from forward pass. + grad_output : DTensor + Gradient of the loss with respect to the output. + + Returns + ------- + tuple[DTensor, None, None] + Gradient with respect to input (same as grad_output), None for dim + and offset_per_rank parameters. + """ + # Gradient passes through unchanged since offset is constant + return grad_output, None, None + + +def shardwise_offset(x: DTensor, dim: int, offset_per_rank: Any) -> DTensor: + """Add a rank-dependent offset to a DTensor along a sharded dimension. + + This function adds an offset to each shard based on its rank along the + device mesh axis that shards the specified dimension. The offset for + each shard is: rank_on_mesh_axis * offset_per_rank. + + Parameters + ---------- + x : DTensor + Input DTensor with dimension `dim` sharded. + dim : int + The dimension that must be sharded. The function identifies which + device mesh axis shards this dimension and uses the rank along that + axis to compute the offset. + offset_per_rank : Any + The offset value per rank. Can be a scalar or tensor that broadcasts + with x_local. The actual offset added is rank * offset_per_rank. + + Returns + ------- + DTensor + DTensor with offset applied: x + rank * offset_per_rank + + Raises + ------ + TypeError + If inputs are not of correct type. + ValueError + If the specified dimension is not sharded, partial placements exist, + or uneven sharding is detected. + + Examples + -------- + >>> # Example: Adding position offsets to sharded sequence indices + >>> # If dim 1 is sharded across 2 ranks with local size 100: + >>> # - Rank 0: adds 0 * offset_per_rank to its shard + >>> # - Rank 1: adds 1 * offset_per_rank to its shard + >>> indices = ... # DTensor with dim 1 sharded, local shape (B, 100, D) + >>> offset_per_rank = 100 # Each rank's shard represents 100 elements + >>> global_indices = shardwise_offset(indices, dim=1, offset_per_rank=offset_per_rank) + """ + return _ShardwiseOffsetImpl.apply(x, dim, offset_per_rank) + + +def shardwise_sum(x: DTensor, dim: int, keepdim: Optional[bool] = None) -> DTensor: + """Sum elements of a DTensor along a specified dimension. + + This function sums elements of a DTensor along the specified dimension. + The sum operation is performed on local tensor chunks while maintaining + gradient computation capabilities. + + Parameters + ---------- + x : DTensor + Input DTensor to sum. + dim : int + Dimension to sum along. + keepdim : Optional[bool] + Whether to keep the dimension when summing. + + Returns + ------- + DTensor + DTensor after sum operation. + + Raises + ------ + TypeError + If inputs are not of correct type. + ValueError + If validation errors occur, such as summing along a sharded dimension. + """ + return _ShardwiseSumImpl.apply(x, dim, keepdim) + + +def shardwise_one_hot(input: DTensor, num_classes: int = -1) -> DTensor: + """One-hot encode a DTensor of class indices. + + This function performs one-hot encoding on a DTensor of class indices. + The operation is performed on local tensor chunks while maintaining + the distributed tensor structure. The new one-hot dimension is added + as the last dimension and is replicated across all devices. + + Parameters + ---------- + input : DTensor + Input DTensor containing class indices (integer values). + num_classes : int + Number of classes for one-hot encoding. If -1, inferred from input. + + Returns + ------- + DTensor + DTensor after one-hot encoding with shape input.shape + (num_classes,). + + Raises + ------ + TypeError + If inputs are not of correct type. + ValueError + If validation errors occur, such as partial placements or uneven sharding. + + Notes + ----- + The one-hot operation is not differentiable with respect to the input indices, + so gradients will be None for the input tensor. + """ + return _ShardwiseOneHotImpl.apply(input, num_classes) + + +def shardwise_distogram(d: DTensor, boundaries: torch.Tensor) -> DTensor: + """Bin pairwise distances into a distogram. + + Computes ``(d.unsqueeze(-1) > boundaries).sum(dim=-1).long()`` element-wise + on local shards while preserving the DTensor placements and shape. + + Parameters + ---------- + d : DTensor + Input DTensor of pairwise distances. + boundaries : torch.Tensor + 1-D tensor of bin boundaries (regular tensor, not a DTensor). + + Returns + ------- + DTensor + Long DTensor of bin indices, same shape and placements as ``d``. + + Notes + ----- + The distogram binning operation is not differentiable, so gradients + will be None for both inputs. + """ + return _ShardwiseDistogramImpl.apply(d, boundaries) + + +def shardwise_softmax(x: DTensor, dim: int = -1) -> DTensor: + """Apply softmax to a DTensor along a specified dimension. + + Parameters + ---------- + x : DTensor + Input DTensor to apply softmax to. + dim : int + Dimension to apply softmax over. Default is -1 (last dimension). + + Returns + ------- + DTensor + DTensor after softmax operation. + """ + return _ShardwiseSoftmaxImpl.apply(x, dim) + + +def shardwise_log_softmax(x: DTensor, dim: int = -1) -> DTensor: + """Apply log_softmax to a DTensor along a specified dimension. + + This function applies log_softmax to a DTensor along the specified dimension. + The log_softmax operation is performed on local tensor chunks while maintaining + gradient computation capabilities. + + Parameters + ---------- + x : DTensor + Input DTensor to apply log_softmax to. + dim : int + Dimension to apply log_softmax over. Default is -1 (last dimension). + + Returns + ------- + DTensor + DTensor after log_softmax operation. + + Raises + ------ + TypeError + If inputs are not of correct type. + ValueError + If validation errors occur, such as applying log_softmax along a sharded dimension. + + Notes + ----- + The log_softmax operation is differentiable and supports gradient computation. + The operation cannot be applied along sharded dimensions as it requires + communication across devices to compute the softmax normalization. + """ + return _ShardwiseLogSoftmaxImpl.apply(x, dim) + + +def shardwise_argmax(x: DTensor, dim: int, keepdim: Optional[bool] = None) -> DTensor: + """ + Compute argmax of a DTensor along a specified dimension (per shard). + + This is a shard-local argmax: it does not communicate across shards, so the + resulting indices correspond to the local shard slices. Use only when the + reduced dimension is not sharded. + + Parameters + ---------- + x : DTensor + Input DTensor. + dim : int + Dimension to take argmax over. + keepdim : Optional[bool] + Whether to retain the reduced dimension with size 1. + + Returns + ------- + DTensor + DTensor of dtype long containing indices of the local argmax. + + Raises + ------ + TypeError + If inputs are not of correct type. + ValueError + If validation errors occur, such as argmax along a sharded dimension + or partial placements. + """ + return _ShardwiseArgmaxImpl.apply(x, dim, keepdim) + + +class ShardwiseOuterOp(Enum): + """Supported operations for shardwise outer operations. + + These operations support broadcasting between tensors with singleton + dimensions for computing pairwise operations efficiently. + """ + + SUBTRACT = "subtract" + """Element-wise subtraction: x - y. Differentiable.""" + + ADD = "add" + """Element-wise addition: x + y. Differentiable.""" + + LOGICAL_AND = "logical_and" + """Element-wise logical AND: x & y. Non-differentiable.""" + + EQUAL = "equal" + """Element-wise equality: x == y. Non-differentiable.""" + + +class _ShardwiseOuterOpImpl(torch.autograd.Function): + """Unified shardwise outer operation at a specified axis. + + This autograd function handles all outer operations (subtract, addition, logical_and, equal) + with a shared code path, differing only in the actual math operation performed. + Supports gradient computation for differentiable operations (SUBTRACT, ADD). + + The operation computes pairwise combinations at the specified axis: + - x: (..., L, ...) at axis + - y: (..., R, ...) at axis + - Result: (..., L, R, ...) with one additional dimension + """ + + @staticmethod + def forward( + ctx: FunctionCtx, + x: DTensor, + y: DTensor, + op: ShardwiseOuterOp, + axis: int, + ) -> DTensor: + """Forward pass for outer operations at specified axis. + + Parameters + ---------- + ctx : FunctionCtx + Context object for saving information needed in backward pass. + x : DTensor + First input tensor with shape (..., L, ...) where L is at position `axis`. + y : DTensor + Second input tensor with shape (..., R, ...) where R is at position `axis`. + op : ShardwiseOuterOp + The operation to perform (SUBTRACT, ADD, LOGICAL_AND, or EQUAL). + axis : int + The axis at which to perform the outer operation. + + Returns + ------- + DTensor + Result of the operation with shape (..., L, R, ...). + The output has one more dimension than the inputs. + + Raises + ------ + TypeError + If inputs are not DTensors or axis is not an int. + ValueError + If device_mesh, placements don't match, or an unsupported operation is specified. + """ + # ========== Type checking ========== + if not isinstance(x, DTensor): + raise TypeError(f"shardwise_outer_op: Expected DTensor for x, got {type(x)}") + if not isinstance(y, DTensor): + raise TypeError(f"shardwise_outer_op: Expected DTensor for y, got {type(y)}") + if not isinstance(op, ShardwiseOuterOp): + raise TypeError(f"shardwise_outer_op: Expected ShardwiseOuterOp for op, got {type(op)}") + if not isinstance(axis, int): + raise TypeError(f"shardwise_outer_op: Expected int for axis, got {type(axis)}") + + # ========== Validate device_mesh and placements match ========== + if x.device_mesh != y.device_mesh: + raise ValueError("shardwise_outer_op: x and y must have the same device_mesh") + if x.placements != y.placements: + raise ValueError("shardwise_outer_op: x and y must have the same placements") + + device_mesh = x.device_mesh + placements_input = x.placements + + # ========== Validate placements ========== + for i_dim_device_mesh, placement in enumerate(placements_input): + if isinstance(placement, Partial): + raise ValueError("shardwise_outer_op: Partial placements are not supported") + elif isinstance(placement, Shard): + # The outer operation axis must not be sharded - this is a shardwise op + # so the outer product must be computed locally without cross-shard communication + if placement.dim == axis: + raise ValueError( + f"shardwise_outer_op: Cannot shard dimension {axis} (the outer operation axis) " + f"with Shard({placement.dim}). The outer operation must be computed locally " + f"on each shard without cross-shard communication." + ) + + x_dim_size = x.shape[placement.dim] + y_dim_size = y.shape[placement.dim] + mesh_dim_size = device_mesh.shape[i_dim_device_mesh] + + if x_dim_size % mesh_dim_size != 0: + raise ValueError( + f"shardwise_outer_op: Uneven sharding of x tensor dimension {placement.dim} " + f"of size {x_dim_size} along device mesh dimension " + f"{i_dim_device_mesh} of size {mesh_dim_size} is not supported" + ) + if y_dim_size % mesh_dim_size != 0: + raise ValueError( + f"shardwise_outer_op: Uneven sharding of y tensor dimension {placement.dim} " + f"of size {y_dim_size} along device mesh dimension " + f"{i_dim_device_mesh} of size {mesh_dim_size} is not supported" + ) + + # ========== Get local tensors and unsqueeze ========== + x_local = x.to_local() + y_local = y.to_local() + + # Unsqueeze local tensors to create broadcast-compatible shapes + # x: (..., L, ...) → (..., L, 1, ...) (insert singleton after axis) + # y: (..., R, ...) → (..., 1, R, ...) (insert singleton at axis) + x_local = x_local.unsqueeze(axis + 1) + y_local = y_local.unsqueeze(axis) + + # Compute output shape: (..., L, R, ...) + shape_output = list(x.shape) + shape_output.insert(axis + 1, y.shape[axis]) + shape_output = tuple(shape_output) + + # Adjust placements for the new dimension + # Shard dimensions > axis need to be incremented + placements_output = list(placements_input) + for i_dim_device_mesh, p in enumerate(placements_input): + if isinstance(p, Shard) and p.dim > axis: + placements_output[i_dim_device_mesh] = Shard(p.dim + 1) + placements_output = tuple(placements_output) + + # ========== Compute output stride ========== + layout_right = LayoutRightMap(shape_output) + stride_output = layout_right.strides + + # ========== Perform operation based on op type ========== + if op == ShardwiseOuterOp.SUBTRACT: + output_local = x_local - y_local + elif op == ShardwiseOuterOp.ADD: + output_local = x_local + y_local + elif op == ShardwiseOuterOp.LOGICAL_AND: + output_local = x_local & y_local + elif op == ShardwiseOuterOp.EQUAL: + output_local = x_local == y_local + else: + raise ValueError(f"shardwise_outer_op: Unsupported operation: {op}") + + # ========== Create output DTensor ========== + result = DTensor.from_local( + output_local, + device_mesh=device_mesh, + placements=placements_output, + shape=shape_output, + stride=stride_output, + ) + + # ========== Handle gradient context ========== + is_differentiable = op in (ShardwiseOuterOp.SUBTRACT, ShardwiseOuterOp.ADD) + if is_differentiable and (x.requires_grad or y.requires_grad): + # Differentiable ops (SUBTRACT, ADD) - save context for backward + ctx.op = op + ctx.axis = axis + ctx.device_mesh = device_mesh + ctx.placements_input = placements_input + ctx.placements_output = placements_output + ctx.x_shape = x.shape + ctx.y_shape = y.shape + ctx.x_stride = x.stride() + ctx.y_stride = y.stride() + ctx.output_shape = shape_output + ctx.x_requires_grad = x.requires_grad + ctx.y_requires_grad = y.requires_grad + else: + # Non-differentiable operations or no grad required + ctx.mark_non_differentiable(result) + ctx.op = op + + return result + + @staticmethod + def backward(ctx: FunctionCtx, grad_output: DTensor) -> tuple[Optional[DTensor], Optional[DTensor], None, None]: + """Backward pass for outer operations. + + Differentiable operations: + - SUBTRACT: For z = x - y: + grad_x = grad_z (summed over axis+1 where x was broadcast) + grad_y = -grad_z (summed over axis where y was broadcast) + - ADD: For z = x + y: + grad_x = grad_z (summed over axis+1 where x was broadcast) + grad_y = grad_z (summed over axis where y was broadcast) + + Parameters + ---------- + ctx : FunctionCtx + Context object with saved information from forward pass. + grad_output : DTensor + Gradient with respect to the output. + + Returns + ------- + tuple[Optional[DTensor], Optional[DTensor], None, None] + Gradients for x, y, None for op parameter, and None for axis parameter. + """ + grad_x = None + grad_y = None + + # ========== Op-specific gradient computation ========== + if ctx.op == ShardwiseOuterOp.SUBTRACT: + # For z = x - y: grad_x = grad_z, grad_y = -grad_z + # Both need to be summed over their broadcast dimensions + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=grad_output, + dtensor_name="_ShardwiseOuterOpImpl.backward grad_output", + expected_shape=ctx.output_shape, + expected_device_mesh=ctx.device_mesh, + expected_placements=ctx.placements_output, + ) + + grad_output_local = grad_output.to_local() + + if ctx.x_requires_grad: + # Sum over axis+1 (where x was broadcast) and squeeze + grad_x_local = grad_output_local.sum(dim=ctx.axis + 1, keepdim=False) + grad_x = DTensor.from_local( + grad_x_local, + device_mesh=ctx.device_mesh, + placements=ctx.placements_input, + shape=ctx.x_shape, + stride=ctx.x_stride, + ) + + if ctx.y_requires_grad: + # Sum over axis (where y was broadcast), negate, and squeeze + grad_y_local = -grad_output_local.sum(dim=ctx.axis, keepdim=False) + grad_y = DTensor.from_local( + grad_y_local, + device_mesh=ctx.device_mesh, + placements=ctx.placements_input, + shape=ctx.y_shape, + stride=ctx.y_stride, + ) + + elif ctx.op == ShardwiseOuterOp.ADD: + # For z = x + y: grad_x = grad_z, grad_y = grad_z + # Both need to be summed over their broadcast dimensions + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=grad_output, + dtensor_name="_ShardwiseOuterOpImpl.backward grad_output", + expected_shape=ctx.output_shape, + expected_device_mesh=ctx.device_mesh, + expected_placements=ctx.placements_output, + ) + + grad_output_local = grad_output.to_local() + + if ctx.x_requires_grad: + # Sum over axis+1 (where x was broadcast) and squeeze + grad_x_local = grad_output_local.sum(dim=ctx.axis + 1, keepdim=False) + grad_x = DTensor.from_local( + grad_x_local, + device_mesh=ctx.device_mesh, + placements=ctx.placements_input, + shape=ctx.x_shape, + stride=ctx.x_stride, + ) + + if ctx.y_requires_grad: + # Sum over axis (where y was broadcast) and squeeze (no negation for addition) + grad_y_local = grad_output_local.sum(dim=ctx.axis, keepdim=False) + grad_y = DTensor.from_local( + grad_y_local, + device_mesh=ctx.device_mesh, + placements=ctx.placements_input, + shape=ctx.y_shape, + stride=ctx.y_stride, + ) + + elif ctx.op == ShardwiseOuterOp.LOGICAL_AND: + # Non-differentiable + pass + + elif ctx.op == ShardwiseOuterOp.EQUAL: + # Non-differentiable + pass + + else: + raise ValueError(f"shardwise_outer_op backward: Unsupported operation: {ctx.op}") + + return grad_x, grad_y, None, None + + +def shardwise_outer_op(lhs: DTensor, rhs: DTensor, axis: int, op: ShardwiseOuterOp) -> DTensor: + """Compute outer operation at specified axis. + + This function performs an outer operation (subtract, logical_and, equal) between + two tensors at a specified axis. The operation creates pairwise combinations + of elements along the specified axis. + + The function internally unsqueezes the inputs to create broadcast-compatible + shapes and then performs the operation: + - lhs: (..., L, ...) → unsqueeze to (..., L, 1, ...) + - rhs: (..., R, ...) → unsqueeze to (..., 1, R, ...) + - Result: (..., L, R, ...) + + Parameters + ---------- + lhs : DTensor + First input tensor with shape (..., L, ...) where L is at position `axis`. + rhs : DTensor + Second input tensor with shape (..., R, ...) where R is at position `axis`. + axis : int + The axis at which to perform the outer operation. Must be a valid + dimension index for both tensors. + op : ShardwiseOuterOp + The operation to perform: + - SUBTRACT: lhs - rhs (differentiable) + - LOGICAL_AND: lhs & rhs (non-differentiable, boolean inputs) + - EQUAL: lhs == rhs (non-differentiable) + + Returns + ------- + DTensor + Result tensor with shape (..., L, R, ...) where: + - L (from lhs) is at position `axis` + - R (from rhs) is at position `axis + 1` + - All other dimensions match the input dimensions. + + Raises + ------ + TypeError + If inputs are not DTensors or op is not a ShardwiseOuterOp. + ValueError + If device_mesh, placements don't match, axis is invalid, or shapes + are incompatible for the outer operation. + + Examples + -------- + >>> # Window batching example + >>> lhs = ... # DTensor with shape (B, K, W, D) + >>> rhs = ... # DTensor with shape (B, K, H, D) + >>> result = shardwise_outer_op(lhs, rhs, axis=2, op=ShardwiseOuterOp.SUBTRACT) + >>> # result has shape (B, K, W, H, D) + + >>> # Computing pairwise differences + >>> queries = ... # DTensor with shape (B, N, D) + >>> keys = ... # DTensor with shape (B, M, D) + >>> diffs = shardwise_outer_op(queries, keys, axis=1, op=ShardwiseOuterOp.SUBTRACT) + >>> # diffs has shape (B, N, M, D) + + Notes + ----- + This function provides a cleaner API compared to manually unsqueezing + tensors before calling the operation. It is the preferred way to compute + outer operations when the input tensors don't already have singleton + dimensions for broadcasting. + """ + # Validate axis bounds + if not isinstance(axis, int): + raise TypeError(f"shardwise_outer_op: Expected int for axis, got {type(axis)}") + + if not isinstance(lhs, DTensor): + raise TypeError(f"shardwise_outer_op: Expected DTensor for lhs, got {type(lhs)}") + if not isinstance(rhs, DTensor): + raise TypeError(f"shardwise_outer_op: Expected DTensor for rhs, got {type(rhs)}") + + # Normalize negative axis + ndim = lhs.ndim + if axis < 0: + axis = ndim + axis + + if axis < 0 or axis >= ndim: + raise ValueError(f"shardwise_outer_op: axis {axis} is out of bounds for tensor with {ndim} dimensions") + + if rhs.ndim != ndim: + raise ValueError( + f"shardwise_outer_op: lhs and rhs must have the same number of dimensions, " + f"got lhs.ndim={lhs.ndim} and rhs.ndim={rhs.ndim}" + ) + + # Pass axis to the autograd function - unsqueezing happens on local tensors + return _ShardwiseOuterOpImpl.apply(lhs, rhs, op, axis) diff --git a/src/boltz/distributed/model/layers/sigmoid_gate.py b/src/boltz/distributed/model/layers/sigmoid_gate.py new file mode 100644 index 000000000..cf3678b8b --- /dev/null +++ b/src/boltz/distributed/model/layers/sigmoid_gate.py @@ -0,0 +1,245 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import torch +from torch.distributed.tensor import DTensor, Partial, Shard + + +class _SigmoidGateImpl(torch.autograd.Function): + """Distributed implementation of sigmoid gating using DTensors. + + This autograd function implements a distributed sigmoid gating operation that applies + a sigmoid-activated gate to an input tensor. The operation is performed element-wise + across distributed tensors while maintaining proper gradient computation. + + The sigmoid gate computes: + output = x * sigmoid(g) + + Key features: + - Distributed computation across device meshes with various sharding strategies + - Memory-efficient implementation that operates on local tensor chunks + - Supports gradient computation through custom backward pass + - Validates tensor compatibility (device mesh, placements, shapes) + + Notes + ----- + Input tensors must be DTensors with: + - Identical device mesh and placements + - Compatible shapes (x and g must have the same shape) + - No Partial placements (not currently supported) + """ + + @staticmethod + def forward(ctx, x: DTensor, g: DTensor) -> DTensor: + """Forward pass of distributed sigmoid gating. + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object for saving information needed in backward pass. + x : DTensor + Input tensor to be gated. Can have any shape and sharding strategy. + g : DTensor + Gate tensor with pre-sigmoid values. Must have identical shape, + device mesh, and placements as x. + + Returns + ------- + DTensor + Output tensor with shape identical to input tensors. + Contains the result of x * sigmoid(g). + + Raises + ------ + TypeError + If inputs are not DTensors. + ValueError + If tensors have incompatible device meshes, placements, or if + Partial placements are used (not supported). + """ + if not isinstance(x, DTensor): + raise TypeError(f"Input 'x' must be of type DTensor. Got type {type(x)}.") + if not isinstance(g, DTensor): + raise TypeError(f"Input 'g' must be of type DTensor. Got type {type(g)}.") + + device_mesh_input = x.device_mesh + if g.device_mesh != device_mesh_input: + raise ValueError( + f"Input tensors 'x' and 'g' must have identical device mesh. " + f"Got device meshes {device_mesh_input} and {g.device_mesh}." + ) + + placements_input = x.placements + for i_dim_device_mesh, placement in enumerate(placements_input): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + if isinstance(placement, Shard): + if x.shape[placement.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {x.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size " + f"{device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + + if g.placements != placements_input: + raise ValueError( + f"Input tensors 'x' and 'g' must have identical placements. " + f"Got placements {placements_input} and {g.placements}." + ) + + input_shape = x.shape + if input_shape != g.shape: + raise ValueError( + f"Input tensors 'x' and 'g' must have identical shapes. Got shapes {input_shape} and {g.shape}." + ) + + g_local = g.to_local().sigmoid() + x_gated_local = x.to_local() * g_local + + ctx.save_for_backward(x_gated_local, g_local) + ctx.device_mesh_input = device_mesh_input + ctx.placements_input = placements_input + ctx.input_shape = input_shape + + out = DTensor.from_local( + x_gated_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=x.shape, + stride=x.stride(), + ) + return out + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[DTensor, DTensor]: + """Backward pass of distributed sigmoid gating. + + Computes gradients with respect to both input tensor x and gate tensor g. + + The gradients are: + - dx = grad_output * sigmoid(g) + - dg = grad_output * x * sigmoid(g) * (1 - sigmoid(g)) + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object containing saved tensors and metadata from forward pass. + grad_output : DTensor + Gradient of the loss with respect to the output tensor. + Must have identical device mesh and placements as the input tensors. + + Returns + ------- + tuple[DTensor, DTensor] + Gradients with respect to x and g respectively. + Both have the same shape and distribution as their corresponding inputs. + + Raises + ------ + TypeError + If grad_output is not a DTensor. + ValueError + If grad_output has incompatible device mesh or placements compared + to the input tensors from the forward pass. + """ + if not isinstance(grad_output, DTensor): + raise TypeError(f"Input 'grad_output' must be of type DTensor. Got type {type(grad_output)}.") + + if grad_output.device_mesh != ctx.device_mesh_input: + raise ValueError( + f"Input 'grad_output' must have the same device mesh as the input tensor. " + f"Got device meshes {grad_output.device_mesh} and {ctx.device_mesh_input}." + ) + + if grad_output.placements != ctx.placements_input: + raise ValueError( + f"Input 'grad_output' must have the same placements as the input tensor. " + f"Got placements {grad_output.placements} and {ctx.placements_input}." + ) + + if grad_output.shape != ctx.input_shape: + raise ValueError( + f"Input 'grad_output' must have the same shape as the input tensor. " + f"Got shapes {grad_output.shape} and {ctx.input_shape}." + ) + + x_gated_local, g_local = ctx.saved_tensors + grad_output_local = grad_output.to_local() + + dx_local = grad_output_local * g_local + dx = DTensor.from_local( + dx_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + + dg_local = grad_output_local * x_gated_local + dg_local *= 1 - g_local + dg = DTensor.from_local( + dg_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=grad_output.shape, + stride=grad_output.stride(), + ) + + return dx, dg + + +def sigmoid_gate(x: DTensor, g: DTensor) -> DTensor: + """Apply sigmoid gating to a distributed tensor. + + This function performs element-wise sigmoid gating: x * sigmoid(g), where both + input and gate tensors are distributed across multiple devices. The operation + is performed efficiently using local tensor operations while maintaining + gradient computation capabilities. + + Parameters + ---------- + x : DTensor + Input tensor to be gated. Can have any shape and sharding strategy. + g : DTensor + Gate tensor with pre-sigmoid values. Must have identical shape, + device mesh, and placements as x. + + Returns + ------- + DTensor + Gated output tensor with shape identical to input tensors. + Contains the result of x * sigmoid(g). + + Examples + -------- + >>> # Assume we have distributed tensors x and g with shape (B, N, D) + >>> output = sigmoid_gate(x, g) + >>> # output = x * torch.sigmoid(g), computed in distributed fashion + + Notes + ----- + - Both input tensors must be DTensors with compatible device meshes and placements + - Partial placements are not currently supported + - The function is differentiable and supports gradient computation + - The operation is performed on local tensor chunks for efficiency + """ + return _SigmoidGateImpl.apply(x, g) diff --git a/src/boltz/distributed/model/layers/squeeze.py b/src/boltz/distributed/model/layers/squeeze.py new file mode 100644 index 000000000..97f7a228f --- /dev/null +++ b/src/boltz/distributed/model/layers/squeeze.py @@ -0,0 +1,414 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import torch +from torch.distributed.tensor import DTensor, Partial, Shard + + +class _ShardwiseUnsqueezeImpl(torch.autograd.Function): + """Custom autograd function for performing unsqueeze operations on sharded distributed tensors. + + This function implements a differentiable unsqueeze operation that is compatible with + PyTorch's distributed tensor (DTensor) framework. It handles the complexities of + updating tensor placements and dimensions when adding a singleton dimension to + a sharded tensor across multiple devices. + + The implementation ensures that: + - Shard placements are correctly adjusted when dimensions shift due to unsqueeze + - Partial placements are not supported (will raise an error) + - Gradient computation is properly handled in the backward pass + """ + + @staticmethod + def forward(ctx, x: DTensor, dim: int) -> DTensor: + """Forward pass: performs unsqueeze operation on a distributed tensor. + + Args: + ctx: PyTorch autograd context for saving information needed in backward pass + x (DTensor): Input distributed tensor to unsqueeze + dim (int): Dimension at which to insert the singleton dimension. + Can be negative (counted from the end) + + Returns: + DTensor: Output tensor with an additional singleton dimension at the specified position + + Raises: + TypeError: If x is not a DTensor or dim is not an int + ValueError: If tensor has Partial placements (not supported) or if there's + uneven sharding that would prevent proper distribution + + Note: + The function follows PyTorch's unsqueeze semantics for shape and stride computation. + For sharded tensors, it updates the shard dimension indices when they are affected + by the dimension insertion. + """ + if not isinstance(x, DTensor): + raise TypeError(f"Input 'x' must be of type DTensor. Got type {type(x)}.") + if not isinstance(dim, int): + raise TypeError(f"Input 'dim' must be of type int. Got type {type(dim)}.") + + device_mesh_input = x.device_mesh + placements_input = x.placements + shape_input = x.shape + stride_input = x.stride() + + dim_to_insert = dim if dim >= 0 else x.ndim + 1 + dim + placements_output = list(placements_input) + for i_dim_device_mesh, p in enumerate(placements_input): + if isinstance(p, Partial): + raise ValueError("Partial placements are not supported") + if isinstance(p, Shard): + if x.shape[p.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {p.dim} of size {x.shape[p.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size " + f"{device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + if p.dim >= dim_to_insert: + placements_output[i_dim_device_mesh] = Shard(p.dim + 1) + + # update shape and stride according to pytorch unsqueeze() function: + # https://github.com/pytorch/pytorch/blob/2c16eb9f3db0ba68520e5832d8bb6d3d875bdaeb/aten/src/ATen/native/TensorShape.cpp#L3879-L3890 + shape_output = list(shape_input) + shape_output.insert(dim_to_insert, 1) + stride_output = list(stride_input) + stride_to_insert = 1 if dim_to_insert >= x.ndim else shape_input[dim_to_insert] * stride_input[dim_to_insert] + stride_output.insert(dim_to_insert, stride_to_insert) + + # Perform unsqueeze on local tensor + input_local = x.to_local() + output_local = input_local.unsqueeze(dim_to_insert) + + # Save necessary information for backward pass + ctx.device_mesh_input = device_mesh_input + ctx.placements_input = placements_input + ctx.placements_output = tuple(placements_output) + ctx.dim_to_squeeze = dim_to_insert + ctx.shape_input = shape_input + ctx.stride_input = stride_input + + # Create output DTensor + out = DTensor.from_local( + output_local, + shape=tuple(shape_output), + stride=tuple(stride_output), + device_mesh=device_mesh_input, + placements=ctx.placements_output, + ) + return out + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[DTensor, None]: + """Backward pass: computes gradients by performing squeeze operation. + + This method implements the reverse operation of unsqueeze for gradient computation. + It takes the gradient with respect to the output and computes the gradient with + respect to the input by squeezing the dimension that was added in the forward pass. + + Args: + ctx: PyTorch autograd context containing saved information from forward pass + grad_output (DTensor): Gradient with respect to the output tensor + + Returns: + tuple[DTensor, None]: Tuple containing: + - Gradient with respect to input tensor (DTensor or None if not needed) + - None for the dim parameter (int parameters don't need gradients) + + Raises: + TypeError: If grad_output is not a DTensor + ValueError: If grad_output has incompatible device mesh or placements + compared to the original input tensor + """ + if not isinstance(grad_output, DTensor): + raise TypeError(f"Input 'grad_output' must be of type DTensor. Got type {type(grad_output)}.") + + if grad_output.device_mesh != ctx.device_mesh_input: + raise ValueError( + f"Input 'grad_output' must have the same device mesh as the input tensor. " + f"Got device meshes {grad_output.device_mesh} and {ctx.device_mesh_input}." + ) + + if grad_output.placements != ctx.placements_output: + raise ValueError( + f"Input 'grad_output' must have the same placements as the input tensor. " + f"Got placements {grad_output.placements} and {ctx.placements_output}." + ) + + if ctx.needs_input_grad[0]: + # Perform squeeze on gradient + grad_output_local = grad_output.to_local() + grad_input_local = grad_output_local.squeeze(ctx.dim_to_squeeze) + + grad_input = DTensor.from_local( + grad_input_local, + shape=ctx.shape_input, + stride=ctx.stride_input, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + ) + else: + grad_input = None + + return grad_input, None + + +class _ShardwiseSqueezeImpl(torch.autograd.Function): + """Custom autograd function for performing squeeze operations on sharded distributed tensors. + + This function implements a differentiable squeeze operation that is compatible with + PyTorch's distributed tensor (DTensor) framework. It handles the complexities of + updating tensor placements and dimensions when removing singleton dimensions from + a sharded tensor across multiple devices. + + The implementation ensures that: + - Shard placements are correctly adjusted when dimensions shift due to squeeze + - Partial placements are not supported (will raise an error) + - Gradient computation is properly handled in the backward pass + - Only singleton dimensions (size 1) can be squeezed + """ + + @staticmethod + def forward(ctx, x: DTensor, dim: int) -> DTensor: + """Forward pass: performs squeeze operation on a distributed tensor. + + Args: + ctx: PyTorch autograd context for saving information needed in backward pass + x (DTensor): Input distributed tensor to squeeze + dim (int): Dimension to squeeze (must be a singleton dimension). + Can be negative (counted from the end) + + Returns: + DTensor: Output tensor with the singleton dimension removed + + Raises: + TypeError: If x is not a DTensor or dim is not an int + ValueError: If tensor has Partial placements (not supported), if there's + uneven sharding, or if the dimension to squeeze is not singleton + """ + if not isinstance(x, DTensor): + raise TypeError(f"Input 'x' must be of type DTensor. Got type {type(x)}.") + if not isinstance(dim, int): + raise TypeError(f"Input 'dim' must be of type int. Got type {type(dim)}.") + + device_mesh_input = x.device_mesh + placements_input = x.placements + shape_input = x.shape + stride_input = x.stride() + + # Convert negative dim to positive + dim_to_squeeze = dim if dim >= 0 else x.ndim + dim + + # Check if dimension is valid and singleton + if dim_to_squeeze < 0 or dim_to_squeeze >= x.ndim: + raise ValueError(f"Dimension {dim} is out of range for tensor with {x.ndim} dimensions") + if shape_input[dim_to_squeeze] != 1: + raise ValueError(f"Cannot squeeze dimension {dim_to_squeeze} with size {shape_input[dim_to_squeeze]}") + + placements_output = list(placements_input) + for i_dim_device_mesh, p in enumerate(placements_input): + if isinstance(p, Partial): + raise ValueError("Partial placements are not supported") + if isinstance(p, Shard): + # Check if trying to squeeze a sharded dimension + if p.dim == dim_to_squeeze: + raise ValueError(f"Cannot squeeze dimension {dim_to_squeeze} as it is sharded") + if x.shape[p.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {p.dim} of size {x.shape[p.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size " + f"{device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + if p.dim > dim_to_squeeze: + placements_output[i_dim_device_mesh] = Shard(p.dim - 1) + + # Update shape and stride according to pytorch squeeze() function + shape_output = list(shape_input) + shape_output.pop(dim_to_squeeze) + stride_output = list(stride_input) + stride_output.pop(dim_to_squeeze) + + # Perform squeeze on local tensor + input_local = x.to_local() + output_local = input_local.squeeze(dim_to_squeeze) + + # Save necessary information for backward pass + ctx.device_mesh_input = device_mesh_input + ctx.placements_input = placements_input + ctx.placements_output = tuple(placements_output) + ctx.dim_to_unsqueeze = dim_to_squeeze + ctx.shape_input = shape_input + ctx.stride_input = stride_input + + # Create output DTensor + out = DTensor.from_local( + output_local, + shape=tuple(shape_output), + stride=tuple(stride_output), + device_mesh=device_mesh_input, + placements=ctx.placements_output, + ) + return out + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[DTensor, None]: + """Backward pass: computes gradients by performing unsqueeze operation. + + This method implements the reverse operation of squeeze for gradient computation. + It takes the gradient with respect to the output and computes the gradient with + respect to the input by unsqueezing the dimension that was removed in the forward pass. + + Args: + ctx: PyTorch autograd context containing saved information from forward pass + grad_output (DTensor): Gradient with respect to the output tensor + + Returns: + tuple[DTensor, None]: Tuple containing: + - Gradient with respect to input tensor (DTensor or None if not needed) + - None for the dim parameter (int parameters don't need gradients) + + Raises: + TypeError: If grad_output is not a DTensor + ValueError: If grad_output has incompatible device mesh or placements + compared to the original input tensor + """ + if not isinstance(grad_output, DTensor): + raise TypeError(f"Input 'grad_output' must be of type DTensor. Got type {type(grad_output)}.") + + if grad_output.device_mesh != ctx.device_mesh_input: + raise ValueError( + f"Input 'grad_output' must have the same device mesh as the input tensor. " + f"Got device meshes {grad_output.device_mesh} and {ctx.device_mesh_input}." + ) + + if grad_output.placements != ctx.placements_output: + raise ValueError( + f"Input 'grad_output' must have the same placements as the input tensor. " + f"Got placements {grad_output.placements} and {ctx.placements_output}." + ) + + if ctx.needs_input_grad[0]: + # Perform unsqueeze on gradient + grad_output_local = grad_output.to_local() + grad_input_local = grad_output_local.unsqueeze(ctx.dim_to_unsqueeze) + + grad_input = DTensor.from_local( + grad_input_local, + shape=ctx.shape_input, + stride=ctx.stride_input, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + ) + else: + grad_input = None + + return grad_input, None + + +def shardwise_squeeze(x: DTensor, dim: int) -> DTensor: + """Performs a squeeze operation on a sharded distributed tensor. + + This function removes a singleton dimension from a distributed tensor at the specified + position while maintaining proper sharding across multiple devices. It's designed + to work seamlessly with PyTorch's autograd system for gradient computation. + + Args: + x (DTensor): Input distributed tensor to squeeze + dim (int): Dimension to squeeze (must be a singleton dimension with size 1). + Can be negative (counted from the end). Valid range is + [-x.ndim, x.ndim-1] where negative values are converted + to positive using: x.ndim + dim + + Returns: + DTensor: New distributed tensor with the singleton dimension removed. + The output tensor will have one fewer dimension than the input, + with all other dimensions unchanged. + + Raises: + TypeError: If x is not a DTensor or dim is not an int + ValueError: If the tensor has unsupported placement types, incompatible + sharding configurations, or if the dimension to squeeze is not singleton + + Examples: + >>> # Assuming we have a 3D distributed tensor of shape (4, 1, 6) + >>> x = ... # DTensor with shape (4, 1, 6) + >>> y = shardwise_squeeze(x, dim=1) # Remove singleton dimension at position 1 + >>> print(y.shape) # (4, 6) + + >>> # Using negative indexing + >>> z = shardwise_squeeze(x, dim=-2) # Remove singleton dimension at position 1 (from end) + >>> print(z.shape) # (4, 6) + + Note: + This function is a wrapper around the _ShardwiseSqueezeImpl autograd function, + making it more convenient to use while maintaining full gradient support. + + The function handles the complexity of updating shard placements when dimensions + are shifted due to the squeeze operation, ensuring correct distributed behavior. + + Unlike PyTorch's squeeze() which can squeeze all singleton dimensions when no dim + is specified, this function requires an explicit dimension to be specified. + """ + return _ShardwiseSqueezeImpl.apply(x, dim) + + +def shardwise_unsqueeze(x: DTensor, dim: int) -> DTensor: + """Performs an unsqueeze operation on a sharded distributed tensor. + + This function adds a singleton dimension to a distributed tensor at the specified + position while maintaining proper sharding across multiple devices. It's designed + to work seamlessly with PyTorch's autograd system for gradient computation. + + Args: + x (DTensor): Input distributed tensor to unsqueeze + dim (int): Dimension at which to insert the singleton dimension. + Can be negative (counted from the end). Valid range is + [-x.ndim-1, x.ndim] where negative values are converted + to positive using: x.ndim + 1 + dim + + Returns: + DTensor: New distributed tensor with an additional singleton dimension. + The output tensor will have one more dimension than the input, + with all other dimensions unchanged. + + Raises: + TypeError: If x is not a DTensor or dim is not an int + ValueError: If the tensor has unsupported placement types or incompatible + sharding configurations + + Examples: + >>> # Assuming we have a 2D distributed tensor of shape (4, 6) + >>> x = ... # DTensor with shape (4, 6) + >>> y = shardwise_unsqueeze(x, dim=1) # Insert at dimension 1 + >>> print(y.shape) # (4, 1, 6) + + >>> # Using negative indexing + >>> z = shardwise_unsqueeze(x, dim=-1) # Insert at last position + >>> print(z.shape) # (4, 6, 1) + + Note: + This function is a wrapper around the _ShardwiseUnsqueezeImpl autograd function, + making it more convenient to use while maintaining full gradient support. + + The function handles the complexity of updating shard placements when dimensions + are shifted due to the unsqueeze operation, ensuring correct distributed behavior. + """ + return _ShardwiseUnsqueezeImpl.apply(x, dim) diff --git a/src/boltz/distributed/model/layers/swiglu.py b/src/boltz/distributed/model/layers/swiglu.py new file mode 100755 index 000000000..d59819bff --- /dev/null +++ b/src/boltz/distributed/model/layers/swiglu.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from torch.distributed.tensor import DTensor +from torch.nn import Module + +from boltz.distributed.model.layers.cat_and_chunk import shardwise_chunk +from boltz.distributed.model.layers.elementwise_op import ElementwiseOp, elementwise_op +from boltz.distributed.model.layers.sigmoid_gate import sigmoid_gate + + +class SwiGLU(Module): + """SwiGLU implemented as a module, for DTensor inputs. + + Gradient computation is implemented in the backward method for the + implementations of shardwise_chunk, sigmoid_gate, and + mult_for_same_placement_and_shape. + + See src/boltz/model/modules/utils.py for reference implementation. + """ + + def forward(self, x: DTensor) -> DTensor: + """Forward pass of SwiGLU. + + DTensor metadata checking is performed in the implementation of + each of shardwise_chunk, sigmoid_gate, and + mult_for_same_placement_and_shape. + + Parameters + ---------- + x : DTensor + Input tensor. + + Returns + ------- + DTensor + Output tensor. + + Raises + ------ + ValueError + See _check_forward_input_for_impl_scope for more details. + """ + self.check_forward_input_for_impl_scope(x) + y, z = shardwise_chunk(x, chunks=2, dim=-1) + a: DTensor = sigmoid_gate(x=z, g=z) + a: DTensor = elementwise_op(a, y, ElementwiseOp.PROD) + + return a + + def check_forward_input_for_impl_scope(self, x: DTensor): + """Check that the input tensor is compatible with the SwiGLU operation. + + The SwiGLU operations is defined only if the size of the last dimension + is a multiple of 2. + + Parameters + ---------- + x : DTensor + Input tensor. + + Raises + ------ + ValueError + If the size of the last dimension is not a multiple of 2. + """ + if hasattr(x, "shape") and x.shape[-1] % 2 != 0: + raise ValueError( + ", ".join( + [ + "SwiGLU operation defined only if the size of the last dimension", + f"is a multiple of 2, whereas x.shape={x.shape}", + ] + ) + ) diff --git a/src/boltz/distributed/model/layers/transition.py b/src/boltz/distributed/model/layers/transition.py new file mode 100644 index 000000000..3322c0e42 --- /dev/null +++ b/src/boltz/distributed/model/layers/transition.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from torch import nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor + +from boltz.distributed.model.layers.elementwise_op import ElementwiseOp, elementwise_op +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.linear import LinearParamsReplicated +from boltz.distributed.model.layers.sigmoid_gate import sigmoid_gate +from boltz.model.layers.transition import Transition as SerialTransition + + +class Transition(nn.Module): + """Distributed two-layer MLP using DTensor.""" + + def __init__( + self, + layer: SerialTransition, + device_mesh: DeviceMesh, + ) -> None: + """Initialize the distributed Transition module. + + Parameters + ---------- + layer : SerialTransition + The serial transition layer containing weights to be distributed. + device_mesh : DeviceMesh + Device mesh defining the distributed computation topology. + """ + super().__init__() + self.device_mesh = device_mesh + self.hidden = layer.hidden + + # Map serial layers to distributed versions + self.norm = LayerNormParamsReplicated(layer.norm, self.device_mesh) + self.fc1 = LinearParamsReplicated(layer.fc1, self.device_mesh) + self.fc2 = LinearParamsReplicated(layer.fc2, self.device_mesh) + self.fc3 = LinearParamsReplicated(layer.fc3, self.device_mesh) + + def forward(self, x: DTensor) -> DTensor: + """Perform a forward pass. + + Parameters + ---------- + x : DTensor + The input data of shape (..., D) + + Returns + ------- + DTensor + The output data of shape (..., D) + """ + x = self.norm(x) + + fc1_out = self.fc1(x) + fc2_out = self.fc2(x) + + # SwiGLU activation: silu(fc1_out) * fc2_out + # Since SiLU(x) = x * sigmoid(x), we use sigmoid_gate(fc1_out, fc1_out) * fc2_out + # NOTE: self.fc1 and self.fc2 have the same dimensionality mapping: dim -> hidden + # so fc1_out and fc2_out are of the same shape + x = sigmoid_gate(fc1_out, fc1_out) + x = elementwise_op(x, fc2_out, ElementwiseOp.PROD) + + x = self.fc3(x) + return x diff --git a/src/boltz/distributed/model/layers/triangular_attention.py b/src/boltz/distributed/model/layers/triangular_attention.py new file mode 100644 index 000000000..e4d5200f9 --- /dev/null +++ b/src/boltz/distributed/model/layers/triangular_attention.py @@ -0,0 +1,1658 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import math +from enum import Enum, auto +from typing import Optional + +import torch +from torch import nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_tensor + +from boltz.distributed.comm import Ring2DCommTriAttn +from boltz.distributed.model.layers.layernorm import _LayerNormParamsReplicatedImpl +from boltz.distributed.model.layers.linear import LinearParamsReplicated, _LinearParamsReplicatedImpl +from boltz.distributed.model.layers.sigmoid_gate import sigmoid_gate +from boltz.distributed.model.modules.utils import TriAttnBackend +from boltz.distributed.utils import tiled_softmax_attention_update, update_exhaustive_strides +from boltz.model.layers.triangular_attention.attention import ( + TriangleAttentionEndingNode as SerialTriangleAttentionEndingNode, +) +from boltz.model.layers.triangular_attention.attention import ( + TriangleAttentionStartingNode as SerialTriangleAttentionStartingNode, +) +from boltz.model.layers.triangular_attention.primitives import ( + Attention, +) +from boltz.model.layers.triangular_attention.primitives import LayerNorm as SerialLayerNormNoAutoCastBF16 +from boltz.model.layers.triangular_attention.primitives import Linear as SerialLinearNoAutoCastBF16 +from boltz.model.layers.triangular_attention.utils import ( + permute_final_dims, +) + +try: + import cuequivariance_torch.primitives.triangle as cueq_triangle + + cueq_is_installed = True +except ImportError: + cueq_is_installed = False + +try: + from trifast.torch import _triangle_attention as trifast_triangle_attention + from trifast.torch import triangle_attention_bwd as trifast_triangle_attention_bwd + + trifast_is_installed = True +except ImportError: + trifast_is_installed = False + + +def can_run_cueq_triattn_sm100f( + device: torch.device, + dtype: torch.dtype, + dim_token: int, + dim_hidden: int, + is_fwd: bool, +) -> bool: + """Check whether the cuEq SM100f triangle-attention kernel can run. + + Parameters + ---------- + device : torch.device + Target device (must be CUDA with SM100 or SM103 compute capability). + dtype : torch.dtype + Data type of q/k tensors. + dim_token : int + Token (sequence) dimension — ``q.shape[-2]`` or ``kT.shape[3]``. + dim_hidden : int + Per-head hidden dimension — ``q.shape[-1]``, i.e. ``c_hidden``. + is_fwd : bool + ``True`` for the forward pass, ``False`` for backward. + + Returns + ------- + bool + ``True`` when all SM100f constraints are satisfied. + """ + if device.type != "cuda": + return False + if dtype not in (torch.bfloat16, torch.float16): + return False + device_cc = torch.cuda.get_device_capability(device) + if device_cc not in ((10, 0), (10, 3)): + return False + if dim_hidden > 128 or dim_hidden % 8 != 0: + return False + if not is_fwd or dim_token % 8 == 0: + return True + return False + + +class LayerNormParamsReplicatedNoAutoCastBF16(nn.Module): + """ + A LayerNorm module with replicated parameters for distributed training and disabled autocast in BF16. + + This module wraps around `_LayerNormParamsReplicatedImpl` to provide a user-friendly interface + for LayerNorm operations using the DTensor API. It supports distributed training with replicated + and sharded placements for input tensors and replicated placements for weight and bias tensors. + + Args: + layer_local (nn.LayerNorm): An already-initialized nn.LayerNorm instance. + device_mesh (DeviceMesh): The device mesh for distributed training. + """ + + def __init__(self, layer_local: SerialLayerNormNoAutoCastBF16, device_mesh: DeviceMesh) -> None: + if not isinstance(layer_local, SerialLayerNormNoAutoCastBF16): + raise ValueError( + f"layer_local is not an instance of SerialLayerNormNoAutoCastBF16 but got {type(layer_local)}" + ) + if layer_local.weight.device.type != device_mesh.device_type: + raise ValueError( + f"layer_local.weight and device_mesh are not on the same device type: " + f"{layer_local.weight.device.type} != {device_mesh.device_type}" + ) + if layer_local.bias is not None and layer_local.bias.device.type != device_mesh.device_type: + raise ValueError( + f"layer_local.bias and device_mesh are not on the same device type: " + f"{layer_local.bias.device.type} != {device_mesh.device_type}" + ) + + super().__init__() + self.c_in = layer_local.c_in + self.normalized_shape = list(self.c_in) + self.eps = layer_local.eps + self.device_mesh = device_mesh + + all_replicate_placements = [Replicate()] * device_mesh.ndim + + if layer_local.weight is None: + self.register_parameter("weight", None) + else: + self.weight = nn.Parameter( + distribute_tensor(layer_local.weight.data, device_mesh, all_replicate_placements) + ) + if layer_local.bias is None: + self.register_parameter("bias", None) + else: + self.bias = nn.Parameter(distribute_tensor(layer_local.bias.data, device_mesh, all_replicate_placements)) + + def forward(self, x: DTensor) -> DTensor: + """ + Forward pass of LayerNormParamsReplicated. + + Args: + x (DTensor): Input tensor. + + Returns: + DTensor: The normalized output tensor. + """ + d = x.dtype + if d is torch.bfloat16: + with torch.autocast("cuda", enabled=False): + out = _LayerNormParamsReplicatedImpl.apply( + x, self.normalized_shape, self.weight, self.bias, self.eps, True + ) + else: + out = _LayerNormParamsReplicatedImpl.apply(x, self.normalized_shape, self.weight, self.bias, self.eps) + return out + + +class LinearParamsReplicatedNoAutoCastBF16(nn.Module): + """ + Distributed linear layer with parameters replicated across all device mesh dimensions and disabled autocast in BF16. + + This is almost equivalent to + ```python + layer = torch.distributed.tensor.distribute_module(layer_local, device_mesh) + ``` + with the exception that the torch.distributed.tensor.distribute_module version will incur + significant overhead due to the unnecessary replication of the output tensor along certain + device mesh dimensions. + + This class avoids such unnecessary overhead by using the custom _LinearParamsReplicatedImpl + autograd function for forward and backward pass computation instead of relying on the distributed + module's forward implementation. + + Key requirements: + 1. Parameters (weight and bias) will be replicated on all device mesh dimensions + 2. Input tensor and parameters must be on the same device mesh + 3. Feature/hidden dimension of the input must not be sharded across the device mesh + 4. Partial reduction along any input dimension is not supported + 5. Input and outputs must be on the same device mesh with the same placements + 6. Gradients of the input have the same placements on the same device mesh as the input + 7. Gradients of the weight and bias have Partial("sum") placements along the input's Shard placements' + dimension so that the all-reduce will be performed along those device-grid dimensions + + """ + + def __init__(self, layer_local: SerialLinearNoAutoCastBF16, device_mesh: DeviceMesh): + """ + Initialize the distributed linear layer. + + Args: + layer_local: nn.Linear to be distributed + device_mesh: Device mesh for distributed computation + """ + if not isinstance(layer_local, SerialLinearNoAutoCastBF16): + raise ValueError( + f"layer_local is not an instance of SerialLinearNoAutoCastBF16 but got {type(layer_local)}" + ) + if layer_local.weight.device.type != device_mesh.device_type: + raise ValueError( + f"layer_local.weight and device_mesh are not on the same device type: " + f"{layer_local.weight.device.type} != {device_mesh.device_type}" + ) + if layer_local.bias is not None and layer_local.bias.device.type != device_mesh.device_type: + raise ValueError( + f"layer_local.bias and device_mesh are not on the same device type: " + f"{layer_local.bias.device.type} != {device_mesh.device_type}" + ) + super().__init__() + all_replicate_placements = [Replicate()] * device_mesh.ndim + self.weight = nn.Parameter(distribute_tensor(layer_local.weight.data, device_mesh, all_replicate_placements)) + if layer_local.bias is None: + self.register_parameter("bias", None) + else: + self.bias = nn.Parameter(distribute_tensor(layer_local.bias.data, device_mesh, all_replicate_placements)) + + def forward(self, input: DTensor) -> DTensor: + """ + Forward pass for the distributed linear layer. + + Uses the custom _LinearParamsReplicatedImpl autograd function to perform the computation + efficiently while preserving correct autograd behavior for distributed tensors. + + Args: + input: Input DTensor with appropriate placement strategy + + Returns: + Output DTensor with same placement strategy as input + """ + d = input.dtype + if d is torch.bfloat16: + with torch.autocast("cuda", enabled=False): + return _LinearParamsReplicatedImpl.apply(input, self.weight, self.bias, True) + else: + return _LinearParamsReplicatedImpl.apply(input, self.weight, self.bias) + + +class _RingMultiHeadTriangleAttentionImpl(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + q_x: DTensor, + kv_x: DTensor, + mask: Optional[DTensor], + triangle_bias: DTensor, + weight_q: DTensor, + weight_k: DTensor, + weight_v: DTensor, + no_heads: int, + c_hidden: int, + ring_comm: Ring2DCommTriAttn, + inf: float, + triattn_backend: TriAttnBackend, + ) -> DTensor: + """This function does the linear projection to prepare the q, k and v tensors + and use them later to compute triangular attention. + + Linear projection and initial data shard redistribution: + - Triangle bias is reorganized in two stages to avoid cross-rail traffic + - Key/value pairs are initially shuffled to offset computation along + attention matrix diagonal + + Stage 1 Triangle Bias Redistribution: Flatten diagonals onto rows/columns + Original Data Ownership After Stage 1 (for axis_cp=1) + ┌───┬───┬───┐ ┌───┬───┬───┐ + │0,0│0,1│0,2│ │0,0│1,1│2,2│ original lower diagonal 0 + ├───┼───┼───┤ ├───┼───┼───┤ + │1,0│1,1│1,2│ → │1,0│2,1│0,2│ original lower diagonal 1 + ├───┼───┼───┤ ├───┼───┼───┤ + │2,0│2,1│2,2│ │2,0│0,1│1,2│ original lower diagonal 2 + └───┴───┴───┘ └───┴───┴───┘ + + Stage 2 Triangle Bias Redistribution: Rotate elements to meet ring attention requirements + After Stage 2 (for axis_cp=1) + ┌───┬───┬───┐ + │0,0│1,1│2,2│ original lower diagonal 0 + ├───┼───┼───┤ + │0,2│1,0│2,1│ original lower diagonal 1 + ├───┼───┼───┤ + │0,1│1,2│2,0│ original lower diagonal 2 + └───┴───┴───┘ + + Forward pass for Multi-Head Triangle Attention using tri-axial + virtual all_gather, reduce and all_gather. It uses a ring + communication pattern across devices: + - Each device maintains double buffers for k, v, triangle bias and mask + - Data is shifted along the context parallelism axis in a ring pattern + - Communication is overlapped with computation for efficiency + + Data sharding Diagram: + ``` + The algorithm can be summarized as the 2-tuple (i, j) indexing the input data + ownership of the triangle bias, where subsequent data ownership of other data + is constrained by matching the corresponding i index of the triangle bias if + the data contributes to the rows of the attention matrix (Q index) or the corresponding + j index of the triangle bias if the data contributes to the columns of the + attention matrix (K index). + + Initial Distribution - For axis_cp=1: (See Ring2DCommTriAttn for more explanation) + + initialized by _PrepQKVCommBiasImpl: + ┌───┬───┬───┐ + │0,0│1,1│2,2│ + ├───┼───┼───┤ + │0,2│1,0│2,1│ + ├───┼───┼───┤ + │0,1│1,2│2,0│ + └───┴───┴───┘ + + At each step, the 2-tuple are updated by upshifted 1 along each column + e.g., at the end of step 0: + ┌───┬───┬───┐ + │0,2│1,0│2,1│ # Column 0: 0,0→0,2→0,1 (wrapped) + ├───┼───┼───┤ + │0,1│1,2│2,0│ # Column 1: 1,1→1,0→1,2 + ├───┼───┼───┤ + │0,0│1,1│2,2│ # Column 2: 2,2→2,1→2,0 + └───┴───┴───┘ + which gives the corresponding Q and K indices for each rank required for step 1 + + At the end of step 2, the data is rotated to its "initialized" state, + where the corresponding tensor along the Q and K axes are saved for + backward pass + + Args: + ctx: Context object for autograd + q_x: Query input tensor + kv_x: Key-value input tensor + mask: Optional mask tensor (None creates all-ones mask) + triangle_bias: Triangle bias tensor + weight_q: Query projection weights + weight_k: Key projection weights + weight_v: Value projection weights + no_heads: Number of attention heads + c_hidden: Hidden dimension size + ring_comm: Ring2DCommTriAttn object for distributed communication + inf: Infinity value for mask bias computation + triattn_backend: Triangular attention backend to use + Returns: + DTensor: Output tensor of shape [*, H, Q, C_hidden] + """ + has_mask = mask is not None + # Check if inputs are of type DTensor + if not isinstance(q_x, DTensor): + raise TypeError(f"Input 'q_x' must be of type DTensor. Got type {type(q_x)}.") + if not isinstance(kv_x, DTensor): + raise TypeError(f"Input 'kv_x' must be of type DTensor. Got type {type(kv_x)}.") + if has_mask and not isinstance(mask, DTensor): + raise TypeError(f"Input 'mask' must be of type DTensor or None. Got type {type(mask)}.") + if not isinstance(triangle_bias, DTensor): + raise TypeError(f"Input 'triangle_bias' must be of type DTensor. Got type {type(triangle_bias)}.") + if not isinstance(weight_q, DTensor): + raise TypeError(f"Input 'weight_q' must be of type DTensor. Got type {type(weight_q)}.") + if not isinstance(weight_k, DTensor): + raise TypeError(f"Input 'weight_k' must be of type DTensor. Got type {type(weight_k)}.") + if not isinstance(weight_v, DTensor): + raise TypeError(f"Input 'weight_v' must be of type DTensor. Got type {type(weight_v)}.") + if not isinstance(triattn_backend, TriAttnBackend): + raise TypeError( + f"Input 'triattn_backend' must be of type TriAttnBackend. Got type {type(triattn_backend)}." + ) + + if triattn_backend not in [ + TriAttnBackend.CUEQ, + TriAttnBackend.TRIFAST, + TriAttnBackend.REFERENCE, + TriAttnBackend.CUEQ_FWD_TRIFAST_BWD, + ]: + # to prevent accidental usage of unsupported backend + # so that we don't have to handle the unsupported backend case in the later code + # every time we have a backend selection logic + raise NotImplementedError( + f"Input 'triattn_backend' must be one of {TriAttnBackend.CUEQ, TriAttnBackend.TRIFAST, TriAttnBackend.REFERENCE, TriAttnBackend.CUEQ_FWD_TRIFAST_BWD}. " + f"Got {triattn_backend}." + ) + + if triattn_backend in (TriAttnBackend.TRIFAST, TriAttnBackend.CUEQ_FWD_TRIFAST_BWD) and not has_mask: + raise ValueError( + "trifast or cueq_fwd_trifast_bwd backend requires a mask but mask is None. " + "Please provide a all-zeros mask to indicate all-valid elements for trifast" + ) + + if inf > torch.finfo(q_x.dtype).max: + raise ValueError( + f"Input 'inf'={inf} is larger than max value of dtype {q_x.dtype}: {torch.finfo(q_x.dtype).max}" + ) + + # Check if inputs have identical device mesh + device_mesh_input = q_x.device_mesh + if device_mesh_input != kv_x.device_mesh: + raise ValueError( + f"Input tensors 'q_x' and 'kv_x' must have identical device mesh. " + f"Got device meshes {device_mesh_input} and {kv_x.device_mesh}." + ) + if has_mask and device_mesh_input != mask.device_mesh: + raise ValueError( + f"Input tensors 'q_x' and 'mask' must have identical device mesh. " + f"Got device meshes {device_mesh_input} and {mask.device_mesh}." + ) + if device_mesh_input != triangle_bias.device_mesh: + raise ValueError( + f"Input tensors 'q_x' and 'triangle_bias' must have identical device mesh. " + f"Got device meshes {device_mesh_input} and {triangle_bias.device_mesh}." + ) + if device_mesh_input != weight_q.device_mesh: + raise ValueError( + f"Input tensors 'q_x' and 'weight_q' must have identical device mesh. " + f"Got device meshes {device_mesh_input} and {weight_q.device_mesh}." + ) + if device_mesh_input != weight_k.device_mesh: + raise ValueError( + f"Input tensors 'q_x' and 'weight_k' must have identical device mesh. " + f"Got device meshes {device_mesh_input} and {weight_k.device_mesh}." + ) + if device_mesh_input != weight_v.device_mesh: + raise ValueError( + f"Input tensors 'q_x' and 'weight_v' must have identical device mesh. " + f"Got device meshes {device_mesh_input} and {weight_v.device_mesh}." + ) + + # Check if q_x, kv_x, and mask_bias have the expected placements (Shard(0), Shard(1), Shard(2)) + + expected_input_placements = (Shard(0), Shard(1), Shard(2)) + + if q_x.placements != expected_input_placements: + raise ValueError( + f"Input tensor 'q_x' must have placements {expected_input_placements}. Got placements {q_x.placements}." + ) + if kv_x.placements != expected_input_placements: + raise ValueError( + f"Input tensor 'kv_x' must have placements {expected_input_placements}. " + f"Got placements {kv_x.placements}." + ) + + placements_input = q_x.placements + + # Create placement mapping for weight gradients + # Weight gradients should have Partial("sum") placements corresponding to input's Shard placements + placements_dweights = [Partial("sum"), Partial("sum"), Partial("sum")] + + # Check mask placements - should match q_x and kv_x + if has_mask and mask.placements != expected_input_placements: + raise ValueError( + f"Input tensor 'mask' must have the same placements as 'q_x' and 'kv_x'. " + f"Expected placements {expected_input_placements}, got {mask.placements}." + ) + + # Check triangle_bias placements - should be sharded along batch, I, and J dimensions + # triangle_bias should have shape [B, I, J, H] and be sharded on dimensions 0, 1, and 2 + if triangle_bias.placements != placements_input: + raise ValueError( + f"Input tensor 'triangle_bias' must be sharded along batch, I, and J dimensions. " + f"Expected placements {placements_input}, got {triangle_bias.placements}." + ) + + # Check weight placements - should be all replicated + all_replicate_placements = [Replicate()] * device_mesh_input.ndim + if weight_q.placements != tuple(all_replicate_placements): + raise ValueError( + f"Weight tensor 'weight_q' must have all replicated placements. " + f"Expected {tuple(all_replicate_placements)}, got {weight_q.placements}." + ) + if weight_k.placements != tuple(all_replicate_placements): + raise ValueError( + f"Weight tensor 'weight_k' must have all replicated placements. " + f"Expected {tuple(all_replicate_placements)}, got {weight_k.placements}." + ) + if weight_v.placements != tuple(all_replicate_placements): + raise ValueError( + f"Weight tensor 'weight_v' must have all replicated placements. " + f"Expected {tuple(all_replicate_placements)}, got {weight_v.placements}." + ) + + # Check ring_comm consistency + coord_device_mesh_input = device_mesh_input.get_coordinate() + if coord_device_mesh_input is None: + raise ValueError( + f"ring_comm.coord_2d {ring_comm.coord_2d} is not on device_mesh_input {device_mesh_input}." + ) + if ring_comm.coord_2d != (coord_device_mesh_input[1], coord_device_mesh_input[2]): + raise ValueError( + f"Input ring_comm's coord_2d {ring_comm.coord_2d} does not match the " + f"device mesh's rank coordinates {coord_device_mesh_input} for the sharded dimensions." + ) + + if q_x.shape != kv_x.shape: + raise ValueError( + f"Input tensors 'q_x' and 'kv_x' must have the same shape. Got shapes {q_x.shape} and {kv_x.shape}." + ) + + if has_mask and mask.shape != q_x.shape[:-1]: + raise ValueError( + f"Input tensor 'mask' must have the same shape as 'q_x' and 'kv_x' except the last dimension. " + f"Got shapes {mask.shape} and {q_x.shape[:-1]}." + ) + + if triangle_bias.shape != q_x.shape[:-1] + (no_heads,): + raise ValueError( + f"Input tensor 'triangle_bias' must have the same shape as 'q_x' and 'kv_x' " + f"except the last dimension and the last dimension must be equal to no_heads. " + f"Got shapes {triangle_bias.shape} and {q_x.shape[:-1] + (no_heads,)}." + ) + + # To accommodate the hybrid TriAttnBackend.CUEQ_FWD_TRIFAST_BWD, we dedicated + # two working flags for the respective fwd and bwd cases in order to reuse the + # cueq and trifast logics for the hybrid mode. But we need to modify the tensor shape + # in the hybrid mode due to the different requirements of the two backends. + # The shapes for the input between cueq and trifast are: + # q: [B, I, H, Q, C_hidden] vs [B, H, I, Q, C_hidden] + # kT: [B, I, H, K, C_hidden] vs [B, H, I, K, C_hidden] + # v: [B, I, H, V, C_hidden] vs [B, H, I, V, C_hidden] + # triangle_bias: [B, 1, H, I, J] vs [B, H, I, J] + # mask: [B, I, 1, 1, J] vs [B, I, J] + triattn_backend_fwd = ( + TriAttnBackend.CUEQ + if triattn_backend in (TriAttnBackend.CUEQ, TriAttnBackend.CUEQ_FWD_TRIFAST_BWD) + else triattn_backend + ) + triattn_backend_bwd = ( + TriAttnBackend.TRIFAST + if triattn_backend in (TriAttnBackend.TRIFAST, TriAttnBackend.CUEQ_FWD_TRIFAST_BWD) + else triattn_backend + ) + + if triattn_backend == TriAttnBackend.REFERENCE: + # we manage the scale in this function scope using torch op in fwd and bwd + apply_scale = True + elif triattn_backend in (TriAttnBackend.CUEQ, TriAttnBackend.TRIFAST, TriAttnBackend.CUEQ_FWD_TRIFAST_BWD): + # we let the TriAttnBackend handle the scale internally in fwd and bwd + apply_scale = False + + if has_mask: + ctx.mark_non_differentiable(mask) + ctx.device_mesh_input = device_mesh_input + ctx.placements_input = placements_input + ctx.placements_dweights = placements_dweights + ctx.input_shape = q_x.shape + ctx.input_stride = q_x.stride() + ctx.weight_q_shape = weight_q.shape + ctx.weight_q_stride = weight_q.stride() + ctx.weight_k_shape = weight_k.shape + ctx.weight_k_stride = weight_k.stride() + ctx.weight_v_shape = weight_v.shape + ctx.weight_v_stride = weight_v.stride() + ctx.triangle_bias_shape = triangle_bias.shape + ctx.triangle_bias_stride = triangle_bias.stride() + ctx.ring_comm = ring_comm + q_scale = 1.0 / math.sqrt(c_hidden) # multiplicative factor for q + ctx.q_scale = q_scale + ctx.apply_scale = apply_scale + ctx.no_heads = no_heads + ctx.c_hidden = c_hidden + ctx.has_mask = has_mask + ctx.triattn_backend_bwd = triattn_backend_bwd + + # Store the mode based on ring_comm.axis_cp + ctx.mode = _Mode.Ending if ring_comm.axis_cp == 0 else _Mode.Starting + + # Convert DTensors to local tensors for computation + q_x_local = q_x.to_local() + kv_x_local = kv_x.to_local() + mask_local = mask.to_local() if has_mask else None + triangle_bias_local = triangle_bias.to_local() + + # Handle transpose for ending mode (when ring_comm.axis_cp == 0) + if ctx.mode == _Mode.Ending: + # Transpose input tensors from [*, I, J, C] to [*, J, I, C] + q_x_local = q_x_local.transpose(-2, -3).contiguous() + kv_x_local = kv_x_local.transpose(-2, -3).contiguous() + # Transpose triangle_bias from [*, I, J, H] to [*, J, I, H] + triangle_bias_local = triangle_bias_local.transpose(-2, -3) + # Transpose mask from [*, I, J] to [*, J, I] if mask exists + if has_mask: + mask_local = mask_local.transpose(-1, -2) + + if has_mask: + if triattn_backend_fwd == TriAttnBackend.CUEQ: + # if not casting to bool here, cueq will cast it internally anyway + # so might as well do it here to save some communication bandwidth. + # Also convert mask to mask_bias: [*, I, J] -> [*, I, 1, 1, J] for cueq + # TODO: we should cast mask to bool from the dataloader + mask_bias_local = mask_local[..., :, None, None, :].to( + dtype=torch.bool, memory_format=torch.contiguous_format + ) + elif triattn_backend_fwd == TriAttnBackend.TRIFAST: + # TRIFAST mask convention: True for invalid positions, False for valid positions + # TRIFAST mask is of shape [*, I, J] + mask_bias_local = ~(mask_local.to(dtype=torch.bool, memory_format=torch.contiguous_format)) + elif triattn_backend_fwd == TriAttnBackend.REFERENCE: + # REFERENCE mask is of shape [*, I, 1, 1, J] and it's an additive mask bias + mask_bias_local = inf * (mask_local - 1) + mask_bias_local = mask_bias_local[..., :, None, None, :].contiguous() + else: + mask_bias_local = None + + if triattn_backend_fwd == TriAttnBackend.TRIFAST: + # Convert triangle_bias from [*, I, J, H] to [*, H, I, J] for TRIFAST + triangle_bias_local = permute_final_dims(triangle_bias_local, (2, 0, 1)).contiguous() + else: + # Convert triangle_bias from [*, I, J, H] to [*, 1, H, I, J] for CUEQ and REFERENCE + triangle_bias_local = permute_final_dims(triangle_bias_local, (2, 0, 1)).unsqueeze(-4).contiguous() + weight_q_local = weight_q.to_local() + weight_k_local = weight_k.to_local() + weight_v_local = weight_v.to_local() + + # send kv_x before computing k and v to avoid sending the latters + # because in general no_heads > 1 so the linear projection expands + # the size of kv_x by no_heads times + kv_x_recv = ring_comm.comm_k_init.enqueue_to_dispatch(kv_x_local) + + # send mask along the axis_cp + if has_mask: + mask_bias_recv = ring_comm.comm_mask_init.enqueue_to_dispatch(mask_bias_local) + else: + mask_bias_recv = None + + # initialize triangle_bias comm for stage 1 + triangle_bias_recv = ring_comm.comm_bias_init0.enqueue_to_dispatch(triangle_bias_local) + + # [*, Q/K/V, H * C_hidden] + q = torch.nn.functional.linear(q_x_local, weight_q_local) + # [*, Q/K, H, C_hidden] + q = q.view(q.shape[:-1] + (no_heads, -1)) + + if triattn_backend_fwd == TriAttnBackend.TRIFAST: + # [B, I, Q/K, H, C_hidden] --> [B, H, I, Q/K, C_hidden] + q = permute_final_dims(q, (2, 0, 1, 3)) + else: + # Both CUEQ and REFERENCE expect q to be of shape [*, H, Q/K, C_hidden] + q = q.transpose(-2, -3) + + if apply_scale: + q *= q_scale + + ring_comm.comm_k_init.wait_until_finished() + # kv_x_recv is ready + + # compute q, k and v + # kT == k.T is returned + # [*, Q/K/V, H * C_hidden] + k = torch.nn.functional.linear(kv_x_recv, weight_k_local) + # [*, Q/K, H, C_hidden] + k = k.view(k.shape[:-1] + (no_heads, -1)) + if triattn_backend_fwd == TriAttnBackend.CUEQ: + # cueq expects k instead of its transpose + # get kT (virtually k) of shape [*, H, K, C_hidden] + kT = k.transpose(-2, -3).contiguous() + elif triattn_backend_fwd == TriAttnBackend.TRIFAST: + # [B, I, Q/K, H, C_hidden] --> [B, H, I, Q/K, C_hidden] + kT = permute_final_dims(k, (2, 0, 1, 3)).contiguous() + elif triattn_backend_fwd == TriAttnBackend.REFERENCE: + # get k.T of shape [*, H, C_hidden, K] + kT = permute_final_dims(k, (1, 2, 0)) + # torch.distributed data transfer requires contiguous tensor + kT = kT.contiguous() + + # wait and initialize triangle_bias comm for stage 2 + ring_comm.comm_bias_init0.wait_until_finished() + + # Due to the two-stage communication, triangle_bias_recv needs an additional + # buffer. To make the parent module RingMultiHeadTriangleAttention satisfy + # the requirements of not modifying the input tensor, an additional copy + # of triangle_bias is required + triangle_bias_recv1 = ring_comm.comm_bias_init1.enqueue_to_dispatch(triangle_bias_recv) + + # [*, Q/K/V, H * C_hidden] + v = torch.nn.functional.linear(kv_x_recv, weight_v_local) + # [*, Q/K, H, C_hidden] + v = v.view(v.shape[:-1] + (no_heads, -1)) + if triattn_backend_fwd == TriAttnBackend.TRIFAST: + # [B, I, Q/K, H, C_hidden] --> [B, H, I, Q/K, C_hidden] + v = permute_final_dims(v, (2, 0, 1, 3)) + else: + # Both CUEQ and REFERENCE expect v to be of shape [*, H, Q/K, C_hidden] + v = v.transpose(-2, -3) + # torch.distributed data transfer requires contiguous tensor + v = v.contiguous() + + # initial triangle_bias should be ready by now + # TODO: move this wait inside the loop right before it's needed + if has_mask: + ring_comm.comm_mask_init.wait_until_finished() + ring_comm.comm_bias_init1.wait_until_finished() + + # triangle_bias_recv1 is ready + # mask_bias_recv is ready + + i_ready = 0 + i_recv = i_ready ^ 1 + kT_buffer = [kT, torch.empty_like(kT)] + v_buffer = [v, torch.empty_like(v)] + triangle_bias_buffer = [triangle_bias_recv1, torch.empty_like(triangle_bias_recv)] + if has_mask: + mask_bias_buffer = [mask_bias_recv, torch.empty_like(mask_bias_recv)] + else: + mask_bias_buffer = None + o, lse_m, amax = None, None, None + n_steps = ring_comm.group_layout.shape[ring_comm.axis_cp] + for step in range(n_steps): + # launch send/recv for the next round + # This is done even for the last step to enable saving the tensors for the backward pass + kT_buffer[i_recv] = ring_comm.comm_k.enqueue_to_dispatch(kT_buffer[i_ready], kT_buffer[i_recv]) + v_buffer[i_recv] = ring_comm.comm_v.enqueue_to_dispatch(v_buffer[i_ready], v_buffer[i_recv]) + if has_mask: + mask_bias_buffer[i_recv] = ring_comm.comm_mask.enqueue_to_dispatch( + mask_bias_buffer[i_ready], mask_bias_buffer[i_recv] + ) + triangle_bias_buffer[i_recv] = ring_comm.comm_bias.enqueue_to_dispatch( + triangle_bias_buffer[i_ready], triangle_bias_buffer[i_recv] + ) + # proceed with current k, v and triangle_bias + # NOTE: B is batch size; H is head; I and J are pair repr N_token + # C_hidden is q/k/v embedding dim; Q/K/V are attention dim (N_token) + # kT.shape == [*, H, C_hidden, K] (default torch variant) or [*, H, K, C_hidden] (cueq variant) + # q.shape == [*, H, Q, C_hidden] + + if triattn_backend_fwd == TriAttnBackend.CUEQ: + o_block, lse_block, amax_block = cueq_triangle.triangle_attention( + q, + kT_buffer[i_ready], + v_buffer[i_ready], + triangle_bias_buffer[i_ready], + mask_bias_buffer[i_ready] if has_mask else None, + scale=1.0 if apply_scale else q_scale, + return_aux=True, + ) + amax_block = amax_block.unsqueeze(-1) + lse_m_block = lse_block.unsqueeze(-1) - amax_block + # cueq TriAttn returns lse and amax in FP32 so we need to cast them back + # but the lse and amax returned can contain 1e9 values, which can overflow + # fp16 so we need to clamp them to the max value of fp16 + if q.dtype == torch.float16: + amax_block = amax_block.clamp( + min=torch.finfo(torch.float16).min, max=torch.finfo(torch.float16).max + ) + lse_m_block = lse_m_block.clamp( + min=torch.finfo(torch.float16).min, max=torch.finfo(torch.float16).max + ) + # TODO: verify if there is actually benefit in using FP32 lse_m and amax + # in tiled_softmax_attention_update + amax_block = amax_block.to(dtype=q.dtype) + lse_m_block = lse_m_block.to(dtype=q.dtype) + elif triattn_backend_fwd == TriAttnBackend.TRIFAST: + # has_mask == False would have raised before reaching here with TRIFAST backend + # o_block is of shape [B, H, I, K, C_hidden] + # lse_block is of shape [B, H, I, Q] + o_block, lse_block = trifast_triangle_attention( + q, kT_buffer[i_ready], v_buffer[i_ready], triangle_bias_buffer[i_ready], mask_bias_buffer[i_ready] + ) + amax_block = None + # TRIFAST returns lse directly (not lse - amax) in FP32. This is known to cause accuracy issues + # due to lower dynamic range in tiled softmax update. + # Pad a singleton K axis to lse_block to be used inside tiled_softmax_attention_update + # Here we don't need to canonicalize the shape of o and lse to the REFERENCE backend's shape + # because the tiled_softmax_attention_update effectively treats the leading axes as virtual + # batch axes. + lse_m_block = lse_block.to(dtype=q.dtype).unsqueeze(-1) + elif triattn_backend_fwd == TriAttnBackend.REFERENCE: + # [B, I, H, Q, K] + a = torch.matmul(q, kT_buffer[i_ready]) + + # biases[0].shape is [B, I, 1, 1, J] + if has_mask: + a += mask_bias_buffer[i_ready] + + # triangle_bias.shape is [B, 1, H, I, J] + a += triangle_bias_buffer[i_ready] + + # The following tries to stabilize pure -1e9 chunk + # in a, which could happen towards the last few + # chunks of the padding. This is done by keeping + # track of the max of a and the lse - max(a) during + # the accumulation. The tiled_softmax_attention_update will + # first attempt to compute amax_block - amax, which + # tends to cancel each other out, before updating + # the accumulators + amax_block = a.amax(dim=-1, keepdim=True) + # [*, H, Q, 1] + lse_m_block = torch.logsumexp(a - amax_block, dim=-1, keepdim=True).to(dtype=a.dtype) + # [*, H, Q, K] + a = torch.softmax(a, dim=-1) + + # [*, H, Q, C_hidden] + o_block = torch.matmul(a, v_buffer[i_ready]) + + o, lse_m, amax = tiled_softmax_attention_update(o_block, lse_m_block, amax_block, o, lse_m, amax) + # wait until next block is ready + ring_comm.comm_k.wait_until_finished() + ring_comm.comm_v.wait_until_finished() + ring_comm.comm_bias.wait_until_finished() + if has_mask: + ring_comm.comm_mask.wait_until_finished() + i_ready ^= 1 + i_recv ^= 1 + # NOTE: The last step's communication is done to reset the data's ownership to its initial state + # at the beginning of the forward pass + # NOTE: Although backward pass doesn't need to do block-wise renormalization with amax but only + # with lse, we need to subtract the masked attention matrix first by amax then lse_m to avoid + # numerical instability so we need to save both lse_m and amax terms for backward pass + if ( + q_x.requires_grad + or kv_x.requires_grad + or triangle_bias.requires_grad + or weight_q.requires_grad + or weight_k.requires_grad + or weight_v.requires_grad + ): + # This should be enough to avoid saving tensors when no gradient is needed + # TODO: tailor to the individual gradient in terms which intermediate tensors are saved + # but the challenge is how to avoid deadlocking in the backward pass in branching + if triattn_backend_fwd == triattn_backend_bwd: + # when fwd and bwd are the same backend, the processed tensors should satisfy + # the same shape requirements in the fwd and bwd pass + ctx.save_for_backward( + q_x_local, + kv_x_recv, + weight_q_local, + weight_k_local, + weight_v_local, + q, + kT_buffer[i_ready], + v_buffer[i_ready], + triangle_bias_buffer[i_ready], + mask_bias_buffer[i_ready] if ctx.has_mask else None, + o, + amax, + lse_m, + ) + elif triattn_backend_fwd == TriAttnBackend.CUEQ and triattn_backend_bwd == TriAttnBackend.TRIFAST: + # this implies triattn_backend == TriAttnBackend.CUEQ_FWD_TRIFAST_BWD and has_mask is True + # need to reshape some tensors as if the fwd pass was done with TRIFAST backend + ctx.save_for_backward( + q_x_local, + kv_x_recv, + weight_q_local, + weight_k_local, + weight_v_local, + q.transpose(-3, -4).contiguous(), # [B, I, H, Q, C_hidden] --> [B, H, I, Q, C_hidden] + kT_buffer[i_ready] + .transpose(-3, -4) + .contiguous(), # [B, I, H, K, C_hidden] --> [B, H, I, K, C_hidden] + v_buffer[i_ready] + .transpose(-3, -4) + .contiguous(), # [B, I, H, V, C_hidden] --> [B, H, I, V, C_hidden] + triangle_bias_buffer[i_ready].squeeze(-4).contiguous(), # [B, 1, H, I, J] --> [B, H, I, J] + ~( # trifast uses inverse boolean mask convention + mask_bias_buffer[i_ready].squeeze((-2, -3)).contiguous() + ), # [B, I, 1, 1, J] --> [B, I, J] + o.transpose(-3, -4).contiguous(), # [B, I, H, V, C_hidden] --> [B, H, I, V, C_hidden] + None, # trifast doesn't need amax + (lse_m + amax) + # trifast assume lse instead of lse_m + .transpose(-3, -4) # [B, I, H, Q, 1] --> [B, H, I, Q, 1] + .contiguous(), + ) + else: + raise NotImplementedError( + f"Unsupported backend {triattn_backend} with fwd backend {triattn_backend_fwd} and bwd backend {triattn_backend_bwd}." + ) + + if triattn_backend_fwd == TriAttnBackend.TRIFAST: + # o is of shape [B, H, I, Q, C_hidden] or [B, H, J, Q, C_hidden] (if ending mode) + # which needs to be transposed into [B, I, J, H * C_hidden] or [B, J, I, H * C_hidden] for + # consistency with input tensor axis semantics and placements + # and for downstream linear projection + o = permute_final_dims(o, (1, 2, 0, 3)).flatten(start_dim=-2) + else: + # o is of shape [B, I, H, Q, C_hidden] or [B, J, H, Q, C_hidden] (if ending mode) + # which needs to be transposed into [B, I, J, H * C_hidden] or [B, J, I, H * C_hidden] for + # consistency with input tensor axis semantics and placements + # and for downstream linear projection + o = o.transpose(-2, -3).flatten(start_dim=-2) + + # Handle transpose back for ending mode + if ctx.mode == _Mode.Ending: + # Transpose output from [B, J, I, H * C_hidden] to [B, I, J, H * C_hidden] + o = o.transpose(-2, -3) + + # Convert result back to DTensor + shape_output = ctx.input_shape[:-1] + (o.shape[-1],) + stride_output = update_exhaustive_strides(o.shape, o.stride(), shape_output) + output = DTensor.from_local( + o, + device_mesh=device_mesh_input, + placements=placements_input, + shape=shape_output, + stride=stride_output, + ) + return output + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward( + ctx, do: DTensor + ) -> tuple[DTensor, DTensor, None, DTensor, DTensor, DTensor, DTensor, None, None, None, None, None, None]: + """Backward pass for Multi-Head Triangle Attention using tri-axial + virtual all_gather, reduce and all_reduce + + This implements the backward pass for the ring communication pattern used in the forward pass. + The data ownership follows the same pattern as the forward pass, with gradients being accumulated + and communicated in a ring pattern. + + Data Ownership Diagram: + ``` + The algorithm can be summarized as the 2-tuple (i, j) indexing the input data + ownership of the triangle bias, where subsequent data ownership of other data + is constrained by matching the corresponding i index of the triangle bias if + the data contributes to the rows of the attention matrix (Q index) or the corresponding + j index of the triangle bias if the data contributes to the columns of the + attention matrix (K index). + + Initial Distribution (saved from forward) - For axis_cp=1: (See Ring2DCommTriAttn for more explanation) + + initialized: end of step 0 + ┌───┬───┬───┐ ┌───┬───┬───┐ + │0,0│1,1│2,2│ │0,2│1,0│2,1│ + ├───┼───┼───┤ upshift ├───┼───┼───┤ upshift + │0,2│1,0│2,1│ ----> │0,1│1,2│2,0│ ----> ... + ├───┼───┼───┤ ├───┼───┼───┤ + │0,1│1,2│2,0│ │0,0│1,1│2,2│ + └───┴───┴───┘ └───┴───┴───┘ + + This indexing scheme is exactly the same as the forward pass to all the associated + intermediate tensors as well as the dependent gradients, whose Q and/or K indices + strictly matching the 2-tuple shown above. The only data tensor that is distributed + as the device grid index is the forward's output's gradient, + i.e., device (i, j) owns do(i, j) + + ``` + + Args: + ctx: Context object containing saved tensors and ring communication object + do: Gradient of output tensor of shape [B, I, J, H * C_hidden] + + Returns: + Tuple of gradients for: + - dq_x: Gradient of query input + - dkv_x_recv: Gradient of key-value input + - None: Placeholder for mask gradient (non-differentiable) + - dtriangle_bias: Gradient of triangle bias + - dweight_q: Gradient of query weight + - dweight_k: Gradient of key weight + - dweight_v: Gradient of value weight + - None: Placeholder for unused gradients + """ + # Check if input is of type DTensor + if not isinstance(do, DTensor): + raise TypeError(f"Input 'do' must be of type DTensor. Got type {type(do)}.") + + # Check if input has same device mesh and placements as forward inputs + if do.device_mesh != ctx.device_mesh_input: + raise ValueError( + f"Input 'do' must have the same device mesh as the input tensors. " + f"Got device meshes {do.device_mesh} and {ctx.device_mesh_input}." + ) + if do.placements != ctx.placements_input: + raise ValueError( + f"Input 'do' must have the same placements as the input tensors. " + f"Got placements {do.placements} and {ctx.placements_input}." + ) + + # Convert gradient input to local tensor + do_local = do.to_local() + + # Handle transpose for ending mode gradient input + if ctx.mode == _Mode.Ending: + # Transpose gradient from [B, I, J, H * C_hidden] to [B, J, I, H * C_hidden] + do_local = do_local.transpose(-2, -3) + # else do_local is of shape [B, I, J, H * C_hidden] + + if ctx.triattn_backend_bwd == TriAttnBackend.TRIFAST: + # flatten and move the head dimension up to: [B, H, I/J, Q, C_hidden] + do_local = permute_final_dims( + do_local.unflatten(-1, (ctx.no_heads, ctx.c_hidden)), (2, 0, 1, 3) + ).contiguous() + else: + # flatten and move the head dimension up to: [B, I/J, H, Q, C_hidden] + do_local = do_local.unflatten(-1, (ctx.no_heads, ctx.c_hidden)).transpose(-2, -3).contiguous() + + ( + q_x, + kv_x_recv, + weight_q, + weight_k, + weight_v, + q, + kT_ready, + v_ready, + triangle_bias_ready, + mask_bias_ready, + o, + amax, + lse_m, + ) = ctx.saved_tensors + + if ctx.triattn_backend_bwd == TriAttnBackend.CUEQ: + can_run_sm100f = can_run_cueq_triattn_sm100f(q.device, q.dtype, kT_ready.shape[3], q.shape[-1], False) + # cueq uses lse instead of lse_m + amax + # it also requires the singleton K axis to be removed + lse = (lse_m + amax).squeeze(-1) + elif ctx.triattn_backend_bwd == TriAttnBackend.TRIFAST: + # trifast lse_m is actually lse, which must have shape [B, H, I, Q] + lse = lse_m.squeeze(-1) + + ring_comm: Ring2DCommTriAttn = ctx.ring_comm + # o is saved from the forward pass and + # is of shape [B, I, H, Q, C_hidden] + + if ctx.triattn_backend_bwd == TriAttnBackend.REFERENCE: + # Only needed by REFERENCE backend: doT is of shape [B, I, H, C_hidden, Q] + doT = do_local.transpose(-2, -1) + # Only needed by REFERENCE backend: qT is of shape [B, I, H, Q, C_hidden] for dkT computation + qT = q.transpose(-2, -1) + dq = torch.empty_like(q) + dkT = torch.empty_like(kT_ready) + if ctx.triattn_backend_bwd in (TriAttnBackend.CUEQ, TriAttnBackend.TRIFAST): + # this is virtually dv as will be returned from cueq triangle attention + # For CUEQ: dvT is of shape [B, I, H, K, C_hidden] + # For TRIFAST: dvT is of shape [B, H, I, K, C_hidden] + dvT = torch.empty_like(v_ready, memory_format=torch.contiguous_format) + elif ctx.triattn_backend_bwd == TriAttnBackend.REFERENCE: + # Instead of transposing the attention matrix, we transpose + # "do" to compute dvT instead + # dvT is of shape [B, I, H, C_hidden, K] + dvT = torch.empty( + v_ready.shape[:-2] + (v_ready.shape[-1], v_ready.shape[-2]), dtype=v_ready.dtype, device=v_ready.device + ) + dtriangle_bias = torch.empty_like(triangle_bias_ready) + + if ctx.triattn_backend_bwd in (TriAttnBackend.CUEQ, TriAttnBackend.TRIFAST): + # d is computed internally in cueq/trifast triangle attention + d = None + elif ctx.triattn_backend_bwd == TriAttnBackend.REFERENCE: + # d.shape is [B, I, H, Q, 1] + d = torch.linalg.vecdot(do_local, o, dim=-1).unsqueeze(-1) + # prevent d from promoting to fp32 if do_local is fp32 + d = d.to(dtype=o.dtype) + + i_ready = 0 + i_recv = i_ready ^ 1 + kT_buffer = [kT_ready, torch.empty_like(kT_ready)] + dkT_buffer = [dkT, torch.empty_like(dkT)] + v_buffer = [v_ready, torch.empty_like(v_ready)] + dvT_buffer = [dvT, torch.empty_like(dvT)] + triangle_bias_buffer = [triangle_bias_ready, torch.empty_like(triangle_bias_ready)] + dtriangle_bias_buffer = [dtriangle_bias, torch.empty_like(dtriangle_bias)] + if ctx.has_mask: + mask_bias_buffer = [mask_bias_ready, torch.empty_like(mask_bias_ready)] + else: + mask_bias_buffer = None + apply_scale = ctx.apply_scale + q_scale = ctx.q_scale + n_steps = ring_comm.group_layout.shape[ring_comm.axis_cp] + for step in range(n_steps): + is_last_step = step == n_steps - 1 + if not is_last_step: + # launch send/recv for the next round + # This is done even for the last step to enable saving the tensors for the backward pass + kT_buffer[i_recv] = ring_comm.comm_k.enqueue_to_dispatch(kT_buffer[i_ready], kT_buffer[i_recv]) + v_buffer[i_recv] = ring_comm.comm_v.enqueue_to_dispatch(v_buffer[i_ready], v_buffer[i_recv]) + if ctx.has_mask: + mask_bias_buffer[i_recv] = ring_comm.comm_mask.enqueue_to_dispatch( + mask_bias_buffer[i_ready], mask_bias_buffer[i_recv] + ) + triangle_bias_buffer[i_recv] = ring_comm.comm_bias.enqueue_to_dispatch( + triangle_bias_buffer[i_ready], triangle_bias_buffer[i_recv] + ) + + # proceed with current k, v and triangle_bias + # NOTE: B is batch size; H is head; I and J are pair repr N_token + # C_hidden is q/k/v embedding dim; Q/K/V are attention dim (N_token) + # kT.shape == [*, H, C_hidden, K] (default torch variant) or [*, H, K, C_hidden] (cueq variant) + # or [*, H, I, Q/K, C_hidden] (trifast variant) + # q.shape == [*, H, Q, C_hidden] (default torch or cueq variant) or [*, H, I, Q/K, C_hidden] (trifast variant) + + if ctx.triattn_backend_bwd == TriAttnBackend.CUEQ: + # dkT_block.shape is [B, I, H, K, C_hidden] + # dvT_block.shape is [B, I, H, K, C_hidden] + if can_run_sm100f: + # SM100f kernel accepts bias in the same dtype as q; + # keep the buffer's native dtype (no-op cast) and let + # cuEq's _convert_bias handle any necessary conversion. + bias_dtype = triangle_bias_buffer[i_ready].dtype + else: + # Non-SM100f cuEq backward requires float32 bias and lse + bias_dtype = torch.float32 + dq_block, dkT_block, dvT_block, dtriangle_bias_block_fp32 = ( + torch.ops.cuequivariance.triangle_attention_bwd( + do_local, + o, + q, + kT_buffer[i_ready], + v_buffer[i_ready], + triangle_bias_buffer[i_ready].to(dtype=bias_dtype), + mask_bias_buffer[i_ready] if ctx.has_mask else None, + lse.to(dtype=torch.float32), + 1.0 if apply_scale else q_scale, + ) + ) + dtriangle_bias_block = dtriangle_bias_block_fp32.to(dtype=triangle_bias_buffer[i_ready].dtype) + elif ctx.triattn_backend_bwd == TriAttnBackend.TRIFAST: + # A fake dmask tensor is also returned by trifast_triangle_attention_bwd as the last return value + dq_block, dkT_block, dvT_block, dtriangle_bias_block, _ = trifast_triangle_attention_bwd( + do_local, + q, + kT_buffer[i_ready], + v_buffer[i_ready], + triangle_bias_buffer[i_ready], + o, + lse.to(dtype=torch.float32), + mask_bias_buffer[i_ready], + ) + elif ctx.triattn_backend_bwd == TriAttnBackend.REFERENCE: + # [B, I, H, Q, K] + a = torch.matmul(q, kT_buffer[i_ready]) + + # biases[0].shape is [B, I, 1, 1, J] + if ctx.has_mask: + a += mask_bias_buffer[i_ready] + + # triangle_bias.shape is [B, 1, H, I, J] + a += triangle_bias_buffer[i_ready] + + # amax and lse_m shape is [B, I, H, Q, 1] + a -= amax + # lse_m is fp32 from logsumexp in fwd autocast, so we need to cast it to match a's dtype + # to avoid promoting a to fp32 + a -= lse_m.to(dtype=a.dtype) + + a = torch.exp(a) + + # dvT_block.shape is [B, I, H, C_hidden, K] + dvT_block = torch.matmul(doT, a) + + # da.shape is [B, I, H, Q, K] + da = torch.matmul(do_local, v_buffer[i_ready].transpose(-1, -2)) + # a is no longer needed so we can repurpose its memory for ds + # ds.shape is [B, I, H, Q, K] + ds = a + ds *= da - d + + # TODO: check if the cublas/cutlass backend is optimal with the + # non-ideal memory layout of kT_buffer[i_ready].transpose(-2, -1) + dq_block = torch.matmul(ds, kT_buffer[i_ready].transpose(-2, -1)) + + # dkT_block.shape is [B, I, H, C_hidden, K] + dkT_block = torch.matmul(qT, ds) + + # dtriangle_bias_block.shape is [B, 1, H, Q, K] + dtriangle_bias_block = ds.sum(dim=-4, keepdim=True, dtype=triangle_bias_buffer[i_ready].dtype) + + if step == 0: + dvT_buffer[i_ready] = dvT_block + dq = dq_block + dkT_buffer[i_ready] = dkT_block + dtriangle_bias_buffer[i_ready] = dtriangle_bias_block + else: + dq += dq_block + ring_comm.comm_dv.wait_until_finished() + dvT_buffer[i_ready] += dvT_block + ring_comm.comm_dk.wait_until_finished() + dkT_buffer[i_ready] += dkT_block + ring_comm.comm_dbias.wait_until_finished() + dtriangle_bias_buffer[i_ready] += dtriangle_bias_block + + dvT_buffer[i_recv] = ring_comm.comm_dv.enqueue_to_dispatch(dvT_buffer[i_ready], dvT_buffer[i_recv]) + dkT_buffer[i_recv] = ring_comm.comm_dk.enqueue_to_dispatch(dkT_buffer[i_ready], dkT_buffer[i_recv]) + dtriangle_bias_buffer[i_recv] = ring_comm.comm_dbias.enqueue_to_dispatch( + dtriangle_bias_buffer[i_ready], dtriangle_bias_buffer[i_recv] + ) + + if not is_last_step: + # wait until next block is ready + ring_comm.comm_k.wait_until_finished() + ring_comm.comm_v.wait_until_finished() + ring_comm.comm_bias.wait_until_finished() + if ctx.has_mask: + ring_comm.comm_mask.wait_until_finished() + i_ready ^= 1 + i_recv ^= 1 + + # dv, dkT and dtriangle_bias need the extra round of send/recv so that the + # data's ownership is transferred to the initial state at the beginning of the forward pass + i_ready ^= 1 + i_recv ^= 1 + ring_comm.comm_dv.wait_until_finished() + ring_comm.comm_dk.wait_until_finished() + ring_comm.comm_dbias.wait_until_finished() + + dkT = dkT_buffer[i_ready] + if ctx.triattn_backend_bwd == TriAttnBackend.CUEQ: + # dv is already of shape [B, I, H, K, C_hidden] + dv = dvT_buffer[i_ready] + elif ctx.triattn_backend_bwd == TriAttnBackend.TRIFAST: + # [B, H, I, K, C_hidden] --> [B, I, H, K, C_hidden] + dv = dvT_buffer[i_ready].transpose(-4, -3) + elif ctx.triattn_backend_bwd == TriAttnBackend.REFERENCE: + # dvT_buffer[i_ready] is of shape [B, I, H, C_hidden, K] + # reshaped to dv of shape [B, I, H, K, C_hidden] + dv = dvT_buffer[i_ready].transpose(-2, -1).contiguous() + dtriangle_bias = dtriangle_bias_buffer[i_ready] + + # the input tensors are sharded according to what's returned from _RingMHTAFunctorImpl.backward + # Here, q_x and dq didn't go through shuffling and its data ownership remain stationary + # kv_x_recv, dkT and dv are shuffled according to ring_comm's comm_k_init + # while dtriangle_bias is shuffled according to ring_comm's comm_bias_init0 and comm_bias_init1 + # The strategy here is to perform local computation first to get dkv_x, which is then shuffled, + # because dkv_x in general is smaller in size compared to dkvT and dvT due to no_heads > 1 + + # q_x and kv_x_recv are of shape [B, I, J, C_hidden] + + dtriangle_bias_recv = ring_comm.comm_dbias_final0.enqueue_to_dispatch(dtriangle_bias) + if ctx.triattn_backend_bwd in (TriAttnBackend.CUEQ, TriAttnBackend.REFERENCE): + # [B, I, Q, H, C_hidden] + dq_reshaped = dq.transpose(-2, -3) + elif ctx.triattn_backend_bwd == TriAttnBackend.TRIFAST: + # [B, H, I, Q, C_hidden] --> [B, I, Q, H, C_hidden] + dq_reshaped = permute_final_dims(dq, (1, 2, 0, 3)) + # [B, I, Q, H * C_hidden] + dq_reshaped = dq_reshaped.flatten(start_dim=-2) + if apply_scale: + dq_reshaped = dq_reshaped * q_scale + dq_x = torch.einsum("...z, zc -> ...c", dq_reshaped, weight_q) + ring_comm.comm_dbias_final0.wait_until_finished() + dtriangle_bias = ring_comm.comm_dbias_final1.enqueue_to_dispatch(dtriangle_bias_recv, dtriangle_bias) + + # dweight_q is of shape [*, H * C_hidden, C_hidden] + dweight_q = torch.einsum("...z, ...c -> zc", dq_reshaped, q_x) + + dweight_q_dtensor = DTensor.from_local( + dweight_q, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_dweights, + shape=ctx.weight_q_shape, + stride=ctx.weight_q_stride, + ) + + # reduce the dimensionality of H * C_hidden to C_hidden + # and sum up contribution from dv and dk before send/recv dkv_x + # [B, I, K, H, C_hidden] + dv_reshaped = dv.transpose(-2, -3) + # [B, I, K, H * C_hidden] + dv_reshaped = dv_reshaped.flatten(start_dim=-2) + + if ctx.triattn_backend_bwd == TriAttnBackend.CUEQ: + # dkT is of shape [B, I, H, K, C_hidden] + # dk_reshaped is of shape [B, I, K, H, C_hidden] + dk_reshaped = dkT.transpose(-2, -3) + elif ctx.triattn_backend_bwd == TriAttnBackend.TRIFAST: + # [B, H, I, K, C_hidden] --> [B, I, K, H, C_hidden] + dk_reshaped = permute_final_dims(dkT, (1, 2, 0, 3)) + elif ctx.triattn_backend_bwd == TriAttnBackend.REFERENCE: + # dkT is of shape [B, I, H, C_hidden, K] + # dk_reshaped is of shape [B, I, K, H, C_hidden] + dk_reshaped = permute_final_dims(dkT, (2, 0, 1)) + + # [B, I, K, H * C_hidden] + dk_reshaped = dk_reshaped.flatten(start_dim=-2) + + # kv_x is broadcasted to perform linear layer to get v and k + # so the gradients of kv_x need to be summed up from both contributions + dkv_x = torch.einsum("...z, zc -> ...c", dv_reshaped, weight_v) + torch.einsum( + "...z, zc -> ...c", dk_reshaped, weight_k + ) + + dkv_x_recv = ring_comm.comm_dk_final.enqueue_to_dispatch(dkv_x) + + dweight_v = torch.einsum("...z, ...c -> zc", dv_reshaped, kv_x_recv) + dweight_v_dtensor = DTensor.from_local( + dweight_v, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_dweights, + shape=ctx.weight_v_shape, + stride=ctx.weight_v_stride, + ) + + dweight_k = torch.einsum("...z, ...c -> zc", dk_reshaped, kv_x_recv) + dweight_k_dtensor = DTensor.from_local( + dweight_k, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_dweights, + shape=ctx.weight_k_shape, + stride=ctx.weight_k_stride, + ) + + ring_comm.comm_dbias_final1.wait_until_finished() + ring_comm.comm_dk_final.wait_until_finished() + + # Handle transpose back for ending mode gradients + if ctx.mode == _Mode.Ending: + # Transpose gradients from [*, J, I, C] back to [*, I, J, C] + dq_x = dq_x.transpose(-2, -3).contiguous() + dkv_x_recv = dkv_x_recv.transpose(-2, -3).contiguous() + # Transpose dtriangle_bias from [*, 1, H, J, I] to [*, 1, H, I, J] + dtriangle_bias = dtriangle_bias.transpose(-1, -2) + + # Convert gradients back to DTensors + dq_x_dtensor = DTensor.from_local( + dq_x, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=ctx.input_shape, + stride=ctx.input_stride, + ) + dkv_x_recv_dtensor = DTensor.from_local( + dkv_x_recv, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=ctx.input_shape, + stride=ctx.input_stride, + ) + if ctx.triattn_backend_bwd in (TriAttnBackend.CUEQ, TriAttnBackend.REFERENCE): + # Convert dtriangle_bias from [*, 1, H, I, J] back to [*, I, J, H] for output + dtriangle_bias_reshaped = permute_final_dims(dtriangle_bias.squeeze(-4), (1, 2, 0)).contiguous() + elif ctx.triattn_backend_bwd == TriAttnBackend.TRIFAST: + # [B, H, I, J] -> [B, I, J, H] + dtriangle_bias_reshaped = permute_final_dims(dtriangle_bias, (1, 2, 0)).contiguous() + dtriangle_bias_dtensor = DTensor.from_local( + dtriangle_bias_reshaped, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=ctx.triangle_bias_shape, + stride=ctx.triangle_bias_stride, + ) + + return ( + dq_x_dtensor, + dkv_x_recv_dtensor, + None, + dtriangle_bias_dtensor, + dweight_q_dtensor, + dweight_k_dtensor, + dweight_v_dtensor, + None, + None, + None, + None, + None, + None, + ) + + +class RingMultiHeadTriangleAttention(nn.Module): + """ + Multi-head triangle attention using a ring communication pattern + (see the underlying autograd.Function for detail explanation of the algorithm) + """ + + def __init__( + self, + layer: Attention, + device_mesh: DeviceMesh, + ring_comm: Ring2DCommTriAttn, + inf: float = 1e9, + ): + """ + Args: + layer: + Serial Attention instance to convert to distributed version + device_mesh: + Device mesh for distributed computation across multiple GPUs + ring_comm: + The communication object to use for distributed computation with ring attention + inf: + Infinity value for mask bias computation + """ + super().__init__() + + self.c_q = layer.c_q + self.c_k = layer.c_k + self.c_v = layer.c_v + self.c_hidden = layer.c_hidden + self.no_heads = layer.no_heads + self.gating = layer.gating + self.ring_comm = ring_comm + self.device_mesh = device_mesh + self.inf = inf + + # Convert linear layers to DTensor-based counterparts + # linear_{q,k,v} mapped to LinearParamsReplicated + self.linear_q = LinearParamsReplicated(layer.linear_q, device_mesh) + self.linear_k = LinearParamsReplicated(layer.linear_k, device_mesh) + self.linear_v = LinearParamsReplicated(layer.linear_v, device_mesh) + + # linear_{o,g} mapped to LinearParamsReplicatedNoAutoCastBF16 + self.linear_o = LinearParamsReplicatedNoAutoCastBF16(layer.linear_o, device_mesh) + + self.linear_g = None + if self.gating and layer.linear_g is not None: + self.linear_g = LinearParamsReplicatedNoAutoCastBF16(layer.linear_g, device_mesh) + + def forward( + self, + q_x: DTensor, + kv_x: DTensor, + biases: list[DTensor], + triattn_backend: TriAttnBackend = TriAttnBackend.REFERENCE, + ) -> DTensor: + """ + Args: + q_x: + [*, Q, C_q] query data + kv_x: + [*, K, C_k] key data + biases: + List containing mask (can be None) and triangle_bias + triattn_backend: + Triangular attention backend to use + Returns + [*, Q, C_q] attention update + """ + # Linear layer weights are already DTensors from LinearParamsReplicated + # compute q, k and v and launch initial shifting + # of biases + # kT == k.T is returned + # Handle optional mask + mask = biases[0] if biases[0] is not None else None + triangle_bias = biases[1] + + o = _RingMultiHeadTriangleAttentionImpl.apply( + q_x, + kv_x, + mask, + triangle_bias, + self.linear_q.weight, + self.linear_k.weight, + self.linear_v.weight, + self.no_heads, + self.c_hidden, + self.ring_comm, + self.inf, + triattn_backend, + ) + + if self.linear_g is not None: + # [B, I, J, H * C_hidden] + g = self.linear_g(q_x) + o = sigmoid_gate(o, g) + + # [*, Q, C_q] + o = self.linear_o(o) + + return o + + +class _Mode(Enum): + Starting = auto() + Ending = auto() + + +class TriangleAttention(nn.Module): + """Distributed triangle attention layer. + + This layer implements a distributed version of the triangle attention operation, + which is used in attention mechanisms for protein structure prediction and other applications + requiring pairwise feature interactions. + + The layer performs the following operations: + 1. Layer normalization of input pairwise features + 2. Linear projection to create triangle bias + 3. Distributed triangle attention computation using ring communication + 4. Transpose operations for ending node configuration + + Parameters + ---------- + mode : _Mode + Whether this is a starting or ending triangle attention node. + layer : SerialTriangleAttentionStartingNode | SerialTriangleAttentionEndingNode + The serial triangle attention layer to convert to distributed version. + Used to initialize weights and normalization parameters. + device_mesh : DeviceMesh + The device mesh for distributed computation across multiple GPUs. + comm : Ring2DCommTriAttn + Ring communication object for efficient distributed triangle attention computation. + """ + + def __init__( + self, + mode: _Mode, + layer: SerialTriangleAttentionStartingNode | SerialTriangleAttentionEndingNode, + device_mesh: DeviceMesh, + comm: Ring2DCommTriAttn, + ) -> None: + """Initialize the distributed triangle attention layer.""" + super().__init__() + self.device_mesh = device_mesh + self.ring_comm = comm + self.mode = mode + + # Store layer parameters for distributed computation + self.c_in = layer.c_in + self.c_hidden = layer.c_hidden + self.no_heads = layer.no_heads + + self.inf = layer.inf + + # Replicate parameters across the device mesh + self.layer_norm = LayerNormParamsReplicatedNoAutoCastBF16(layer.layer_norm, self.device_mesh) + self.linear = LinearParamsReplicatedNoAutoCastBF16(layer.linear, self.device_mesh) + + # Use the ring-based multi-head attention for distributed computation + self.mha = RingMultiHeadTriangleAttention( + layer.mha, + self.device_mesh, + self.ring_comm, + self.inf, + ) + + # Validate mode consistency with ring comm + if mode == _Mode.Starting: + if not isinstance(layer, SerialTriangleAttentionStartingNode): + raise ValueError(f"StartingNode mode is inconsistent with layer type {type(layer)}") + if self.ring_comm.axis_cp != 1: + raise ValueError(f"StartingNode mode is inconsistent with ring_comm.axis_cp {self.ring_comm.axis_cp}") + elif mode == _Mode.Ending: + if not isinstance(layer, SerialTriangleAttentionEndingNode): + raise ValueError(f"EndingNode mode is inconsistent with layer type {type(layer)}") + if self.ring_comm.axis_cp != 0: + raise ValueError(f"EndingNode mode is inconsistent with ring_comm.axis_cp {self.ring_comm.axis_cp}") + else: + raise ValueError(f"Invalid mode {mode}") + + def forward( + self, x: DTensor, mask: Optional[DTensor] = None, triattn_backend: TriAttnBackend = TriAttnBackend.REFERENCE + ) -> DTensor: + """Forward pass of the distributed triangle attention layer. + + Parameters + ---------- + x : DTensor + Input pairwise tensor with shape (B, I, J, C_in). + Must be sharded on dimensions 1 and 2. + mask : DTensor, optional + Mask tensor with shape (B, I, J) indicating valid positions. + Must be sharded on dimensions 1 and 2. If None, creates a mask of all ones. + triattn_backend : TriAttnBackend + Triangular attention backend to use + Returns + ------- + DTensor + Output pairwise tensor with shape (B, I, J, C_in). + """ + # Validate input types + if not isinstance(x, DTensor): + raise TypeError(f"Input 'x' must be of type DTensor. Got type {type(x)}.") + + if mask is not None: + if not isinstance(mask, DTensor): + raise TypeError(f"Input 'mask' must be of type DTensor or None. Got type {type(mask)}.") + if mask.shape != x.shape[:-1]: + raise ValueError( + f"Input tensor 'mask' must have the same shape as the first 3 dimensions of 'x'. " + f"Got mask shape: {mask.shape} vs x shape[:3]: {x.shape[:3]}" + ) + + if triattn_backend in (TriAttnBackend.CUEQ, TriAttnBackend.CUEQ_FWD_TRIFAST_BWD) and not cueq_is_installed: + raise ValueError( + "cuequivariance_torch is not installed. For Triangle Attention support, " + "install using: pip install cuequivariance_ops_torch_cu13== cuequivariance_torch== " + "where the 'version' tag can be found in the pyproject.toml file" + ) + if ( + triattn_backend in (TriAttnBackend.TRIFAST, TriAttnBackend.CUEQ_FWD_TRIFAST_BWD) + and not trifast_is_installed + ): + raise ValueError( + "trifast is not installed. For Triangle Attention support, install using: pip install trifast" + ) + if triattn_backend == TriAttnBackend.CUEQ_FWD_TRIFAST_BWD and x.dtype != torch.float32: + raise ValueError(f"CUEQ_FWD_TRIFAST_BWD is only intended for FP32 usage. Got x.dtype {x.dtype}") + + # Normalize input - mask creation moved to MHA implementation + x = self.layer_norm(x) + + # Compute triangle bias + triangle_bias = self.linear(x) + + # Prepare biases for attention computation - mask_bias computation moved to MHA implementation + # Regardless of triattn_backend, the binary mask is passed to the MHA implementation + # where the default torch variant will convert it internally to a mask bias while + # the underlying cueq call will use the binary mask directly + biases = [mask, triangle_bias] + + # Apply distributed multi-head attention - transpose logic moved inside MHA implementation + output = self.mha(q_x=x, kv_x=x, biases=biases, triattn_backend=triattn_backend) + + return output + + +class TriangleAttentionStartingNode(TriangleAttention): + """Distributed triangle attention starting node layer.""" + + def __init__( + self, + layer: SerialTriangleAttentionStartingNode, + device_mesh: DeviceMesh, + comm: Ring2DCommTriAttn, + ) -> None: + """Initialize the distributed triangle attention starting node layer. + + Parameters + ---------- + layer : SerialTriangleAttentionStartingNode + The serial triangle attention layer to convert to distributed version. + device_mesh : DeviceMesh + The device mesh for distributed computation across multiple GPUs. + comm : Ring2DCommTriAttn + Ring communication object for efficient distributed triangle attention computation. + """ + if not layer.starting: + raise ValueError("Serial layer must be configured as starting=True for TriangleAttentionStartingNode") + super().__init__(_Mode.Starting, layer, device_mesh, comm) + + +class TriangleAttentionEndingNode(TriangleAttention): + """Distributed triangle attention ending node layer.""" + + def __init__( + self, + layer: SerialTriangleAttentionEndingNode, + device_mesh: DeviceMesh, + comm: Ring2DCommTriAttn, + ) -> None: + """Initialize the distributed triangle attention ending node layer. + + Parameters + ---------- + layer : SerialTriangleAttentionEndingNode + The serial triangle attention layer to convert to distributed version. + device_mesh : DeviceMesh + The device mesh for distributed computation across multiple GPUs. + comm : Ring2DCommTriAttn + Ring communication object for efficient distributed triangle attention computation. + """ + if layer.starting: + raise ValueError("Serial layer must be configured as starting=False for TriangleAttentionEndingNode") + super().__init__(_Mode.Ending, layer, device_mesh, comm) diff --git a/src/boltz/distributed/model/layers/triangular_mult.py b/src/boltz/distributed/model/layers/triangular_mult.py new file mode 100644 index 000000000..adc3f5044 --- /dev/null +++ b/src/boltz/distributed/model/layers/triangular_mult.py @@ -0,0 +1,760 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from enum import Enum, auto +from typing import Tuple + +import torch +from torch import nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Shard + +from boltz.distributed.comm import Ring2DComm +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.linear import LinearParamsReplicated +from boltz.distributed.model.layers.sigmoid_gate import sigmoid_gate +from boltz.distributed.utils import update_exhaustive_strides +from boltz.model.layers.triangular_mult import ( + TriangleMultiplicationIncoming as SerialTriangleMultiplicationIncoming, +) +from boltz.model.layers.triangular_mult import ( + TriangleMultiplicationOutgoing as SerialTriangleMultiplicationOutgoing, +) + + +class _XposeArgs(Enum): + lhs = auto() + rhs = auto() + + +def _distributed_bmm( + lhs: torch.Tensor, + rhs: torch.Tensor, + comm: Ring2DComm, + permute_lhs: tuple[int, ...] | None = None, + permute_rhs: tuple[int, ...] | None = None, + permute_out: tuple[int, ...] | None = None, + xpose_args: _XposeArgs | None = None, +) -> torch.Tensor: + """Perform distributed batch matrix multiplication using ring communication. + + This function implements a memory-efficient distributed batch matrix + multiply operation across a 2D process grid using ring communication patterns. + It computes the matrix multiplication of two tensors while minimizing memory usage + through double buffering and overlapping computation with communication. + + The algorithm works by: + 1. Optionally permuting input tensors to desired layouts + 2. Setting up communication buffers based on transpose requirements + 3. Using ring communication to rotate tensor chunks across processes + 4. Performing overlapped computation and communication with double buffering + 5. Accumulating partial results to compute the final distributed bmm + + Communication Patterns + ---------------------- + The function uses Ring2DComm to implement sophisticated communication patterns + across a 2D process grid. Below are ASCII diagrams illustrating the key phases: + + **Phase 1: Initial 2D Grid Setup** + + For a 3x3 process grid, each process (i,j) initially owns tensor chunks: + ``` + ┌─────┬─────┬─────┐ + │(0,0)│(0,1)│(0,2)│ ← Row 0 + ├─────┼─────┼─────┤ + │(1,0)│(1,1)│(1,2)│ ← Row 1 + ├─────┼─────┼─────┤ + │(2,0)│(2,1)│(2,2)│ ← Row 2 + └─────┴─────┴─────┘ + ↑ ↑ ↑ + Col 0 Col 1 Col 2 + ``` + + **Phase 2: Transpose Communication (if xpose_args specified)** + + e.g., when xpose_args=_XposeArgs.rhs, RHS tensor is transposed across the 2D grid: + ``` + Original RHS Ownership After Transpose Communication + ┌─────┬─────┬─────┐ ┌─────┬─────┬─────┐ + │ R00 │ R01 │ R02 │ │ R00 │ R10 │ R20 │ + ├─────┼─────┼─────┤ → ├─────┼─────┼─────┤ + │ R10 │ R11 │ R12 │ │ R01 │ R11 │ R21 │ + ├─────┼─────┼─────┤ ├─────┼─────┼─────┤ + │ R20 │ R21 │ R22 │ │ R02 │ R12 │ R22 │ + └─────┴─────┴─────┘ └─────┴─────┴─────┘ + ``` + When xpose_args=_XposeArgs.lhs, LHS tensor is similarly transposed across the 2D grid + + **Phase 3: Initial Ring Setup** + + Row initialization (comm_row_init): Each row i shifts left by i positions + ``` + Before Row Init After Row Init + ┌─────┬─────┬─────┐ ┌─────┬─────┬─────┐ + │ L00 │ L01 │ L02 │ ←shift 0│ L00 │ L01 │ L02 │ + ├─────┼─────┼─────┤ ├─────┼─────┼─────┤ + │ L10 │ L11 │ L12 │ ←shift 1│ L11 │ L12 │ L10 │ + ├─────┼─────┼─────┤ ├─────┼─────┼─────┤ + │ L20 │ L21 │ L22 │ ←shift 2│ L22 │ L20 │ L21 │ + └─────┴─────┴─────┘ └─────┴─────┴─────┘ + ``` + + Column initialization (comm_col_init): Each column j shifts up by j positions + ``` + Before Col Init After Col Init + ┌─────┬─────┬─────┐ ┌─────┬─────┬─────┐ + │ R00 │ R01 │ R02 │ │ R00 │ R11 │ R22 │ + ├─────┼─────┼─────┤ shift ├─────┼─────┼─────┤ + │ R10 │ R11 │ R12 │ ↑ │ R10 │ R21 │ R02 │ + ├─────┼─────┼─────┤ 0,1,2 ├─────┼─────┼─────┤ + │ R20 │ R21 │ R22 │ │ R20 │ R01 │ R12 │ + └─────┴─────┴─────┘ └─────┴─────┴─────┘ + ``` + + **Phase 4: Ring Computation Loop** + + For each iteration k in range(grid_size): + 1. Compute partial matmul: out += matmul(lhs_chunk, rhs_chunk) + 2. Ring shift both tensors for next iteration + + Ring communication pattern (each step shifts by 1): + ``` + Step 0 → Step 1 → Step 2 (back to original) + + LHS Row Shifts (left by 1): + ┌─────┬─────┬─────┐ ┌─────┬─────┬─────┐ ┌─────┬─────┬─────┐ + │ L00 │ L01 │ L02 │ → │ L01 │ L02 │ L00 │ → │ L02 │ L00 │ L01 │ + ├─────┼─────┼─────┤ ├─────┼─────┼─────┤ ├─────┼─────┼─────┤ + │ L11 │ L12 │ L10 │ → │ L12 │ L10 │ L11 │ → │ L10 │ L11 │ L12 │ + ├─────┼─────┼─────┤ ├─────┼─────┼─────┤ ├─────┼─────┼─────┤ + │ L22 │ L20 │ L21 │ → │ L20 │ L21 │ L22 │ → │ L21 │ L22 │ L20 │ + └─────┴─────┴─────┘ └─────┴─────┴─────┘ └─────┴─────┴─────┘ + + RHS Column Shifts (up by 1): + ┌─────┬─────┬─────┐ ┌─────┬─────┬─────┐ ┌─────┬─────┬─────┐ + │ R00 │ R11 │ R22 │ │ R10 │ R21 │ R02 │ │ R20 │ R01 │ R12 │ + ├─────┼─────┼─────┤ ├─────┼─────┼─────┤ ├─────┼─────┼─────┤ + │ R10 │ R21 │ R02 │ → │ R20 │ R01 │ R12 │ → │ R00 │ R11 │ R22 │ + ├─────┼─────┼─────┤ ├─────┼─────┼─────┤ ├─────┼─────┼─────┤ + │ R20 │ R01 │ R12 │ │ R00 │ R11 │ R22 │ │ R10 │ R21 │ R02 │ + └─────┴─────┴─────┘ └─────┴─────┴─────┘ └─────┴─────┴─────┘ + ``` + + **Double Buffering Strategy** + + The algorithm uses double buffering to overlap communication with computation: + ``` + Time → │ Compute │ Compute │ Compute │ + │ Buffer0 │ Buffer1 │ Buffer0 │ + │ ↓ │ ↓ │ ↓ │ + Comm → │ Send │ Send │ Send │ + │ Buffer1│ Buffer0 │ Buffer1 │ + │ Recv │ Recv │ Recv │ + │ Buffer1 │ Buffer0 │ Buffer1 │ + ``` + + This ensures that while one buffer is being used for computation, the other + buffer is being prepared through communication for the next iteration. + + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side tensor for matrix multiplication. + Typically has shape (B, ...) where B is batch dimension. + rhs : torch.Tensor + Right-hand side tensor for matrix multiplication. + Must be compatible with lhs for matrix multiplication after permutations. + comm : Ring2DComm + Ring communication object configured for 2D process grid communication. + Provides row and column communication groups for distributed computation. + permute_lhs : tuple[int, ...] | None, optional + Permutation indices to apply to lhs tensor before computation. Typically + the permutation with group the batch-like axes into leading axes and reshape + the last two axes into "N" and "K" dimensions (in the NMK notation) + If None, no permutation is applied. Default is None. + permute_rhs : tuple[int, ...] | None, optional + Permutation indices to apply to rhs tensor before computation. Typically + the permutation with group the batch-like axes into leading axes and reshape + the last two axes into "K" and "M" dimensions (in the NMK notation) + If None, no permutation is applied. Default is None. + permute_out : tuple[int, ...] | None, optional + Permutation indices to apply to output tensor after computation. Typically + the permutation reverts the resulting permutation of the output matrix + due to the permutation of the input lhs' and rhs' axes. + If None, no permutation is applied. Default is None. + xpose_args : _XposeArgs | None, optional + Specifies which tensor requires transpose communication: + - _XposeArgs.lhs: Transpose communication for left-hand side tensor + - _XposeArgs.rhs: Transpose communication for right-hand side tensor + - None: No transpose communication required + Default is None. + + Returns + ------- + torch.Tensor + Result of the distributed batch matrix multiplication. + Shape depends on input shapes and permutation arguments. + + Examples + -------- + Typical usage in triangle multiplication: + + >>> # For outgoing triangle multiplication + >>> result = _distributed_bmm( + ... lhs=tensor_a, + ... rhs=tensor_b, + ... comm=ring_comm, + ... permute_lhs=(0, 3, 1, 2), # (B, n, k, D) -> (B, D, n, k) + ... permute_rhs=(0, 3, 2, 1), # (B, m, k, D) -> (B, D, k, m) + ... permute_out=(0, 2, 3, 1), # (B, D, n, m) -> (B, n, m, D) + ... xpose_args=_XposeArgs.rhs + ... ) + """ + if permute_lhs is not None: + lhs = lhs.permute(permute_lhs) + # this enforces lhs and rhs to be a clone so that the in-place modification + # does not affect the input tensor + lhs = lhs.clone(memory_format=torch.contiguous_format) + if permute_rhs is not None: + rhs = rhs.permute(permute_rhs) + rhs = rhs.clone(memory_format=torch.contiguous_format) + + if xpose_args == _XposeArgs.lhs: + lhs_recv = comm.comm_2d_trans.enqueue_to_dispatch(lhs) + rhs_recv = rhs + rhs = torch.empty_like(rhs_recv) + elif xpose_args == _XposeArgs.rhs: + rhs_recv = comm.comm_2d_trans.enqueue_to_dispatch(rhs) + lhs_recv = lhs + lhs = torch.empty_like(lhs_recv) + elif xpose_args is None: + lhs_recv = lhs + lhs = torch.empty_like(lhs_recv) + rhs_recv = rhs + rhs = torch.empty_like(rhs_recv) + else: + raise ValueError(f"Invalid xpose_args: {xpose_args}") + + # post the comm_2d_trans.wait_until_finished() (or no wait if xpose_args is not None), + # *_recv are the correct tensors to operate on + i_ready = 0 + i_recv = i_ready ^ 1 + lhs_buffer = [lhs_recv, lhs] + rhs_buffer = [rhs_recv, rhs] + + if xpose_args is not None: + comm.comm_2d_trans.wait_until_finished() + + lhs_buffer[i_recv] = comm.comm_row_init.enqueue_to_dispatch(lhs_buffer[i_ready], lhs_buffer[i_recv]) + rhs_buffer[i_recv] = comm.comm_col_init.enqueue_to_dispatch(rhs_buffer[i_ready], rhs_buffer[i_recv]) + + i_ready ^= 1 + i_recv ^= 1 + + out = torch.zeros_like(lhs_buffer[i_ready]) + + comm.comm_row_init.wait_until_finished() + comm.comm_col_init.wait_until_finished() + + # Double buffering computation + for k_step in range(comm.group_layout.shape[1]): + lhs_ready = lhs_buffer[i_ready] + rhs_ready = rhs_buffer[i_ready] + if k_step < comm.group_layout.shape[1] - 1: + lhs_buffer[i_recv] = comm.comm_row.enqueue_to_dispatch(lhs_ready, lhs_buffer[i_recv]) + rhs_buffer[i_recv] = comm.comm_col.enqueue_to_dispatch(rhs_ready, rhs_buffer[i_recv]) + out = out + torch.matmul(lhs_ready, rhs_ready) + if k_step < comm.group_layout.shape[1] - 1: + comm.comm_row.wait_until_finished() + comm.comm_col.wait_until_finished() + i_ready = i_ready ^ 1 + i_recv = i_recv ^ 1 + + if permute_out is not None: + out = out.permute(permute_out) + return out + + +class _Direction(Enum): + Outgoing = auto() + Incoming = auto() + + +class _TriangleMultiplicationImpl(torch.autograd.Function): + """Distributed implementation of triangle multiplication using ring communication. + + This autograd function implements a memory-efficient distributed triangle multiplication + operation across a 2D process grid. The computation is parallelized using ring + communication patterns to reduce memory usage and communication overhead. + + The triangle multiplication computes: + + for Outgoing: + o = torch.einsum("bnkd,bmkd->bnmd", a * mask, b * mask) + + for Incoming: + o = torch.einsum("bknd,bkmd->bnmd", a * mask, b * mask) + + Key features: + - Distributed across a 2D grid with sharding on token dimensions (dim 1 and 2) + - Uses ring communication to rotate data chunks during computation + - Memory-efficient implementation that avoids materializing full tensors + - Supports gradient computation through custom backward pass + + Notes + ----- + Input tensors must be DTensors with: + - Shape: (B, N_token1, N_token2, c_hidden) for tensors a and b + - Shape: (B, N_token1, N_token2, 1) for mask tensor + - Sharding on dimensions 1 and 2 (Shard(1) and Shard(2) placements) + - Identical device mesh and placements across all inputs + + The algorithm uses a ring-based communication pattern where: + - Tensor b is transposed and rotated by row + - Tensor a is rotated by column + - Each process computes partial matrix products and accumulates results + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward(ctx, x: DTensor, mask: DTensor, g: DTensor, comm: Ring2DComm, direction: _Direction) -> DTensor: + """Forward pass of distributed triangle multiplication computation. + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object for saving information needed in backward pass. + x : DTensor + Input tensor with shape (B, N_token1, N_token2, c_hidden * 2). + Must be sharded on dimensions 1 and 2. + mask : DTensor + Mask tensor with shape (B, N_token1, N_token2) indicating valid positions. + Must be sharded on dimensions 1 and 2. + g : DTensor + pre-sigmoid gate tensor with shape (B, N_token1, N_token2, c_hidden * 2) indicating valid positions. + Must be sharded on dimensions 1 and 2. + comm : Ring2DComm + Ring communication object configured for the distributed computation. + direction : _Direction + Direction of the triangle multiplication, Outgoing or Incoming. + + Returns + ------- + DTensor + Output tensor with shape (B, N_token1, N_token2, c_hidden). + Contains the distributed triangle multiplication result. + """ + # Check if inputs are of type DTensor + if not isinstance(x, DTensor): + raise TypeError(f"Input 'x' must be of type DTensor. Got type {type(x)}.") + if not isinstance(mask, DTensor): + raise TypeError(f"Input 'mask' must be of type DTensor. Got type {type(mask)}.") + if not isinstance(g, DTensor): + raise TypeError(f"Input 'g' must be of type DTensor. Got type {type(g)}.") + + # Check if inputs have identical device mesh + device_mesh_input = x.device_mesh + if device_mesh_input != mask.device_mesh: + raise ValueError( + f"Input tensors 'x' and 'mask' must have identical device mesh. " + f"Got device meshes {device_mesh_input} and {mask.device_mesh}." + ) + if device_mesh_input != g.device_mesh: + raise ValueError( + f"Input tensors 'x' and 'g' must have identical device mesh. " + f"Got device meshes {device_mesh_input} and {g.device_mesh}." + ) + + # Check if inputs have identical placements + placements_input = x.placements + if placements_input != mask.placements: + raise ValueError( + f"Input tensors 'x' and 'mask' must have identical placements. " + f"Got placements {placements_input} and {mask.placements}." + ) + if placements_input != g.placements: + raise ValueError( + f"Input tensors 'x' and 'g' must have identical placements. " + f"Got placements {placements_input} and {g.placements}." + ) + if placements_input != (Shard(0), Shard(1), Shard(2)): + # For debugging, we requires the placements to be (Shard(0), Shard(1), Shard(2)) + # TODO: remove this to only use the previous check + raise ValueError( + f"Input tensor 'x's placements are not (Shard(0), Shard(1), Shard(2)). " + f"Got placements {placements_input}." + ) + + # Check input shapes + if x.shape[-1] % 2 != 0: + raise ValueError(f"Input tensor 'x' must have an even number of hidden dimension size. Got {x.shape[-1]}") + + if x.ndim != 4: + raise ValueError(f"Input tensor 'x' must have 4 dimensions. Got {x.ndim} dimensions.") + + if mask.ndim != 3: + raise ValueError(f"Input tensor 'mask' must have 3 dimensions. Got {mask.ndim} dimensions.") + + if mask.shape != x.shape[:3]: + raise ValueError( + f"Input tensor 'mask' must have the same shape as the first 3 dimensions of 'x'. " + f"Got mask shape: {mask.shape} vs x shape[:3]: {x.shape[:3]}" + ) + if g.shape != x.shape: + raise ValueError( + f"Input tensor 'g' must have the same shape as 'x'. Got g shape: {g.shape} vs x shape: {x.shape}" + ) + + # Perform consistency check between the ring_comm and the device_mesh_input + i_tensor_dim_to_i_grid_axis = [-1] * x.ndim + for i_grid_axis, placement in enumerate(placements_input): + if isinstance(placement, Shard): + i_tensor_dim_to_i_grid_axis[placement.dim] = i_grid_axis + if i_tensor_dim_to_i_grid_axis[1] == -1 or i_tensor_dim_to_i_grid_axis[2] == -1: + raise ValueError(f"Input tensors' dimensions 1 and 2 must be sharded. Got placements {placements_input}.") + + # Check ring_comm consistency + if comm.group_col != device_mesh_input.get_group(i_tensor_dim_to_i_grid_axis[1]): + raise ValueError( + "Input ring_comm's group_col process group is not the same as the group sharding the input tensors' axis 1" + ) + + coord_device_mesh_input = device_mesh_input.get_coordinate() + if coord_device_mesh_input is None: + raise ValueError(f"ring_comm.coord_2d {comm.coord_2d} is not on device_mesh_input {device_mesh_input}.") + if comm.coord_2d != ( + coord_device_mesh_input[i_tensor_dim_to_i_grid_axis[1]], + coord_device_mesh_input[i_tensor_dim_to_i_grid_axis[2]], + ): + raise ValueError( + f"Input ring_comm's coord_2d {comm.coord_2d} does not match the " + f"device mesh's rank coordinates {coord_device_mesh_input} for the sharded dimensions." + ) + + ctx.mark_non_differentiable(mask) + + # Apply mask and prepare for computation + mask_local = mask.to_local().unsqueeze(-1) + g_local = g.to_local().sigmoid() + x_local = x.to_local() * mask_local + x_local *= g_local + + # the _distributed_bmm will permute a_local and b_local and make + # the resulting tensors contiguous so we don't need to clone them here + a_local, b_local = torch.chunk(x_local, 2, dim=-1) + + # Store tensors for backward pass + if x.requires_grad: + # here x_local is masked and gated + ctx.save_for_backward(a_local, b_local, mask_local, x_local, g_local) + ctx.comm = comm + ctx.shape_x_input = x.shape + ctx.stride_x_input = x.stride() + ctx.shape_g_input = g.shape + ctx.stride_g_input = g.stride() + ctx.placements_input = placements_input + ctx.device_mesh_input = device_mesh_input + ctx.direction = direction + + if direction == _Direction.Outgoing: + permute_lhs = (0, 3, 1, 2) # from (B, n, k, D) to (B, D, n, k) + permute_rhs = (0, 3, 2, 1) # from (B, m, k, D) to (B, D, k, m) + permute_out = (0, 2, 3, 1) # from (B, D, n, m) to (B, n, m, D) + xpose_args = _XposeArgs.rhs + elif direction == _Direction.Incoming: + permute_lhs = (0, 3, 2, 1) # from (B, k, n, D) to (B, D, n, k) + permute_rhs = (0, 3, 1, 2) # from (B, k, m, D) to (B, D, k, m) + permute_out = (0, 2, 3, 1) # from (B, D, n, m) to (B, n, m, D) + xpose_args = _XposeArgs.lhs + else: + raise ValueError(f"Invalid direction: {direction}") + + out_local = _distributed_bmm( + a_local, + b_local, + comm, + permute_lhs=permute_lhs, + permute_rhs=permute_rhs, + permute_out=permute_out, + xpose_args=xpose_args, + ).contiguous() + + shape_output = x.shape[:-1] + (out_local.shape[-1],) + stride_output = update_exhaustive_strides(x.shape, x.stride(), shape_output) + # Convert back to DTensor + out = DTensor.from_local( + out_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=shape_output, + stride=stride_output, + ) + return out + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, d_loss_d_out: DTensor) -> Tuple[DTensor, None, DTensor, None, None, None]: + """Backward pass of distributed triangle multiplication computation.""" + if not isinstance(d_loss_d_out, DTensor): + raise TypeError(f"Input 'd_loss_d_out' must be of type DTensor. Got type {type(d_loss_d_out)}.") + + if d_loss_d_out.device_mesh != ctx.device_mesh_input: + raise ValueError( + f"Input 'd_loss_d_out' must have the same device mesh as the input tensors. " + f"Got device meshes {d_loss_d_out.device_mesh} and {ctx.device_mesh_input}." + ) + + if d_loss_d_out.placements != ctx.placements_input: + raise ValueError( + f"Input 'd_loss_d_out' must have the same placements as the input tensors. " + f"Got placements {d_loss_d_out.placements} and {ctx.placements_input}." + ) + + a, b, mask_local, x_masked_gated_local, g_local = ctx.saved_tensors + comm = ctx.comm + direction = ctx.direction + + # cast d_loss_d_out to the same dtype as a (saved tensor) to avoid type promotion to FP32 + # Note: torch.amp.custom_bwd disables autocast, so operations run in the input dtype. + # If the upstream adjoint (d_loss_d_out) arrives as FP32 (e.g. from loss scaling or downstream FP32 layers), + # mixed-precision ops with saved BF16 tensors would promote to FP32, causing potential communication + # buffer mismatches and NCCL hangs. Explicit casting ensures consistent precision. + d_loss_d_out_local = d_loss_d_out.to_local().to(dtype=a.dtype) + + if direction == _Direction.Outgoing: + lhs_da = d_loss_d_out_local + rhs_da = b + permute_lhs_da = (0, 3, 1, 2) # from (B, n, m, D) to (B, D, n, m) + permute_rhs_da = (0, 3, 1, 2) # from (B, m, k, D) to (B, D, m, k) + permute_out_da = (0, 2, 3, 1) # from (B, D, n, k) to (B, n, k, D) + xpose_args_da = None + + lhs_db = d_loss_d_out_local + rhs_db = a + permute_lhs_db = (0, 3, 2, 1) # from (B, n, m, D) to (B, D, m, n) + permute_rhs_db = (0, 3, 1, 2) # from (B, n, k, D) to (B, D, n, k) + permute_out_db = (0, 2, 3, 1) # from (B, D, m, k) to (B, m, k, D) + xpose_args_db = _XposeArgs.lhs + + elif direction == _Direction.Incoming: + lhs_da = b + rhs_da = d_loss_d_out_local + permute_lhs_da = (0, 3, 1, 2) # from (B, k, m, D) to (B, D, k, m) + permute_rhs_da = (0, 3, 2, 1) # from (B, n, m, D) to (B, D, m, n) + permute_out_da = (0, 2, 3, 1) # from (B, D, k, n) to (B, k, n, D) + xpose_args_da = _XposeArgs.rhs + + lhs_db = a + rhs_db = d_loss_d_out_local + permute_lhs_db = (0, 3, 1, 2) # from (B, k, n, D) to (B, D, k, n) + permute_rhs_db = (0, 3, 1, 2) # from (B, n, m, D) to (B, D, n, m) + permute_out_db = (0, 2, 3, 1) # from (B, D, k, m) to (B, k, m, D) + xpose_args_db = None + else: + raise ValueError(f"Invalid direction: {direction}") + + d_loss_d_a_local = _distributed_bmm( + lhs_da, + rhs_da, + comm, + permute_lhs=permute_lhs_da, + permute_rhs=permute_rhs_da, + permute_out=permute_out_da, + xpose_args=xpose_args_da, + ).contiguous() + + # Phase 2: d_loss_d_b + d_loss_d_b_local = _distributed_bmm( + lhs_db, + rhs_db, + comm, + permute_lhs=permute_lhs_db, + permute_rhs=permute_rhs_db, + permute_out=permute_out_db, + xpose_args=xpose_args_db, + ).contiguous() + + # concatenate and apply mask to gradients + dab_local = torch.cat([d_loss_d_a_local, d_loss_d_b_local], dim=-1) + + x_masked_gated_local *= 1 - g_local + dg_local = dab_local * x_masked_gated_local + + dg = DTensor.from_local( + dg_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=ctx.shape_g_input, + stride=ctx.stride_g_input, + ) + + dx_local = dab_local + dx_local *= mask_local + dx_local *= g_local + + # Convert gradients back to DTensors + dx = DTensor.from_local( + dx_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=ctx.shape_x_input, + stride=ctx.stride_x_input, + ) + + return dx, None, dg, None, None + + +class TriangleMultiplication(nn.Module): + """Distributed triangle multiplication layer. + + This layer implements a distributed version of the triangle multiplication operation, + which is used in attention mechanisms for protein structure prediction and other applications + requiring pairwise feature interactions. + + The layer performs the following operations: + 1. Layer normalization of input pairwise features + 2. Linear projections to create two representation streams (a and b) + 3. Distributed triangle multiplication computation using ring communication + 4. Output gating and final linear projection + + Parameters + ---------- + layer : SerialTriangleMultiplicationOutgoing | SerialTriangleMultiplicationIncoming + The serial triangle multiplication layer to convert to distributed version. + Used to initialize projection weights and normalization parameters. + device_mesh : DeviceMesh + The device mesh for distributed computation across multiple GPUs. + comm : Ring2DComm + Ring communication object for efficient distributed triangle multiplication computation. + """ + + def __init__( + self, + direction: _Direction, + layer: SerialTriangleMultiplicationOutgoing | SerialTriangleMultiplicationIncoming, + device_mesh: DeviceMesh, + comm: Ring2DComm, + ) -> None: + """Initialize the distributed triangle multiplication layer.""" + super().__init__() + self.device_mesh = device_mesh + self.ring_comm = comm + + self.norm_in = LayerNormParamsReplicated(layer.norm_in, self.device_mesh) + self.p_in = LinearParamsReplicated(layer.p_in, self.device_mesh) + self.g_in = LinearParamsReplicated(layer.g_in, self.device_mesh) + + self.norm_out = LayerNormParamsReplicated(layer.norm_out, self.device_mesh) + self.p_out = LinearParamsReplicated(layer.p_out, self.device_mesh) + self.g_out = LinearParamsReplicated(layer.g_out, self.device_mesh) + + if direction == _Direction.Outgoing: + if not isinstance(layer, SerialTriangleMultiplicationOutgoing): + raise ValueError(f"Invalid layer type for direction {direction}: {type(layer)}") + elif direction == _Direction.Incoming: + if not isinstance(layer, SerialTriangleMultiplicationIncoming): + raise ValueError(f"Invalid layer type for direction {direction}: {type(layer)}") + else: + raise ValueError(f"Invalid direction {direction}") + self._direction = direction + + def forward(self, x: DTensor, mask: DTensor) -> DTensor: + """Forward pass of the distributed triangle multiplication layer. + + Parameters + ---------- + x : DTensor + Input pairwise tensor with shape (B, N, N, D). + Must be sharded on dimensions 1 and 2. + mask : DTensor + Mask tensor with shape (B, N, N) indicating valid positions. + Must be sharded on dimensions 1 and 2. + + Returns + ------- + DTensor + Output pairwise tensor with shape (B, N, N, D). + """ + # Stabilize pair embedding tensor with layer norm + x = self.norm_in(x) + x_in = x + g_out = self.g_out(x_in) + + # Decompress: D -> 2D + g = self.g_in(x) + x = self.p_in(x) + + # Distributed triangular multiplication (mask is applied inside the implementation) + x = _TriangleMultiplicationImpl.apply(x, mask, g, self.ring_comm, self._direction) + + # Output gating + x = self.p_out(self.norm_out(x)) + x = sigmoid_gate(x, g_out) + + return x + + +class TriangleMultiplicationOutgoing(TriangleMultiplication): + """Distributed triangle multiplication outgoing layer.""" + + def __init__( + self, + layer: SerialTriangleMultiplicationOutgoing, + device_mesh: DeviceMesh, + comm: Ring2DComm, + ) -> None: + """Initialize the distributed triangle multiplication outgoing layer. + + Parameters + ---------- + layer : SerialTriangleMultiplicationOutgoing + The serial triangle multiplication outgoing layer to convert to distributed version. + device_mesh : DeviceMesh + The device mesh for distributed computation across multiple GPUs. + comm : Ring2DComm + Ring communication object for efficient distributed triangle multiplication computation. + """ + super().__init__(_Direction.Outgoing, layer, device_mesh, comm) + + +class TriangleMultiplicationIncoming(TriangleMultiplication): + """Distributed triangle multiplication incoming layer.""" + + def __init__( + self, + layer: SerialTriangleMultiplicationIncoming, + device_mesh: DeviceMesh, + comm: Ring2DComm, + ) -> None: + """Initialize the distributed triangle multiplication incoming layer. + + Parameters + ---------- + layer : SerialTriangleMultiplicationIncoming + The serial triangle multiplication incoming layer to convert to distributed version. + device_mesh : DeviceMesh + The device mesh for distributed computation across multiple GPUs. + comm : Ring2DComm + Ring communication object for efficient distributed triangle multiplication computation. + """ + super().__init__(_Direction.Incoming, layer, device_mesh, comm) diff --git a/src/boltz/distributed/model/layers/utils.py b/src/boltz/distributed/model/layers/utils.py new file mode 100644 index 000000000..e5cae77c6 --- /dev/null +++ b/src/boltz/distributed/model/layers/utils.py @@ -0,0 +1,2276 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import math +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Partial, Placement, Shard + +from boltz.distributed.model.layers.flatten_and_unflatten import ( + shardwise_flatten, + shardwise_flatten_sharded, + shardwise_unflatten_sharded, +) +from boltz.distributed.model.modules.utils import validate_window_batching_parameters +from boltz.distributed.utils import update_exhaustive_strides + + +def get_query_window_key_range(W: int, H: int, K: int, ids_query_window: torch.Tensor) -> torch.Tensor: + """ + Get the range of half-window indices (j) that query windows attend to. + + Vectorized version that computes ranges for multiple query windows simultaneously. + + Parameters + ---------- + W : int + Atoms per query window (must be even) + H : int + Keys per query window (must be divisible by W//2) + K : int + Total number of query windows (must be > 0) + ids_query_window : torch.Tensor + Query window indices of any shape, with values in range [0, K-1] + + Returns + ------- + torch.Tensor + Shape [2, *ids_query_window.shape] where: + - result[0] contains j_min values + - result[1] contains j_max values + + Raises + ------ + AssertionError + If input parameters don't satisfy constraints + + Examples + -------- + >>> ids = torch.tensor([0, 2, 9]) + >>> ranges = get_query_window_key_range(32, 128, 10, ids) + >>> ranges.shape + torch.Size([2, 3]) + >>> ranges[0] # j_min values + tensor([0, 1, 15]) + >>> ranges[1] # j_max values + tensor([4, 8, 19]) + + >>> ids = torch.tensor([[0, 1], [2, 3]]) + >>> ranges = get_query_window_key_range(32, 128, 10, ids) + >>> ranges.shape + torch.Size([2, 2, 2]) + """ + # Validate inputs + assert W > 0 and W % 2 == 0, "W must be positive and even" + assert H > 0 and H % (W // 2) == 0, "H must be divisible by W//2" + assert K > 0, "K must be positive" + assert torch.all(ids_query_window >= 0) and torch.all( + ids_query_window < K + ), f"Query window indices must be in range [0, {K - 1}]" + + # Calculate h (half-windows per query window) + h = H // (W // 2) + + # Calculate range using proven formula (vectorized) + # Note: j_max >= j_min is guaranteed by the formula (proven mathematically) + + # j_min: Left-clamped to 0 for early windows when (2*i + 1 - h//2) < 0 + # This occurs when i < (h//2 - 1)/2, typically first few query windows + j_min = torch.maximum(torch.zeros_like(ids_query_window), 2 * ids_query_window + 1 - h // 2) + + # j_max: Right-clamped to 2*K-1 for late windows when (2*i + h//2) > 2*K-1 + # This occurs when i > K - 1 - h//4, typically last few query windows + j_max = torch.minimum(torch.full_like(ids_query_window, 2 * K - 1), 2 * ids_query_window + h // 2) + + # Co-occurrence: Both clamps can occur simultaneously for very small K (K < h//2 + 1) + # In such cases, the query window sees all available half-windows [0, 2K-1] + # For typical h=8, this happens when K ≤ 4 (rare in practice) + + # Stack j_min and j_max along new dimension at the front + return torch.stack([j_min, j_max], dim=0) + + +def gather_sliding_windows_backward( + grad_output: torch.Tensor, + window_start_offsets: torch.Tensor, + window_size: int, + axis: int, + input_shape: tuple, +) -> torch.Tensor: + """ + Backward pass for sliding window gathering operation. + + Computes gradient w.r.t. input given gradient w.r.t. output. + + Parameters + ---------- + grad_output : torch.Tensor + Gradient w.r.t. output, shape (..., n_windows, window_size, ...) + window_start_offsets : torch.Tensor + Window starting positions used in forward pass, shape (n_windows,) + window_size : int + Size of each window (h) + axis : int + Axis along which windowing was applied + input_shape : tuple + Shape of the original input tensor + + Returns + ------- + torch.Tensor + Gradient w.r.t. input, shape matching input_shape + + Notes + ----- + Uses index_add_ to accumulate overlapping gradients from multiple windows + that read the same input positions. + """ + # Validate input types + if not isinstance(grad_output, torch.Tensor): + raise TypeError(f"grad_output must be a torch.Tensor, got {type(grad_output)}") + if not isinstance(window_start_offsets, torch.Tensor): + raise TypeError(f"window_start_offsets must be a torch.Tensor, got {type(window_start_offsets)}") + + # Validate window_start_offsets shape + if window_start_offsets.ndim != 1: + raise ValueError( + f"window_start_offsets must be 1D, got {window_start_offsets.ndim}D with shape {window_start_offsets.shape}" + ) + + n_windows = window_start_offsets.shape[0] + + # Normalize and validate axis + ndim_input = len(input_shape) + if axis < 0: + axis_normalized = ndim_input + axis + else: + axis_normalized = axis + + if not (0 <= axis_normalized < ndim_input): + raise ValueError( + f"axis {axis} out of range for input with {ndim_input} dims " + f"(normalized to {axis_normalized}, valid range [0, {ndim_input}))" + ) + + # Build expected grad_output shape from input_shape + # Mapping: (..., in_len, ...) -> (..., n_windows, window_size, ...) + expected_shape = list(input_shape) + expected_shape[axis_normalized] = n_windows # Replace in_len with n_windows + expected_shape.insert(axis_normalized + 1, window_size) # Insert window_size after n_windows + + # Validate complete grad_output shape + if grad_output.shape != tuple(expected_shape): + raise ValueError( + f"grad_output shape mismatch:\n" + f" Expected: {tuple(expected_shape)}\n" + f" Got: {grad_output.shape}\n" + f" (Derived from input_shape={input_shape}, n_windows={n_windows}, " + f"window_size={window_size}, axis={axis_normalized})" + ) + + device = grad_output.device + + # Use normalized axis for subsequent operations + axis = axis_normalized + in_len = input_shape[axis] + + # ============================================================ + # Step A: Normalize grad_output shape for Scatter-Add + # ============================================================ + # Explanation of the Backward Logic: + # + # 1. Preparation (Flattening): Because the input can have arbitrary dimensions + # (e.g., Batch, Time, Features or Channels, Time, Height, Width), performing + # operations on specific dimensions is tricky. In Step A, we permute and + # flatten the tensor into a 2D matrix: [n_windows * window_size, Flat_Features]. + # This standardizes the problem regardless of the input shape. + + # Forward output was: (..., n_windows, window_size, ...) + # We want to isolate (n_windows, window_size) and flatten the rest. + + # 1. Move 'window_size' (currently at axis+1) back to end + # Current: (..., n_windows, window_size, ...) + # Result: (n_windows, window_size, Flattend_Features) + g = grad_output.moveaxis([axis, axis + 1], [0, 1]).flatten(2, -1) + + # 4. Final Flatten for Index Add: + # Current: (n_windows, window_size, Flattend_Features) + # Result: (Total_Elements, Flattened_Features) + # Total_Elements = n_windows * window_size + grad_source_flat = g.flatten(0, 1) + + # ============================================================ + # Step B: Generate Target Indices (Where to accumulate gradients) + # ============================================================ + # 2. Mapping (target_indices): This is the inverse of the forward index_select. + # + # Forward: "For window w, read index i." + # + # Backward: "For window w, gradient i belongs to index i+window_start_offsets[w]." + # We generate a full grid of these destination indices. + + # We need to map every element in (n_windows, window_size) back to + # the padded input vector index. + # Formula: index[w, i] = window_start_offsets[w] + pad_top + i + # where the rhs is exactly the indices of the padded input vector + # padded_vector in the forward pass, i.e., index[w, i] is the fwd + # output[..., w, i, ...]'s index in the original source padded_vector. + + # Determine padding (same logic as forward) + min_k = window_start_offsets.min().item() + pad_top = max(0, -min_k) + + # 1. Create Base Offsets for each window + # Shape: (n_windows, 1) + base_indices = (window_start_offsets + pad_top).unsqueeze(1) + + # 2. Create Window Steps (0, 1, ..., window_size-1) + # Shape: (1, window_size) + window_steps = torch.arange(window_size, device=device).unsqueeze(0) + + # 3. Broadcast to get full index map + # Shape: (n_windows, window_size) + target_indices = base_indices + window_steps + + # 4. Flatten indices to match the flattened gradients + # Shape: (Total_Elements,) + target_indices_flat = target_indices.view(-1) + + # ============================================================ + # Step C: Accumulate Gradients (Scatter Add) + # ============================================================ + # 3. Accumulation (index_add_): This is the crucial step. Since multiple windows + # might overlap (read from the same input index), their gradients must be summed. + # index_add_ handles this atomically. + # + # Note on Padding: We allocate a buffer that includes the padding size. + # Gradients computed for "padded zeros" are accumulated into the padding + # regions of this buffer. + + # Calculate padded length + max_k = window_start_offsets.max().item() + needed_length = max_k + window_size + pad_bottom = max(0, needed_length - in_len) + + # 1. Create a zero-filled buffer for the PADDED input gradient + # Length = pad_top + in_len + pad_bottom + total_padded_len = pad_top + in_len + pad_bottom + num_features = grad_source_flat.shape[1] + + grad_padded_buffer = torch.zeros((total_padded_len, num_features), device=device, dtype=grad_output.dtype) + + # 2. Perform the Accumulation + # This is the "Overlap-Add" magic. Gradients from overlapping windows + # are summed up automatically. + # Here grad_source_flat[i] is the upstream adjoint of the fwd output + # while target_indices_flat[i] is the index of the fwd padded_vector + # and the gradient thereof + grad_padded_buffer.index_add_(0, target_indices_flat, grad_source_flat) + + # ============================================================ + # Step D: Handle Padding & Reshape + # ============================================================ + # 4. Slicing: In the final step, we simply slice out the valid middle region + # (pad_top to pad_top + in_len), effectively discarding the gradients that + # accumulated in the "virtual" padded zones. + + # 1. Slice off the padding (Discard gradients that fell into the pad zones) + # We only keep indices [pad_top : pad_top + in_len] + # NOTE that in the actual usage case corresponding to the get_indexing_matrix + # and single_to_keys, there should be an input mask that would go thru + # the same single_to_keys operation as do the sequence data so the mask + # would also result in zeros corresponding to the pad zones in the padded_vector, + # which implies that it's safe to discard the gradients that fell into the pad zones here + grad_input_flat = grad_padded_buffer[pad_top : pad_top + in_len] + + # 2. Reshape back to the original input geometry + # We flattened (..., In_Len, ...) into (In_Len, Features). + # We need to reverse this. + + # A. Calculate dimensions before and after 'axis' + # input_shape = (D1, D2, In_Len, D3, D4) + # We need to reshuffle grad_input_flat (In_Len, D1*D2*D3*D4) + # back to that shape. + + # It is cleaner to use the original shape directly but permuted. + # Target Layout for reshape: (In_Len, ...) + + # Construct the permuted shape where 'axis' is at dim 0 + permuted_shape = list(input_shape) + permuted_shape.pop(axis) + permuted_shape.insert(0, in_len) + + # Reshape + grad_input = grad_input_flat.reshape(permuted_shape) + + # Inverse Permutation: Move dim 0 back to 'axis' + grad_input = grad_input.moveaxis(0, axis) + + return grad_input + + +class GatherSlidingWindows(torch.autograd.Function): + @staticmethod + def forward(ctx, input, window_start_offsets, window_size, axis): + """ + Gather overlapping sliding windows from input at specified starting positions. + + This operation implements efficient windowed attention by extracting windows + from the input sequence. The underlying mathematical structure is a block + Toeplitz matrix (see Theorems 1-6 in documentation). + + Example & Intuition: + -------------------- + Consider gathering 3 windows of size 8 from a sequence of length 10. + The operation can be viewed as a sparse matrix 'M' (n_windows=3, window_size=8, seq_len=10) + where each window gathers contiguous elements from the input. + + Semantic mapping to get_indexing_matrix: + - axis 0: query window index + - axis 1: slot within window, i.e., index in [0, h-1] where h == H // (W // 2) + - axis 2: input sequence position, index to the "2 * K" half-windows dimension + + 1. Visualizing as Sparse Matrix 'M': + NOTE: This corresponds to the transposed onehot tensor from get_indexing_matrix: + M == onehot.transpose(1, 0).transpose(-2, -1) + + The '1's represent which input index is gathered to the output. + Notice the diagonal pattern shifts by 2 between windows. + + Window 0 (starting at position -3): "Pad Left" + [0 0 0 0 0 0 0 0 0 0] <- Slot 0 (padded zero) + [0 0 0 0 0 0 0 0 0 0] <- Slot 1 (padded zero) + [0 0 0 0 0 0 0 0 0 0] <- Slot 2 (padded zero) + [1 0 0 0 0 0 0 0 0 0] <- Slot 3 (gathers Input[0]) + [0 1 0 0 0 0 0 0 0 0] <- Slot 4 (gathers Input[1]) + [0 0 1 0 0 0 0 0 0 0] ... + [0 0 0 1 0 0 0 0 0 0] + [0 0 0 0 1 0 0 0 0 0] + + Window 1 (starting at position -1): "Shifted +2 from Window 0" + [0 0 0 0 0 0 0 0 0 0] <- Slot 0 (padded zero) + [1 0 0 0 0 0 0 0 0 0] <- Slot 1 (gathers Input[0]) + [0 1 0 0 0 0 0 0 0 0] <- Slot 2 (gathers Input[1]) + [0 0 1 0 0 0 0 0 0 0] ... + [0 0 0 1 0 0 0 0 0 0] + [0 0 0 0 1 0 0 0 0 0] + [0 0 0 0 0 1 0 0 0 0] + [0 0 0 0 0 0 1 0 0 0] + + Window 2 (starting at position +1): "Shifted +2 from Window 1" + [0 1 0 0 0 0 0 0 0 0] <- Slot 0 (gathers Input[1]) + [0 0 1 0 0 0 0 0 0 0] <- Slot 1 (gathers Input[2]) + [0 0 0 1 0 0 0 0 0 0] ... + [0 0 0 0 1 0 0 0 0 0] + [0 0 0 0 0 1 0 0 0 0] + [0 0 0 0 0 0 1 0 0 0] + [0 0 0 0 0 0 0 1 0 0] + [0 0 0 0 0 0 0 0 1 0] + + 2. The "Unfold" Implementation: + Instead of explicit matrix multiplication, we use torch.unfold to create + sliding windows efficiently. If input is 2D (seq_len, features), gathering + windows via this operation is equivalent to applying the sparse matrix M. + + For computational efficiency, we slide a window over the (padded) input: + - Window 0 (offset=-3): reads positions [-3] to [4] (with padding) + - Window 2 (offset=+1): reads positions [1] to [8] + + Args: + input: Tensor of arbitrary shape (..., seq_len, ...) + window_start_offsets: (n_windows,) tensor of integer starting positions + window_size: int, the size of each output window (h) + axis: int, the dimension corresponding to sequence length + + Returns: + Tensor of shape (..., n_windows, window_size, ...) + 1. The 'n_windows' dimension replaces the original sequence dimension at 'axis' + 2. The 'window_size' dimension is placed immediately after 'n_windows' (at axis + 1) + """ + + # 1. Normalize and Save Axis/Shape info + ndim = input.ndim + if axis < 0: + axis += ndim + + in_len = input.shape[axis] + + # 2. Analyze Padding + min_k = window_start_offsets.min().item() + max_k = window_start_offsets.max().item() + + pad_top = max(0, -min_k) + needed_length = max_k + window_size + pad_bottom = max(0, needed_length - in_len) + + ctx.mark_non_differentiable(window_start_offsets) + + # Save context for backward + # We save shapes and integers, and the window_start_offsets tensor. + # We DO NOT save the input (saves memory). + ctx.save_for_backward(window_start_offsets) + ctx.params = { + "input_shape": input.shape, + "pad_top": pad_top, + "pad_bottom": pad_bottom, + "output_len": window_size, + "axis": axis, + "in_len": in_len, + } + + # 3. Forward Logic: Create sliding windows via unfold + if pad_top == 0 and pad_bottom == 0: + padded_vector = input + else: + pad_arg = [0] * (2 * ndim) + pad_idx_left = (ndim - 1 - axis) * 2 + pad_idx_right = pad_idx_left + 1 + pad_arg[pad_idx_left] = pad_top + pad_arg[pad_idx_right] = pad_bottom + padded_vector = torch.nn.functional.pad(input, pad_arg) + + # Shape: (..., num_windows, ..., window_size) + # - Position 'axis' now contains num_windows (replaces padded_len) + # - New dimension window_size added at position -1 (end) + windows = padded_vector.unfold(axis, window_size, 1) + + # Shape: (n_windows,) - translate window_start_offsets to padded coordinate system + slice_indices = window_start_offsets + pad_top + + # Shape: (..., n_windows, ..., window_size) - select specific windows along axis + selected_windows = windows.index_select(axis, slice_indices) + + # Permute: (..., n_windows, ..., window_size) -> (..., n_windows, window_size, ...) + result = selected_windows.moveaxis(-1, axis + 1) + + return result + + @staticmethod + def backward(ctx, grad_output): + """ + Explanation of the Backward Logic + + 1. Preparation (Flattening): Because the input can have arbitrary dimensions + (e.g., Batch, Time, Features or Channels, Time, Height, Width), performing + operations on specific dimensions is tricky. In Step A, we permute and + flatten the tensor into a 2D matrix: [N_Windows * Output_Len, Flat_Features]. + This standardizes the problem regardless of the input shape. + + 2. Mapping (target_indices): This is the inverse of the forward index_select. + + Forward: "For window w, read index i." + + Backward: "For window w, gradient i belongs to index i+offset[w]." + We generate a full grid of these destination indices. + + 3. Accumulation (index_add_): This is the crucial step. Since multiple windows + might overlap (read from the same input index), their gradients must be summed. + index_add_ handles this atomically. + + Note on Padding: We allocate a buffer that includes the padding size. + Gradients computed for "padded zeros" are accumulated into the padding + regions of this buffer. + + 4. Slicing: In the final step, we simply slice out the valid middle region + (pad_top to pad_top + in_len), effectively discarding the gradients that + accumulated in the "virtual" padded zones. + """ + # Retrieve context + (window_start_offsets,) = ctx.saved_tensors + params = ctx.params + input_shape = params["input_shape"] + window_size = params["output_len"] # Stored as output_len in context for backward compat + axis = params["axis"] + + # Call standalone backward function (includes validation) + grad_input = gather_sliding_windows_backward(grad_output, window_start_offsets, window_size, axis, input_shape) + + return grad_input, None, None, None + + +def compute_query_window_ownership(W: int, H: int, K: int, qw_start: int, qw_end: int) -> dict: + """ + Compute halo requirements for a given query window ownership. + + Parameters + ---------- + W : int + Atoms per query window (must be even) + H : int + Keys per query window (must be divisible by W//2) + K : int + Total number of query windows (global) + qw_start : int + First owned query window (inclusive) + qw_end : int + Last owned query window + 1 (exclusive) + + Returns + ------- + dict + { + 'hw_owned': (int, int), # Owned half-windows [2*qw_start, 2*qw_end) + 'hw_needed': (int, int), # All half-windows needed [start, end) + 'left_halo_size': int, # Half-windows needed from left neighbor + 'right_halo_size': int, # Half-windows needed from right neighbor + } + + Examples + -------- + >>> # Rank owns QW[4,8) for K=12 + >>> ownership = compute_query_window_ownership(32, 128, 12, 4, 8) + >>> ownership['hw_owned'] + (8, 16) # Owns HW[8-15] (inferred from QW range) + >>> ownership['hw_needed'] + (5, 19) # Needs HW[5-18] + >>> ownership['left_halo_size'] + 3 # Needs HW[5,6,7] from left neighbor + """ + assert W > 0 and W % 2 == 0 + assert H > 0 and H % (W // 2) == 0 + assert K > 0 + assert 0 <= qw_start <= qw_end <= K + + # Infer half-window ownership from query window ownership + # Query window i owns half-windows [2i, 2i+1] + hw_start = 2 * qw_start + hw_end = 2 * qw_end + + # Determine which half-windows are needed for owned query windows + if qw_start < qw_end: + owned_qw_ids = torch.arange(qw_start, qw_end) + ranges = get_query_window_key_range(W, H, K, owned_qw_ids) + + hw_need_start = ranges[0].min().item() + hw_need_end = ranges[1].max().item() + 1 # Exclusive end + + # Compute halo sizes + left_halo_size = max(0, hw_start - hw_need_start) + right_halo_size = max(0, hw_need_end - hw_end) + else: + # No owned query windows + hw_need_start = hw_start + hw_need_end = hw_start + left_halo_size = 0 + right_halo_size = 0 + + return { + "hw_owned": (hw_start, hw_end), + "hw_needed": (hw_need_start, hw_need_end), + "left_halo_size": left_halo_size, + "right_halo_size": right_halo_size, + } + + +def get_halo_from_neighbors( + rank: int, + size_group: int, + n_half_windows_local: int, + W: int, + H: int, + K: int, +) -> tuple[list, list]: + """ + Compute send/recv metadata for halo exchange (supports multi-hop). + + Returns: + tuple(recv_meta, send_meta) + - recv_meta: list of (peer_rank, halo_type, offset_in_halo, length) + where halo_type is 'left' or 'right'. + - send_meta: list of (peer_rank, offset_in_local, length) + """ + # 1. Validate inputs + assert W > 0 and W % 2 == 0, "W must be positive and even" + assert H > 0 and H % (W // 2) == 0, "H must be divisible by W//2" + assert H // (W // 2) % 2 == 0, "H // (W // 2) must be even" + assert K > 0, "K must be positive" + if K % size_group != 0: + raise ValueError(f"K {K} must be an integer multiple of the number of ranks {size_group}.") + + # 2. Vectorized ownership computation + # Rank ownership: [hw_start, hw_end) + rank_ids = torch.arange(size_group) + hw_owned_starts = rank_ids * n_half_windows_local + hw_owned_ends = (rank_ids + 1) * n_half_windows_local + + # Query window ownership: [qw_start, qw_end) + qw_starts = hw_owned_starts // 2 + qw_ends = hw_owned_ends // 2 + + # Needed half-windows: [hw_need_start, hw_need_end) + # Note: K is large, so we don't want to compute range for all query windows individually. + # But compute_query_window_ownership works per rank range. + # We can vectorize get_query_window_key_range over the start/end QWs of all ranks. + + # For a range of QWs [qs, qe), the needed range is union of needs of all q in [qs, qe). + # Since QWs are monotonic, needed range is [min(need(qs)), max(need(qe-1))]. + # Exception: if qs >= qe (empty rank), need range is empty/irrelevant. + + # Compute needs for first QW of each rank + ranges_start = get_query_window_key_range(W, H, K, qw_starts) + hw_need_starts = ranges_start[0] # j_min of first QW + + # Compute needs for last QW of each rank + # Use (qw_ends - 1) but clamp to >= 0 to avoid index -1 for empty ranks + last_qw_ids = (qw_ends - 1).clamp(min=0) + ranges_end = get_query_window_key_range(W, H, K, last_qw_ids) + hw_need_ends = ranges_end[1] + 1 # j_max + 1 of last QW (exclusive end) + + # Handle empty ranks: if qw_start >= qw_end, they own nothing and need nothing (locally) + # We set need = owned to result in 0 halo size + is_empty = qw_starts >= qw_ends + hw_need_starts = torch.where(is_empty, hw_owned_starts, hw_need_starts) + hw_need_ends = torch.where(is_empty, hw_owned_starts, hw_need_ends) + + # Get current rank's values + my_hw_start = hw_owned_starts[rank].item() + my_hw_end = hw_owned_ends[rank].item() + my_need_start = hw_need_starts[rank].item() + my_need_end = hw_need_ends[rank].item() + + recv_meta = [] # (peer, halo_type, offset_in_halo, length) + send_meta = [] # (peer, offset_in_local, length) + + # 3. Identify Neighbors via SearchSorted + + # --- RECV Left --- + # Need neighbors covering [my_need_start, my_hw_start) + if my_need_start < my_hw_start: + # Find first rank whose owned range ends after my_need_start + # i.e. hw_owned_ends[p] > my_need_start + p_start_idx = torch.searchsorted(hw_owned_ends, my_need_start, side="right") + # Find last rank whose owned range starts before my_hw_start + # i.e. hw_owned_starts[p] < my_hw_start + p_end_idx = torch.searchsorted(hw_owned_starts, my_hw_start, side="left") + + for peer in range(p_start_idx, p_end_idx): + if peer == rank: + continue + # Overlap: [max(my_need, peer_start), min(my_start, peer_end)) + l_start = max(my_need_start, hw_owned_starts[peer].item()) + l_end = min(my_hw_start, hw_owned_ends[peer].item()) + if l_start < l_end: + recv_meta.append((peer, "left", l_start - my_need_start, l_end - l_start)) + + # --- RECV Right --- + # Need neighbors covering [my_hw_end, my_need_end) + if my_hw_end < my_need_end: + # First rank whose owned range ends after my_hw_end + p_start_idx = torch.searchsorted(hw_owned_ends, my_hw_end, side="right") + # Last rank whose owned range starts before my_need_end + p_end_idx = torch.searchsorted(hw_owned_starts, my_need_end, side="left") + + for peer in range(p_start_idx, p_end_idx): + if peer == rank: + continue + # Overlap: [max(my_end, peer_start), min(my_need_end, peer_end)) + r_start = max(my_hw_end, hw_owned_starts[peer].item()) + r_end = min(my_need_end, hw_owned_ends[peer].item()) + if r_start < r_end: + recv_meta.append((peer, "right", r_start - my_hw_end, r_end - r_start)) + + # --- SEND Left (Peers needing me for their left halo) --- + # Condition: peer > rank AND peer_need_start < my_hw_end + # Range of peers: (rank + 1, ...) such that peer_need_start < my_hw_end + # Since hw_need_starts is sorted (monotonic with rank), we can search + if rank + 1 < size_group: + # Find last peer where peer_need_start < my_hw_end + # searchsorted on hw_need_starts to find insertion point of my_hw_end + limit_idx = torch.searchsorted(hw_need_starts, my_hw_end, side="left") + # Valid peers are in range [rank + 1, limit_idx) + # Note: limit_idx is where value >= my_hw_end starts, so up to limit_idx-1 are < my_hw_end + + for peer in range(rank + 1, limit_idx): + # Overlap: Peer needs [p_need_start, p_start), I own [my_start, my_end) + # Intersection: [max(p_need_start, my_start), min(p_start, my_end)) + p_need_start = hw_need_starts[peer].item() + p_hw_start = hw_owned_starts[peer].item() + + l_start = max(p_need_start, my_hw_start) + l_end = min(p_hw_start, my_hw_end) + if l_start < l_end: + send_meta.append((peer, l_start - my_hw_start, l_end - l_start)) + + # --- SEND Right (Peers needing me for their right halo) --- + # Condition: peer < rank AND peer_need_end > my_hw_start + # Range of peers: (..., rank - 1) such that peer_need_end > my_hw_start + # Since hw_need_ends is sorted, we can search + if rank > 0: + # Find first peer where peer_need_end > my_hw_start + # searchsorted on hw_need_ends with my_hw_start + start_idx = torch.searchsorted(hw_need_ends, my_hw_start, side="right") + # Valid peers are in range [start_idx, rank) + + for peer in range(start_idx, rank): + # Overlap: Peer needs [p_end, p_need_end), I own [my_start, my_end) + # Intersection: [max(p_end, my_start), min(p_need_end, my_end)) + p_hw_end = hw_owned_ends[peer].item() + p_need_end = hw_need_ends[peer].item() + + r_start = max(p_hw_end, my_hw_start) + r_end = min(p_need_end, my_hw_end) + if r_start < r_end: + send_meta.append((peer, r_start - my_hw_start, r_end - r_start)) + + return recv_meta, send_meta + + +def pack_and_pad( + input: torch.Tensor, mask: torch.Tensor, axis: int, W: int, keep_input_padding: bool = False +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Left-pack valid elements from input and pad to the next multiple of W. + + This utility prepares variable-length inputs (with padding masks) for downstream + operations like gather_sliding_windows by: + 1. Left-packing valid elements (moving all True-masked values to the front) + 2. Padding the sequence length to a multiple of W + + Note: This function does NOT reshape the output. Any reshaping (e.g., to + (2*K, W//2) for half-window layout) should be done by the caller after + this function returns. + + Process: + 1. Validate inputs (shapes, broadcastability, W is positive and even) + 2. Determine target_len based on keep_input_padding flag + 3. Pad input/mask to target_len if needed + 4. Sort mask descending (stable) to left-pack valid elements + 5. Gather input elements using the sort indices + 6. Zero out invalid positions in the packed output + + Parameters + ---------- + input : torch.Tensor + Input tensor of arbitrary shape (..., seq_len, ...) + mask : torch.Tensor + Boolean mask with mask.shape[axis] == input.shape[axis]. + Other dimensions must be broadcastable with input. + True=valid element, False=padding to ignore. + axis : int + Dimension containing the sequence + W : int + Padding factor (must be positive and even). The output length along + axis will be padded to the next multiple of W. + keep_input_padding : bool, optional + If False (default), the output length is based on the maximum number + of valid elements across all slices orthogonal to axis, padded to + the next multiple of W. This removes as much padding as possible. + If True, the output length is based on input.shape[axis], padded + to the next multiple of W. This preserves the original sequence length. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + (packed_output, gather_indices, packed_mask) + + packed_output : torch.Tensor + Left-packed and padded tensor of shape (..., target_len, ...). + Valid elements are at the front, followed by zeros. + target_len is a multiple of W. + + gather_indices : torch.Tensor + Indices used for gathering, shape (..., target_len, ...). + Required for pack_and_pad_backward to scatter gradients back. + + packed_mask : torch.Tensor + Boolean mask for packed_output, shape (..., target_len, ...). + True for valid elements, False for padding. + """ + # 1. Validate inputs + if not isinstance(mask, torch.Tensor): + raise TypeError(f"mask must be a torch.Tensor, got {type(mask)}") + if mask.dtype != torch.bool: + raise TypeError(f"mask must have dtype torch.bool, got {mask.dtype}") + if W <= 0 or W % 2 != 0: + raise ValueError(f"W must be positive and even, got {W}") + if not isinstance(keep_input_padding, bool): + raise TypeError(f"Expected bool for keep_input_padding, got {type(keep_input_padding)}") + + # Normalize axis + ndim = input.ndim + if axis < 0: + axis += ndim + if not (0 <= axis < ndim): + raise ValueError(f"axis {axis} out of range for {ndim}D input") + + # Validate mask shape compatibility along axis + if input.ndim != mask.ndim: + raise ValueError(f"mask ndim {mask.ndim} must match input ndim {input.ndim}") + + if input.shape[axis] != mask.shape[axis]: + raise ValueError( + f"mask and input must match along axis={axis}: " + f"input.shape[{axis}]={input.shape[axis]}, mask.shape[{axis}]={mask.shape[axis]}" + ) + + # Check broadcastability (other dims) + try: + broadcasted_shape = torch.broadcast_shapes(input.shape, mask.shape) + except RuntimeError as e: + raise ValueError(f"mask shape {mask.shape} not broadcastable to input shape {input.shape}: {e}") + + if broadcasted_shape != input.shape: + raise ValueError( + f"mask shape {mask.shape} broadcasts to {broadcasted_shape}, which mismatches input {input.shape}" + ) + + # 2. Determine target length based on keep_input_padding flag + if keep_input_padding: + # Use original input length as basis for padding calculation + len_basis = input.shape[axis] + else: + # Use max valid count to minimize output length (removes excess padding) + valid_counts = mask.sum(dim=axis, dtype=torch.long) + max_valid = valid_counts.max().item() + len_basis = max_valid + + # Round up to next multiple of W + target_len = ((len_basis + W - 1) // W) * W + + # 3. Pad input and mask to target_len if needed + current_len = input.shape[axis] + if target_len > current_len: + pad_len = target_len - current_len + + # Build padding arguments for torch.nn.functional.pad + # Format: (left_dim_N, right_dim_N, ..., left_dim_0, right_dim_0) + pad_arg = [0] * (2 * input.ndim) + pad_idx = (input.ndim - 1 - axis) * 2 + 1 # Index for right-padding along axis + pad_arg[pad_idx] = pad_len + + input_padded = torch.nn.functional.pad(input, pad_arg) + mask_padded = torch.nn.functional.pad(mask, pad_arg) + else: + input_padded = input + mask_padded = mask + + # 4. Left-pack valid elements using stable descending sort on mask + # Sorting True (1) before False (0) in descending order moves valid elements to the front + mask_padded_sorted, argsort_mask_padded = torch.sort(mask_padded, dim=axis, descending=True, stable=True) + + # Slice to target_len (handles case where input was longer than target_len) + slices = [slice(None)] * mask_padded.ndim + slices[axis] = slice(0, target_len) + argsort_mask_padded = argsort_mask_padded[tuple(slices)] + mask_padded_sorted = mask_padded_sorted[tuple(slices)] + + # 5. Expand indices to match input dimensions for gathering + # torch.gather requires index tensor to have same ndim as input. + # Expanding broadcasts the sort indices across non-axis dimensions (e.g., features). + target_gather_shape = list(input_padded.shape) + target_gather_shape[axis] = target_len + argsort_mask_padded_expanded = argsort_mask_padded.expand(target_gather_shape) + mask_padded_sorted_expanded = mask_padded_sorted.expand(target_gather_shape) + + # Gather input elements according to the left-packing order + input_packed_padded = torch.gather(input_padded, axis, argsort_mask_padded_expanded) + + # 6. Zero out invalid positions + # The gather operation may have pulled "garbage" values from original invalid positions. + # Multiplying by the sorted mask ensures only valid elements are non-zero. + input_packed_padded = input_packed_padded * mask_padded_sorted_expanded.to(input.dtype) + + return input_packed_padded, argsort_mask_padded_expanded, mask_padded_sorted_expanded + + +def pack_and_pad_backward( + grad_output: torch.Tensor, + mask_output: torch.Tensor, + indices: torch.Tensor, + input_shape: tuple, + axis: int, +) -> torch.Tensor: + """ + Backward pass for pack_and_pad. + + Scatters gradients from the packed/padded output back to the original input shape. + This reverses the gather operation by using scatter_add with the same indices. + + Parameters + ---------- + grad_output : torch.Tensor + Gradient w.r.t. packed output, shape (..., target_len, ...). + This is the gradient flowing back from operations applied to pack_and_pad's output. + mask_output : torch.Tensor + Mask from pack_and_pad's forward pass, shape (..., target_len, ...). + Used to zero out gradients for invalid (padding) positions. + indices : torch.Tensor + Gather indices from pack_and_pad's forward pass (argsort_mask_padded_expanded), + shape (..., target_len, ...). Used to scatter gradients back to original positions. + input_shape : tuple + Shape of the original input tensor to pack_and_pad. + axis : int + Dimension containing the sequence (same as in forward pass). + + Returns + ------- + torch.Tensor + Gradient w.r.t. original input, shape matching input_shape. + """ + # Mask out gradients for invalid positions (padding) + grad_masked = grad_output * mask_output + + # Determine buffer size - may need extra space if target_len > original input length + target_len = indices.shape[axis] + current_len = input_shape[axis] + max_len = max(target_len, current_len) + + # Create gradient buffer for scatter operation + if max_len > current_len: + buffer_shape = list(input_shape) + buffer_shape[axis] = max_len + grad_input_padded = torch.zeros(buffer_shape, dtype=grad_masked.dtype, device=grad_masked.device) + else: + grad_input_padded = torch.zeros(input_shape, dtype=grad_masked.dtype, device=grad_masked.device) + + # Scatter gradients back to their original positions + # This reverses the gather operation from the forward pass + grad_input_padded.scatter_add_(axis, indices, grad_masked) + + # Slice back to original input shape if we used an extended buffer + if max_len > current_len: + slices = [slice(None)] * len(input_shape) + slices[axis] = slice(0, current_len) + grad_input = grad_input_padded[tuple(slices)] + else: + grad_input = grad_input_padded + + return grad_input + + +def gather_sliding_windows(input, window_start_offsets, window_size, axis): + """ + Gather sliding windows from input using specified offsets. + + This operation implements windowed attention by extracting overlapping windows + from the input sequence. The underlying mathematical structure is a block Toeplitz + matrix (see Theorems 1-6 in documentation). + + Args: + input: Input tensor of shape (..., sequence_len, ...) + window_start_offsets: Starting positions for each window, shape (n_windows,) + window_size: Size of each window (h) + axis: Dimension along which to gather windows + + Returns: + Tensor of shape (..., n_windows, window_size, ...) + """ + return GatherSlidingWindows.apply(input, window_start_offsets, window_size, axis) + + +def get_flattened_range_indices(start_ends: torch.Tensor, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + """ + Generate flattened row and column indices for a set of ranges per row. + + Args: + start_ends: (D, 2) tensor where col 0 is start, col 1 is end (exclusive). + device: torch device. + + Returns: + (row_indices, col_indices) tuple of 1D tensors. + """ + starts = start_ends[:, 0].to(torch.long) + ends = start_ends[:, 1].to(torch.long) + lengths = (ends - starts).clamp(min=0) + + if lengths.sum() == 0: + return torch.empty(0, dtype=torch.long, device=device), torch.empty(0, dtype=torch.long, device=device) + + # 1. Generate Row Indices + row_indices = torch.repeat_interleave(torch.arange(start_ends.shape[0], device=device), lengths) + + # 2. Generate Column Indices + # Global flat counter + total_length = lengths.sum() + flat_range = torch.arange(total_length, device=device) + + # Offsets for each row in the flat sequence + cum_lengths = torch.cumsum(lengths, dim=0) + shifts = torch.zeros_like(cum_lengths) + shifts[1:] = cum_lengths[:-1] + + # Expand shifts to match flat range + shifts_expanded = torch.repeat_interleave(shifts, lengths) + starts_expanded = torch.repeat_interleave(starts, lengths) + + # col_idx = start + relative_idx + col_indices = starts_expanded + (flat_range - shifts_expanded) + + return row_indices, col_indices + + +def _distributed_pack_and_pad( + input: Optional[DTensor], mask: DTensor, axis: int, W: int, keep_input_padding: bool = False +) -> tuple[Optional[DTensor], DTensor, torch.Tensor, torch.Tensor]: + """Distributed left-pack and pad operation for DTensor inputs. + + Left-packs valid elements (as indicated by mask) to the front along the specified axis, + then pads to the right (end) to the next multiple of (W * size_group) where size_group + is the number of ranks sharding the tensor along axis. Communication is performed to redistribute + elements across ranks so that the packed result is evenly sharded. + + Note: This function does NOT reshape the output. Any reshaping (e.g., to + (2*K, W//2) for half-window layout) should be done by the caller. + + Args: + input (Optional[DTensor]): Input DTensor of shape (..., seq_len, ...) sharded along axis, + or None. If None, input is set to mask (useful for computing metadata only). + mask (DTensor): Boolean mask DTensor with shape broadcastable to input. True indicates + valid elements, False indicates padding to ignore. + axis (int): Dimension containing the sequence to pack and pad. + W (int): Padding factor. The output length along axis will be padded to the next + multiple of (W * size_group). + keep_input_padding (bool, optional): If False (default), output length is based on + the maximum number of valid elements across all slices. If True, output length + is based on input.shape[axis]. Defaults to False. + + Returns: + tuple[Optional[DTensor], DTensor, torch.Tensor, torch.Tensor]: + output: Packed and padded DTensor, or None if input was None. Shape is + (..., target_len, ...) where target_len is a multiple of (W * size_group). + mask_output: Boolean mask DTensor for output, same shape as output. + argsort_mask_flat_local: Local argsort indices (2D tensor of shape + (shape_leading_flat_local, shape_axis_local)) used for backward pass. + valid_counts_all_ranks: Valid counts per rank, shape (size_group, shape_leading_flat_local). + """ + # 0. sanity checks + if input is not None and not isinstance(input, DTensor): + raise TypeError(f"Expected DTensor, got {type(input)}") + if not isinstance(mask, DTensor): + raise TypeError(f"Expected DTensor, got {type(mask)}") + if not isinstance(axis, int): + raise TypeError(f"Expected int for axis, got {type(axis)}") + if not isinstance(W, int): + raise TypeError(f"Expected int for W, got {type(W)}") + if not isinstance(keep_input_padding, bool): + raise TypeError(f"Expected bool for keep_input_padding, got {type(keep_input_padding)}") + + # Mask must be boolean to ensure correct valid count computation + if mask.dtype != torch.bool: + raise TypeError( + f"mask must have dtype torch.bool to avoid precision issues in valid count computation, " + f"got {mask.dtype}. Use mask.bool() to convert." + ) + + if input is None: + input = mask + has_input = False + else: + has_input = True + if input.device_mesh != mask.device_mesh: + raise ValueError( + f"input and mask must have the same device mesh but got {input.device_mesh} and {mask.device_mesh}" + ) + if input.placements != mask.placements: + raise ValueError( + f"input and mask must have the same placements but got {input.placements} and {mask.placements}" + ) + + placements = input.placements + device_mesh = input.device_mesh + + i_dim_device_mesh_shard_axis = None + for i_dim_device_mesh, placement in enumerate(placements): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + if input.shape[placement.dim] % device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {input.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + if placement.dim == axis: + i_dim_device_mesh_shard_axis = i_dim_device_mesh + + if i_dim_device_mesh_shard_axis is None: + raise ValueError(f"input is not sharded along axis {axis}") + + try: + shape_broadcast = torch.broadcast_shapes(input.shape, mask.shape) + except RuntimeError as e: + raise ValueError("input and mask shapes cannot be broadcasted") from e + + if shape_broadcast != input.shape: + raise ValueError(f"broadcasted shape {shape_broadcast} is not equal to input.shape {input.shape}") + + # 1. Get rank info from passed mesh + rank = device_mesh.get_local_rank(i_dim_device_mesh_shard_axis) + size_group = device_mesh.size(i_dim_device_mesh_shard_axis) + group = device_mesh.get_group(i_dim_device_mesh_shard_axis) + global_rank_peers = torch.distributed.get_process_group_ranks(group) + + local_input = input.to_local() + local_mask = mask.to_local() + device = local_input.device + + # Explicitly broadcast mask to match input shape locally + if local_mask.shape != local_input.shape: + local_mask = local_mask.expand(local_input.shape) + + # --- Phase 1: Forward Logic --- + + # 1.1 Reshape to 2D (shape_leading_flat_local, N) + # Move axis to last dim + local_input_moved = local_input.movedim(axis, -1) + local_mask_moved = local_mask.movedim(axis, -1) + + shape_leading_flat_local = local_input_moved.numel() // local_input_moved.shape[-1] + shape_axis_local = local_input_moved.shape[-1] + + local_input_2d = local_input_moved.reshape(shape_leading_flat_local, shape_axis_local) + local_mask_2d = local_mask_moved.reshape(shape_leading_flat_local, shape_axis_local) + + # 1.2 Local Sort/Pack + local_valid_counts = local_mask_2d.sum(dim=1, dtype=torch.long) # (shape_leading_flat_local,) + + # Sort mask to get valid elements to the left + # local_mask_sorted.shape == argsort_mask_flat_local.shape == local_mask_2d.shape == (shape_leading_flat_local, shape_axis_local) + local_mask_sorted, argsort_mask_flat_local = torch.sort(local_mask_2d, dim=1, descending=True, stable=True) + + # Gather valid data + local_valid_data = torch.gather(local_input_2d, 1, argsort_mask_flat_local) + + # Zero out invalid elements + local_valid_data = local_valid_data * local_mask_sorted.to(local_input_2d.dtype) + + # 1.3 Global Planning (Per Row) + valid_counts_all_ranks = torch.zeros(size_group, shape_leading_flat_local, device=device, dtype=torch.long) + dist.all_gather_into_tensor( + valid_counts_all_ranks, local_valid_counts.unsqueeze(0).to(dtype=torch.long), group=group + ) + + global_ends = valid_counts_all_ranks.cumsum(dim=0) + global_starts = global_ends - valid_counts_all_ranks + total_valid = global_ends[-1] # (shape_leading_flat_local,) + + my_global_start = global_starts[rank] # (shape_leading_flat_local,) + my_global_end = my_global_start + local_valid_counts # (shape_leading_flat_local,) + + # 1.4 Target Partitioning (Global Uniform) + + if keep_input_padding: + # NOTE: we use input.shape[axis] instead of max_total_valid =mask.sum(dim=axis).max() + # to stay consistent with the Boltz implementation. + # This is slightly inefficient because the trailing invalid elements can be reduced + # by directly pad towards (W * size_group) based on max_total_valid + # instead of the original input length + len_basis = input.shape[axis] + else: + # Use single global target length derived from MAX row length to ensure continuity. + # This prevents gaps in valid sequence data across rank boundaries for short rows. + max_total_valid = total_valid.max().item() + # while the other comm ops are within the "group", the (shape_leading_flat_local,) axis is virtually + # sharded if other axes than "axis" in the input are also sharded so we need to + # do a global all_reduce to get the max_total_valid across the sharded (shape_leading_flat_local,) axis + tensor_max_total_valid = torch.tensor(max_total_valid, device=device, dtype=torch.long) + for i_subgroup, subgroup in enumerate(device_mesh.get_all_groups()): + # we can't use the default world group because the input device_mesh can be + # a subgroup of the world group, e.g., device_mesh is a submesh. Also, DeviceMesh + # has no API to return the union of its subgroups so we need to iterate over all subgroups + if i_subgroup == i_dim_device_mesh_shard_axis: + # already reduce via all_gather and local max() + continue + torch.distributed.all_reduce(tensor_max_total_valid, op=torch.distributed.ReduceOp.MAX, group=subgroup) + len_basis = tensor_max_total_valid.item() + + target_len_global = math.ceil(len_basis / (W * size_group)) * (W * size_group) + target_len_local = target_len_global // size_group + + # Vectorized target ranges (Broadcast scalar to all rows) + all_target_starts = torch.arange(size_group, device=device) * target_len_local # (WS,) + all_target_ends = all_target_starts + target_len_local # (WS,) + + # 1.5 Vectorized Communication + + if has_input: + output_local = torch.zeros(shape_leading_flat_local, target_len_local, dtype=local_input.dtype, device=device) + ops = [] + recv_bufs = {} + + # We will save these indices for backward + # send_indices_map[peer] = (rows, cols) into local_valid_data + # recv_indices_map[peer] = (rows, cols) into output_local + send_indices_map = {} + recv_indices_map = {} + + for peer in range(size_group): + # --- SEND Logic --- + # Intersection: [my_start, my_end) AND [target_start, target_end) + p_target_start = all_target_starts[peer] # Scalar + p_target_end = all_target_ends[peer] # Scalar + + # Broadcast scalar target to (shape_leading_flat_local,) + overlap_start = torch.maximum(my_global_start, p_target_start) + overlap_end = torch.minimum(my_global_end, p_target_end) + + # Convert to local indices relative to local_valid_data (starts at 0) + # local_s = overlap_start - my_global_start + local_start = (overlap_start - my_global_start).clamp(min=0) + local_end = (overlap_end - my_global_start).clamp(min=0) + + # (shape_leading_flat_local, 2) + send_ranges = torch.stack([local_start, local_end], dim=1) + send_rows, send_cols = get_flattened_range_indices(send_ranges, device) + + send_indices_map[peer] = (send_rows, send_cols) + send_buf = local_valid_data[send_rows, send_cols] + + if peer != rank: + if send_buf.numel() > 0: + ops.append(dist.P2POp(dist.isend, send_buf, global_rank_peers[peer], group=group)) + + # --- RECV Logic --- + # Intersection: [my_target_start, my_target_end) AND [src_start, src_end) + # my_target_start is rank * target_len_local (scalar) + my_t_start = rank * target_len_local + my_t_end = (rank + 1) * target_len_local + + p_src_start = global_starts[peer] # (shape_leading_flat_local,) + p_src_end = global_ends[peer] # (shape_leading_flat_local,) + + recv_overlap_start = torch.maximum(torch.tensor(my_t_start, device=device), p_src_start) + recv_overlap_end = torch.minimum(torch.tensor(my_t_end, device=device), p_src_end) + + # Convert to local output indices (relative to my_t_start) + out_start = (recv_overlap_start - my_t_start).clamp(min=0) + out_end = (recv_overlap_end - my_t_start).clamp(min=0) + + # (shape_leading_flat_local, 2) + recv_ranges = torch.stack([out_start, out_end], dim=1) + recv_rows, recv_cols = get_flattened_range_indices(recv_ranges, device) + + recv_indices_map[peer] = (recv_rows, recv_cols) + recv_len = recv_rows.numel() + + if peer != rank: + if recv_len > 0: + recv_buf = torch.empty(recv_len, dtype=local_input.dtype, device=device) + ops.append(dist.P2POp(dist.irecv, recv_buf, global_rank_peers[peer], group=group)) + recv_bufs[peer] = recv_buf + else: + # Self-copy + output_local[recv_rows, recv_cols] = send_buf + + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + for peer, buf in recv_bufs.items(): + r_rows, r_cols = recv_indices_map[peer] + output_local[r_rows, r_cols] = buf + + # 1.6 Reshape to Output Format + if local_input_moved.ndim > 1: + # output_local is (shape_leading_flat_local, target_len_local) + # Reconstruct non-axis dimensions + output_reshaped = output_local.unflatten(0, local_input_moved.shape[:-1]) + output_final_flat = output_reshaped.movedim(-1, axis) + else: + # output_local.shape is (shape_leading_flat_local=1, target_len_local) + # since (shape_leading_flat_local,) is an temporary axis we added in this function, we need to squeeze it out + output_final_flat = output_local.squeeze(0) + else: + output_final_flat = None + + # 1.7 Generate output masks + # Compute valid range per row for this rank + # total_valid is (shape_leading_flat_local,) - global valid count per row + # rank is local rank in group + # target_len_local is scalar - max length per rank + # Formula: i_end_valid_local = min(total_valid, (rank + 1) * target_len_local) - rank * target_len_local + # This tells us how many valid elements this rank "owns" in the global sequence, starting from 0. + i_end_valid_local = ( + torch.minimum(total_valid, torch.tensor((rank + 1) * target_len_local, device=device)) - rank * target_len_local + ) + i_end_valid_local = i_end_valid_local.clamp(min=0) # (shape_leading_flat_local,) + + # Create mask (shape_leading_flat_local, target_len_local) + idx_cols = torch.arange(target_len_local, device=device) # (target_len_local,) + mask_local_2d = idx_cols.unsqueeze(0) < i_end_valid_local.unsqueeze( + 1 + ) # (shape_leading_flat_local, target_len_local) + + # Reshape mask to match output_final structure + if local_input_moved.ndim > 1: + mask_reshaped = mask_local_2d.unflatten(0, local_input_moved.shape[:-1]) + mask_final_flat = mask_reshaped.movedim(-1, axis) + else: + mask_final_flat = mask_local_2d.squeeze(0) + + if output_final_flat is None: + # has_input==False means mask input and mask output + output_final_flat = mask_final_flat + + # compute global output shape + # This function doesn't modify any other input axes except for "axis" and for inserting + # an extra axis at position "axis + 1". The modification for "axis" is guaranteed to be + # evenly sharded by the i_dim_device_mesh_shard_axis. + shape_output = list(input.shape) + shape_output[axis] = output_final_flat.shape[axis] * size_group + shape_output = tuple(shape_output) + + strides_output = update_exhaustive_strides(output_final_flat.shape, output_final_flat.stride(), shape_output) + + output = DTensor.from_local(output_final_flat, device_mesh, placements, shape=shape_output, stride=strides_output) + + strides_mask_output = update_exhaustive_strides(mask_final_flat.shape, mask_final_flat.stride(), shape_output) + mask_output = DTensor.from_local( + mask_final_flat, device_mesh, placements, shape=shape_output, stride=strides_mask_output + ) + + return (output if has_input else None, mask_output, argsort_mask_flat_local, valid_counts_all_ranks) + + +def _distributed_unpad_and_unpack( + input: DTensor, + axis: int, + argsort_mask_flat_local: torch.Tensor, + valid_counts_all_ranks_unpacked: torch.Tensor, + shape_input_expected: torch.Size | None = None, + device_mesh_expected: DeviceMesh | None = None, + placements_expected: tuple[Placement, ...] | None = None, +) -> DTensor: + """Distributed unpad and unpack operation for DTensor inputs. + + Inverse of _distributed_pack_and_pad. Scatters elements from the packed/padded layout + back to their original positions using the argsort indices from the forward pass. + Communication is performed to redistribute elements across ranks. + + Args: + input (DTensor): Packed DTensor of shape (..., target_len, ...) sharded along axis. + axis (int): Dimension containing the packed sequence. + argsort_mask_flat_local (torch.Tensor): Local argsort indices from _distributed_pack_and_pad, + 2D tensor of shape (shape_leading_flat_local, shape_axis_local). + valid_counts_all_ranks_unpacked (torch.Tensor): Valid counts per rank from + _distributed_pack_and_pad, shape (size_group, shape_leading_flat_local). + shape_input_expected (torch.Size | None, optional): Expected input shape for validation. + device_mesh_expected (DeviceMesh | None, optional): Expected device mesh for validation. + placements_expected (tuple[Placement, ...] | None, optional): Expected placements for validation. + + Returns: + DTensor: Unpacked DTensor with elements scattered back to their original positions. + Shape is (..., original_len, ...) where original_len = argsort_mask_flat_local.shape[1] * size_group. + """ + device = input.device + + # sanity checks + if not isinstance(input, DTensor): + raise TypeError(f"Expected DTensor, got {type(input)}") + if device_mesh_expected is not None and input.device_mesh != device_mesh_expected: + raise ValueError(f"input device_mesh mismatch: expected {device_mesh_expected}, got {input.device_mesh}") + if placements_expected is not None and input.placements != placements_expected: + raise ValueError(f"input placements mismatch: expected {placements_expected}, got {input.placements}") + if shape_input_expected is not None and input.shape != shape_input_expected: + raise ValueError(f"input shape mismatch: expected {shape_input_expected}, got {input.shape}") + + if not isinstance(argsort_mask_flat_local, torch.Tensor): + raise TypeError(f"Expected torch.Tensor, got {type(argsort_mask_flat_local)}") + + if argsort_mask_flat_local.ndim != 2: + raise ValueError(f"argsort_mask_flat_local must be a 2D tensor, got {argsort_mask_flat_local.ndim}D tensor") + + device_mesh = input.device_mesh + placements = input.placements + + i_dim_device_mesh_shard_axis = None + for i_dim_device_mesh, placement in enumerate(placements): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + if input.shape[placement.dim] % device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {input.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + if placement.dim == axis: + i_dim_device_mesh_shard_axis = i_dim_device_mesh + + if i_dim_device_mesh_shard_axis is None: + raise ValueError(f"input is not sharded along axis {axis}") + + size_group = device_mesh.size(i_dim_device_mesh_shard_axis) + if valid_counts_all_ranks_unpacked.shape != (size_group, argsort_mask_flat_local.shape[0]): + raise ValueError( + f"valid_counts_all_ranks_unpacked.shape {valid_counts_all_ranks_unpacked.shape} != " + f"(size_group, argsort_mask_flat_local.shape[0]) {size_group, argsort_mask_flat_local.shape[0]}" + ) + + rank = device_mesh.get_local_rank(i_dim_device_mesh_shard_axis) + group = device_mesh.get_group(i_dim_device_mesh_shard_axis) + global_rank_peers = torch.distributed.get_process_group_ranks(group) + + # 2. Reshape input to 2D (D, max_target_len) + input_local = input.to_local() + input_moved = input_local.movedim(axis, -1) + input_2d = input_moved.reshape(-1, input_moved.shape[-1]) + + # except for the potential difference in padding along the 'axis', + # argsort_mask_flat_local.shape[0] == input_2d.shape[0] + if argsort_mask_flat_local.shape[0] != input_2d.shape[0]: + raise ValueError( + f"argsort_mask_flat_local.shape[0] {argsort_mask_flat_local.shape[0]} != input_2d.shape[0] {input_2d.shape[0]}" + ) + + # 3. Reconstruct Masks (same as _distributed_pack_and_pad logic) + # NOTE: indexing the input with these global_{start,end} indices + # will automatically exclude the invalid elements' contributions to the output + global_ends = valid_counts_all_ranks_unpacked.cumsum(dim=0) + global_starts = global_ends - valid_counts_all_ranks_unpacked + my_global_start = global_starts[rank] + + # The 'target' length from the pad_and_pack logic correspond to the + # length of the input along 'axis' because this is the inverse of pad_and_pack + target_len_local = input_local.shape[axis] + all_target_starts = torch.arange(size_group, device=device) * target_len_local + all_target_ends = all_target_starts + target_len_local + + # 4. Reverse Communication + output_valid_2d = torch.zeros(argsort_mask_flat_local.shape, dtype=input_2d.dtype, device=device) + local_valid_counts = valid_counts_all_ranks_unpacked[rank] + + ops = [] + recv_bufs = {} + + for peer in range(size_group): + # --- REVERSE RECV (Corresponds to Forward SEND) --- + # Reconstruct send_mask intervals from forward + p_target_start = all_target_starts[peer] + p_target_end = all_target_ends[peer] + + overlap_start = torch.maximum(my_global_start, p_target_start) + # Reconstruct my_global_end + my_global_end = my_global_start + local_valid_counts + overlap_end = torch.minimum(my_global_end, p_target_end) + + local_start = (overlap_start - my_global_start).clamp(min=0) + local_end = (overlap_end - my_global_start).clamp(min=0) + + send_ranges = torch.stack([local_start, local_end], dim=1) + send_rows, send_cols = get_flattened_range_indices(send_ranges, device) + + expected_output_len = send_rows.numel() + + if peer != rank: + if expected_output_len > 0: + output_recv_buf = torch.empty(expected_output_len, dtype=input_2d.dtype, device=device) + ops.append(dist.P2POp(dist.irecv, output_recv_buf, global_rank_peers[peer], group=group)) + # Store indices to scatter later + recv_bufs[peer] = (send_rows, send_cols, output_recv_buf) + + # --- REVERSE SEND (Corresponds to Forward RECV) --- + # Reconstruct recv_mask intervals from forward + my_t_start = rank * target_len_local + my_t_end = (rank + 1) * target_len_local + + p_src_start = global_starts[peer] + p_src_end = global_ends[peer] + + recv_overlap_start = torch.maximum(torch.tensor(my_t_start, device=device), p_src_start) + recv_overlap_end = torch.minimum(torch.tensor(my_t_end, device=device), p_src_end) + + out_start = (recv_overlap_start - my_t_start).clamp(min=0) + out_end = (recv_overlap_end - my_t_start).clamp(min=0) + + recv_ranges = torch.stack([out_start, out_end], dim=1) + recv_rows, recv_cols = get_flattened_range_indices(recv_ranges, device) + + input_to_send = input_2d[recv_rows, recv_cols] + + if peer != rank: + if input_to_send.numel() > 0: + ops.append(dist.P2POp(dist.isend, input_to_send, global_rank_peers[peer], group=group)) + + if peer == rank: + # Self-copy: scatter directly + # output_valid_2d[send_rows, send_cols] = input_to_send + output_valid_2d[send_rows, send_cols] = input_to_send + + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + for peer, (r, c, buf) in recv_bufs.items(): + output_valid_2d[r, c] = buf + + # 5. Scatter to Original Shape + # output_valid_2d is sorted and of same shape as argsort_mask_flat_local + # Scatter back to unsorted + output_2d = torch.zeros_like(output_valid_2d) + output_2d.scatter_(1, argsort_mask_flat_local, output_valid_2d) + + # Reshape D back to non-axis dims + if input.ndim > 1: + shape_no_axis = list(input_local.shape) + shape_no_axis.pop(axis) + output_reshaped = output_2d.unflatten(0, tuple(shape_no_axis)) + output_final = output_reshaped.movedim(-1, axis) + else: + # output_2d.shape is (D=1, argsort_mask_flat_local.shape[1]) + # since (D,) is an temporary axis we added in this function, we need to squeeze it out + output_final = output_2d.squeeze(0) + + # 'axis' is guaranteed sharded along device mesh dimension i_dim_device_mesh_shard_axis + shape_axis_output = output_final.shape[axis] * size_group + shape_output = input.shape[:axis] + (shape_axis_output,) + input.shape[axis + 1 :] + + strides_output = update_exhaustive_strides(output_final.shape, output_final.stride(), shape_output) + + output = DTensor.from_local(output_final, device_mesh, placements, shape=shape_output, stride=strides_output) + + return output + + +class DistributedPackAndPad(torch.autograd.Function): + """Autograd function for distributed pack and pad operation. + + Forward: Left-packs valid elements and pads to the right (end) to a multiple of (W * size_group). + Backward: Unpacks and scatters gradients back to original positions. + + See _distributed_pack_and_pad for detailed documentation. + """ + + @staticmethod + def forward( + ctx, input: DTensor, mask: DTensor, axis: int, W: int, keep_input_padding: bool = False + ) -> tuple[DTensor, DTensor]: + output, mask_output, argsort_mask_flat_local, valid_counts_all_ranks = _distributed_pack_and_pad( + input, mask, axis, W, keep_input_padding + ) + + # Context Saving + ctx.save_for_backward(argsort_mask_flat_local, valid_counts_all_ranks) + ctx.axis = axis + ctx.device_mesh = input.device_mesh + ctx.placements = input.placements + ctx.shape_output_fwd = output.shape + ctx.mark_non_differentiable(mask_output, argsort_mask_flat_local, valid_counts_all_ranks) + + return output, mask_output + + @staticmethod + def backward(ctx, grad_output: DTensor, grad_mask: DTensor) -> tuple[DTensor, None, None, None, None]: + # 1. Unpack + (argsort_mask_flat_local, valid_counts_all_ranks) = ctx.saved_tensors + + grad_input = _distributed_unpad_and_unpack( + grad_output, + ctx.axis, + argsort_mask_flat_local, + valid_counts_all_ranks, + shape_input_expected=ctx.shape_output_fwd, + device_mesh_expected=ctx.device_mesh, + placements_expected=ctx.placements, + ) + + return grad_input, None, None, None, None + + +class DistributedUnpadAndUnpack(torch.autograd.Function): + """Autograd function for distributed unpad and unpack operation. + + Forward: Unpacks elements from packed/padded layout back to original positions. + Backward: Packs gradients using the forward pass of _distributed_pack_and_pad. + + This is the inverse operation of DistributedPackAndPad. + """ + + @staticmethod + def forward( + ctx, + input: DTensor, + mask: DTensor, + mask_original: DTensor, + axis: int, + keep_input_padding: bool, + ) -> DTensor: + # masks must have same ndim as input + # Non-axis trailing dimensions can be 1 (for broadcasting) or match input + if mask.ndim != input.ndim: + raise RuntimeError( + f"mask ndim {mask.ndim} must equal input ndim {input.ndim}. " + f"For 3D inputs, use 3D masks with shape (B, N, 1) for broadcasting." + ) + + if mask_original.ndim != input.ndim: + raise RuntimeError( + f"mask_original ndim {mask_original.ndim} must equal input ndim {input.ndim}. " + f"For 3D inputs, use 3D masks with shape (B, N, 1) for broadcasting." + ) + + # Masks must be boolean to ensure correct valid count computation + if mask.dtype != torch.bool: + raise TypeError( + f"mask must have dtype torch.bool to avoid precision issues in valid count computation, " + f"got {mask.dtype}. Use mask.bool() to convert." + ) + if mask_original.dtype != torch.bool: + raise TypeError( + f"mask_original must have dtype torch.bool to avoid precision issues in valid count computation, " + f"got {mask_original.dtype}. Use mask_original.bool() to convert." + ) + + # Check shapes are broadcast-compatible + try: + shape_input = torch.broadcast_shapes(input.shape, mask.shape) + except RuntimeError as e: + raise RuntimeError(f"Shapes of input {input.shape} and mask {mask.shape} are not broadcastable.") from e + + if shape_input != input.shape: + raise RuntimeError(f"Broadcasted shape {shape_input} is not equal to input shape {input.shape}") + + if input.device_mesh != mask.device_mesh: + raise RuntimeError( + f"Input and mask must have the same device mesh but got {input.device_mesh} and {mask.device_mesh}" + ) + + if input.placements != mask.placements: + raise RuntimeError( + f"Input and mask must have the same placements but got {input.placements} and {mask.placements}" + ) + + if input.device_mesh != mask_original.device_mesh: + raise RuntimeError( + f"Input and mask must have the same device mesh but got {input.device_mesh} and {mask_original.device_mesh}" + ) + + if input.placements != mask_original.placements: + raise RuntimeError( + f"Input and mask must have the same placements but got {input.placements} and {mask_original.placements}" + ) + + placements = input.placements + device_mesh = input.device_mesh + + i_dim_device_mesh_shard_axis = None + for i_dim_device_mesh, placement in enumerate(placements): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + if input.shape[placement.dim] % device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {input.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + if placement.dim == axis: + i_dim_device_mesh_shard_axis = i_dim_device_mesh + + if i_dim_device_mesh_shard_axis is None: + raise ValueError(f"input is not sharded along axis {axis}") + + # 0. Normalize axis + ndim = input.ndim + if axis < 0: + axis += ndim + + if not (0 <= axis < ndim): + raise ValueError(f"axis {axis} out of range for {ndim}D input") + + # 1. While this function does not concern about the 'window' size in application, + # the usage requirement of _distributed_pack_and_pad requires input padding factor + # Here we just take out the group_size from the input shape along 'axis' to get the padding factor + W = input.shape[axis] // device_mesh.shape[i_dim_device_mesh_shard_axis] + + # 2. Expand mask_original in case the mask has been broadcasted from mask_original + shape_mask_original_expanded = input.shape[:axis] + (mask_original.shape[axis],) + input.shape[axis + 1 :] + + try: + shape_mask_original_expanded_broadcast = torch.broadcast_shapes( + shape_mask_original_expanded, mask_original.shape + ) + except RuntimeError as e: + raise RuntimeError( + f"Shapes of input {input.shape} and mask {mask_original.shape} " + f"are not broadcastable excluding the sequence axis {axis}." + ) from e + + if shape_mask_original_expanded_broadcast != shape_mask_original_expanded: + raise ValueError( + f"mask_original shape {mask_original.shape} is not broadcastable to input shape {input.shape} excluding the sequence axis {axis}" + ) + + if mask_original.shape != shape_mask_original_expanded: + mask_original_local = mask_original.to_local() + input_local = input.to_local() + + target_local_shape = ( + input_local.shape[:axis] + (mask_original_local.shape[axis],) + input_local.shape[axis + 1 :] + ) + + mask_original_local_expanded = mask_original_local.expand(target_local_shape) + + strides_mask_original_expanded = tuple( + 0 if mask_original_local_expanded.stride()[i] == 0 else mask_original.stride()[i] + for i in range(mask_original.ndim) + ) + + mask_original = DTensor.from_local( + mask_original_local_expanded, + mask_original.device_mesh, + mask_original.placements, + shape=shape_mask_original_expanded, + stride=strides_mask_original_expanded, + ) + + # 3. Re-compute metadata from mask_original + # We run forward pass using mask_original as both input and mask. + # This is cheap(er) and gives us the correct metadata for the backward pass. + # We need mask_original to be a DTensor. + + _, mask_pack_and_pad, argsort_mask_flat_local, valid_counts_all_ranks = _distributed_pack_and_pad( + None, mask_original, axis, W, keep_input_padding + ) + + ctx.mark_non_differentiable(mask_original, mask_pack_and_pad, argsort_mask_flat_local, valid_counts_all_ranks) + + # 4. Check mask consistency (if mask_qw_dtensor is provided) + # We assume mask passed in corresponds to original packed and padded mask. + if not torch.equal(mask.to_local(), mask_pack_and_pad.to_local()): + raise ValueError("mask_original does not correspond to mask_qw_dtensor when sorted/packed.") + + # 5. Invert the data + output = _distributed_unpad_and_unpack(input, axis, argsort_mask_flat_local, valid_counts_all_ranks) + + ctx.save_for_backward(mask_original) + ctx.params = {"axis": axis, "W": W, "keep_input_padding": keep_input_padding} + + return output + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[DTensor, None, None, None, None]: + # grad_output is Square layout (gradient of our output). + # We want gradient of our input (Window layout). + # This corresponds to Forward pass of DistributedUnmaskReshape. + + (mask_original,) = ctx.saved_tensors + params = ctx.params + axis = params["axis"] + W = params["W"] + keep_input_padding = params["keep_input_padding"] + + grad_input, _, _, _ = _distributed_pack_and_pad(grad_output, mask_original, axis, W, keep_input_padding) + + assert grad_input is not None + + return grad_input, None, None, None, None + + +def distributed_pack_and_pad( + input_dtensor: DTensor, mask_dtensor: DTensor, W: int, axis: int, keep_input_padding: bool = False +) -> tuple[DTensor, DTensor]: + """Distributed left-pack and pad operation for DTensor inputs. + + Left-packs valid elements (as indicated by mask) to the front along the specified axis, + then pads to the right (end) to the next multiple of (W * size_group) where size_group + is the number of ranks sharding the tensor along axis. + + Args: + input_dtensor (DTensor): Input DTensor of shape (..., seq_len, ...) sharded along axis. + mask_dtensor (DTensor): Boolean mask DTensor with shape broadcastable to input. + True indicates valid elements, False indicates padding to ignore. + W (int): Padding factor. The output length along axis will be padded to the next + multiple of (W * size_group). + axis (int): Dimension containing the sequence to pack and pad. + keep_input_padding (bool, optional): If False (default), output length is based on + the maximum number of valid elements across all slices. If True, output length + is based on input_dtensor.shape[axis]. Defaults to False. + + Returns: + tuple[DTensor, DTensor]: + output: Packed and padded DTensor. Shape is (..., target_len, ...) where + target_len is a multiple of (W * size_group). + mask_output: Boolean mask DTensor for output, same shape as output. + """ + return DistributedPackAndPad.apply(input_dtensor, mask_dtensor, axis, W, keep_input_padding) + + +def distributed_unpad_and_unpack( + input: DTensor, + mask: DTensor, + mask_original: DTensor, + axis: int, + keep_input_padding: bool, +) -> DTensor: + """Distributed unpad and unpack operation for DTensor inputs. + + Inverse of distributed_pack_and_pad. Scatters elements from the packed/padded layout + back to their original positions. + + Args: + input (DTensor): Packed DTensor of shape (..., target_len, ...) sharded along axis. + mask (DTensor): Boolean mask DTensor for input, indicating valid (True) vs + invalid (False) elements in the packed layout. + mask_original (DTensor): Boolean mask DTensor indicating valid elements in the + original unpacked layout. Used to reconstruct the argsort indices. + axis (int): Dimension containing the packed sequence. + keep_input_padding (bool): Whether input padding was kept in the forward pass + of distributed_pack_and_pad. + + Returns: + DTensor: Unpacked DTensor with elements scattered back to their original positions. + Shape matches mask_original's shape along axis. + """ + return DistributedUnpadAndUnpack.apply(input, mask, mask_original, axis, keep_input_padding) + + +class DistributedGatherSlidingWindows(torch.autograd.Function): + @staticmethod + def forward(ctx, dense_dtensor: DTensor, window_size: int, axis: int) -> DTensor: + """ + Distributed Forward Pass using ownership-based halo exchange. + + Args: + dense_dtensor: Input DTensor sharded along axis + window_size: h = H // (W//2) in the original window batching parameters from Boltz + axis: dense_dtensor's axis to apply windowing + """ + # 0. sanity checks + if not isinstance(dense_dtensor, DTensor): + raise TypeError(f"Expected DTensor, got {type(dense_dtensor)}") + + if not isinstance(axis, int): + raise TypeError(f"Expected int for axis, got {type(axis)}") + # Normalize axis + ndim = dense_dtensor.ndim + + if ndim < 2: + raise ValueError(f"dense_dtensor must have at least 2 dimensions, got {ndim}D") + + if axis < 0: + axis += ndim + if not (0 <= axis < ndim): + raise ValueError(f"axis {axis} out of range for {ndim}D input") + if not (isinstance(window_size, int) and window_size > 0 and window_size % 2 == 0): + # h := window_size must be an even integer per the original get_indexing_matrix function from Boltz, + # i.e., h / 2 must be an integer + raise TypeError(f"Expected positive even integer for window_size, got {type(window_size)}") + + # the halo size computation from compute_query_window_ownership assumes the + # window batching parameters W, H and K and their math relationship from the + # original get_indexing_matrix function from Boltz, so the input dense_dtensor's + # shape must satisfy the requirements: + # 1. dense_dtensor.shape[axis] == 2 * K + # 2. dense_dtensor.shape[axis + 1] == W // 2 + # In addition, the window_size must be even integer and that + # window_size * dense_dtensor.shape[axis + 1] gives the resulting H + if dense_dtensor.shape[axis] % 2 != 0: + raise ValueError(f"dense_dtensor.shape[{axis}] must be even, got {dense_dtensor.shape[axis]}") + K = dense_dtensor.shape[axis] // 2 + W = dense_dtensor.shape[axis + 1] * 2 + H = window_size * dense_dtensor.shape[axis + 1] + + device_mesh = dense_dtensor.device_mesh + placements = dense_dtensor.placements + + i_dim_device_mesh_shard_axis = None + i_axes_sharded_by_mesh_dim = [None] * device_mesh.ndim + for i_dim_device_mesh, placement in enumerate(placements): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + if dense_dtensor.shape[placement.dim] % device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {dense_dtensor.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + if placement.dim == axis: + i_dim_device_mesh_shard_axis = i_dim_device_mesh + if placement.dim == axis + 1: + # NOTE: technically, axis + 1 of shape (W // 2) is already supported because + # there is no special handling of axis + 1 inside this function but it's only + # treated as another axes orthogonal to axis. But the upstream DistributedUnmaskReshape + # can not produce sharded placement of axis + 1 so we exclude it from execution + # to fence off upstream bugs + raise NotImplementedError(f"Sharding along axis {axis + 1} is not supported") + i_axes_sharded_by_mesh_dim[i_dim_device_mesh] = placement.dim + + if i_dim_device_mesh_shard_axis is None: + raise ValueError(f"input dense_dtensor is not sharded along axis {axis}") + + # 1. Unpack DTensor and get rank info + local_tensor = dense_dtensor.to_local() + rank_in_group = device_mesh.get_local_rank(i_dim_device_mesh_shard_axis) + size_group = device_mesh.size(i_dim_device_mesh_shard_axis) + group = device_mesh.get_group(i_dim_device_mesh_shard_axis) + ranks_global_group = torch.distributed.get_process_group_ranks(group) + rank_global = ranks_global_group[rank_in_group] + + # 2. Determine ownership from DTensor sharding + # The DTensor is sharded along axis (sequence of half-windows) + # Global shape[axis] = 2*K, local_tensor.shape[axis] = (2 * K) / size_group + # The following code, esp. the usage of compute_query_window_ownership, + # assumes the underlying sequence length owned by each rank, i.e., + # local_tensor.shape[axis] * local_tensor.shape[axis + 1], is a multiple of "W". + # This is equivalent to: (2 * K // size_group * (W // 2)) % W == 0, i.e., + # (K // size_group * W) % W == 0, i.e., K // size_group is an integer + if K % size_group != 0: + raise ValueError( + f"K {K} is not a integer multiple of the number of ranks {size_group} sharding the dense_dtensor along axis {axis}" + ) + local_hw_len = local_tensor.shape[axis] + hw_start = rank_in_group * local_hw_len + hw_end = (rank_in_group + 1) * local_hw_len + + # Query windows: each QW i owns half-windows [2i, 2i+1] + # So if we own HW [hw_start, hw_end), we own QW [hw_start//2, hw_end//2) + qw_start = hw_start // 2 + qw_end = hw_end // 2 + + # Validate: each rank must own at least one query window + assert qw_start < qw_end, ( + f"Rank {rank_global} has no query windows to process: QW[{qw_start},{qw_end}). " + f"This typically means size_group > K. Either reduce size_group or increase K." + ) + + ownership = compute_query_window_ownership(W, H, K, qw_start, qw_end) + + hw_need_start, hw_need_end = ownership["hw_needed"] + left_halo_size = ownership["left_halo_size"] + right_halo_size = ownership["right_halo_size"] + + # 3. Halo Exchange (multi-hop) + device = local_tensor.device + + # Create halo buffers + left_halo = None + right_halo = None + + if left_halo_size > 0: + # Get shape with left_halo_size at axis + left_shape = list(local_tensor.shape) + left_shape[axis] = left_halo_size + left_halo = torch.zeros(left_shape, dtype=local_tensor.dtype, device=device) + + if right_halo_size > 0: + right_shape = list(local_tensor.shape) + right_shape[axis] = right_halo_size + right_halo = torch.zeros(right_shape, dtype=local_tensor.dtype, device=device) + + recv_meta, send_meta = get_halo_from_neighbors(rank_in_group, size_group, local_hw_len, W, H, K) + + ops = [] + + # Execute Recvs + recv_temps = [] + for peer, htype, offset, length in recv_meta: + target_buffer = left_halo if htype == "left" else right_halo + # Recv needs contiguous buffer. target_buffer slice might not be. + # Create temp buffer + shape = list(target_buffer.shape) + shape[axis] = length + recv_buf = torch.empty(shape, dtype=local_tensor.dtype, device=device) + ops.append(dist.P2POp(dist.irecv, recv_buf, ranks_global_group[peer], group=group)) + recv_temps.append((recv_buf, target_buffer, offset, length)) + + # Execute Sends + for peer, offset, length in send_meta: + send_buf = local_tensor.narrow(axis, offset, length).contiguous() + ops.append(dist.P2POp(dist.isend, send_buf, ranks_global_group[peer], group=group)) + + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # Copy temp buffers to actual halo tensors + for buf, target, off, ln in recv_temps: + target.narrow(axis, off, ln).copy_(buf) + + # 4. Construct extended local view: [left_halo, local_data, right_halo] + parts = [] + if left_halo is not None: + parts.append(left_halo) + parts.append(local_tensor) + if right_halo is not None: + parts.append(right_halo) + + extended_local = torch.cat(parts, dim=axis) if len(parts) > 1 else local_tensor + + # 5. Apply local unfold on extended data + # Compute local offsets relative to extended_local + # In extended_local, owned half-windows start at index left_halo_size + # Needed half-windows span [hw_need_start, hw_need_end) + # Offset in extended_local = (hw_index - hw_need_start) + # This leverages the translational equivalence property of the gather_sliding_windows function: + # T(x[j_s:j_e], offsets[k_s:k_e] - j_s) = T(x, offsets)[k_s:k_e] + # where "T" is equivalent to the gather_sliding_windows function + # (or equivalently the underlying Toeplitz matrix operation). + + # Generate offsets for owned query windows + h = window_size + offset_start = 1 - h // 2 + owned_qw_offsets_global = torch.arange(offset_start + 2 * qw_start, offset_start + 2 * qw_end, 2, device=device) + + # Convert to local indices in extended_local + # Global offset points to a half-window index + # In extended_local: HW[hw_need_start] is at index 0 + local_offsets = owned_qw_offsets_global - hw_need_start + + # Apply efficient unfold on extended data + # local_result has shape (..., K_local, window_size, W/2, ...) where local_result.shape[axis] == K_local + local_result = gather_sliding_windows(extended_local, local_offsets, window_size, axis) + # This function requires the input dense_dtensor.shape[axis] is evenly sharded by the device_mesh + # and guarantees local_result is evenly sharded along the same axis + shape_output = list(local_result.shape) + for i_dim_device_mesh, i_axis in enumerate(i_axes_sharded_by_mesh_dim): + if i_axis is None: + continue + shape_output[i_axis] = shape_output[i_axis] * device_mesh.size(i_dim_device_mesh) + shape_output = tuple(shape_output) + strides_output = update_exhaustive_strides(local_result.shape, local_result.stride(), shape_output) + + # 6. Save context for backward + ctx.save_for_backward(local_offsets) + ctx.params = { + "fwd_local_input_shape": local_tensor.shape, + "fwd_input_shape": dense_dtensor.shape, + "fwd_output_shape": shape_output, + "ownership": ownership, + "axis": axis, + "mesh": device_mesh, + "placements": placements, + "i_dim_device_mesh_shard_axis": i_dim_device_mesh_shard_axis, + "window_size": window_size, + "recv_meta": recv_meta, + "send_meta": send_meta, + } + + # 7. Wrap in DTensor (sharded on query window dimension) + return DTensor.from_local( + local_result, device_mesh, placements, shape=torch.Size(shape_output), stride=tuple(strides_output) + ) + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[DTensor, None, None]: + """ + Distributed Backward Pass: Compute gradients with neighbor exchange. + + Steps: + 1. Compute gradient w.r.t. extended_local (includes halos) + 2. Split into [left_halo_grad, local_grad, right_halo_grad] + 3. Send halo grads back to neighbors who own that data + 4. Receive grads from neighbors who used our data as halos + 5. Accumulate and return + """ + # 1. Unpack context + (local_offsets,) = ctx.saved_tensors + params = ctx.params + fwd_local_input_shape = params["fwd_local_input_shape"] + fwd_input_shape = params["fwd_input_shape"] + fwd_output_shape = params["fwd_output_shape"] + ownership = params["ownership"] + axis = params["axis"] + mesh = params["mesh"] + placements = params["placements"] + i_dim_device_mesh_shard_axis = params["i_dim_device_mesh_shard_axis"] + window_size = params["window_size"] + recv_meta = params["recv_meta"] + send_meta = params["send_meta"] + + hw_start, hw_end = ownership["hw_owned"] + hw_need_start, hw_need_end = ownership["hw_needed"] + left_halo_size = ownership["left_halo_size"] + right_halo_size = ownership["right_halo_size"] + + # sanity checks + if not isinstance(grad_output, DTensor): + raise TypeError(f"Expected DTensor, got {type(grad_output)}") + if grad_output.device_mesh != mesh: + raise ValueError(f"grad_output device_mesh mismatch: expected {mesh}, got {grad_output.device_mesh}") + if grad_output.placements != placements: + raise ValueError(f"grad_output placements mismatch: expected {placements}, got {grad_output.placements}") + if grad_output.shape != fwd_output_shape: + raise ValueError(f"grad_output shape mismatch: expected {fwd_output_shape}, got {grad_output.shape}") + + local_grad_output = grad_output.to_local() + + # 2. Compute gradient w.r.t. extended_local using standalone backward function + # Build extended_local's shape (same shape as local_input but with extended_len at axis) + extended_len = hw_need_end - hw_need_start + extended_shape = list(fwd_local_input_shape) + extended_shape[axis] = extended_len + + # Call standalone backward to get gradient w.r.t. extended_local + grad_extended = gather_sliding_windows_backward( + local_grad_output, local_offsets, window_size, axis, tuple(extended_shape) + ) + + # 3. Split extended gradient + offset = 0 + grad_left_halo = grad_extended.narrow(axis, offset, left_halo_size) if left_halo_size > 0 else None + offset += left_halo_size + + local_owned_len = hw_end - hw_start + grad_local = grad_extended.narrow(axis, offset, local_owned_len) + offset += local_owned_len + + grad_right_halo = grad_extended.narrow(axis, offset, right_halo_size) if right_halo_size > 0 else None + + # 4. Exchange halo gradients + group = mesh.get_group(i_dim_device_mesh_shard_axis) + ranks_global_group = torch.distributed.get_process_group_ranks(group) + + ops = [] + + # 1. Send gradients for data I received in fwd (halo grads) + # fwd recv_meta: (peer, type, offset, length) + # I received 'length' from 'peer' into my 'type' halo at 'offset'. + # Now I send that slice of grad back to 'peer'. + for peer, h_type, offset, length in recv_meta: + if h_type == "left": + assert grad_left_halo is not None + grad_chunk = grad_left_halo.narrow(axis, offset, length) + else: + assert grad_right_halo is not None + grad_chunk = grad_right_halo.narrow(axis, offset, length) + + ops.append(dist.P2POp(dist.isend, grad_chunk.contiguous(), ranks_global_group[peer], group=group)) + + # 2. Recv gradients for data I sent in fwd (accumulate to local) + # fwd send_meta: (peer, offset, length) + # I sent 'length' from my local at 'offset' to 'peer'. + # Now I recv that grad from 'peer' and add to my local grad. + recv_grads = [] # (tensor, offset, length) + for peer, offset, length in send_meta: + grad_buf = torch.empty( + grad_local.shape[:axis] + (length,) + grad_local.shape[axis + 1 :], + dtype=grad_local.dtype, + device=grad_local.device, + ) + ops.append(dist.P2POp(dist.irecv, grad_buf, ranks_global_group[peer], group=group)) + recv_grads.append((grad_buf, offset, length)) + + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # 5. Accumulate received gradients + for grad_buf, offset, length in recv_grads: + target = grad_local.narrow(axis, offset, length) + target.add_(grad_buf) + + # 6. Return DTensor gradient + strides_grad_input = update_exhaustive_strides(grad_local.shape, grad_local.stride(), fwd_input_shape) + grad_input = DTensor.from_local( + grad_local, mesh, placements, shape=torch.Size(fwd_input_shape), stride=tuple(strides_grad_input) + ) + + return grad_input, None, None + + +def distributed_gather_sliding_windows(dense_dtensor: DTensor, window_size: int, axis: int) -> DTensor: + """ + Distributed version of gather_sliding_windows for DTensor inputs. + + Args: + dense_dtensor: Input DTensor sharded along axis dimension + window_size: h = H // (W//2) + axis: Dimension to apply windowing + + Returns: + DTensor sharded on query window dimension + """ + return DistributedGatherSlidingWindows.apply(dense_dtensor, window_size, axis) + + +def convert_single_repr_to_window_batched_key(x: DTensor, W: int, H: int) -> DTensor: + """Converts a single representation tensor to a window-batched key tensor. + + Reshapes and processes the input tensor to create overlapping windows suitable for + attention keys in windowed attention mechanisms. The input is unflattened into + half-windows and then gathered using sliding windows. + + Args: + x: Input tensor of shape (B, N, ...), where B is batch size and N is sequence length. + W: Query window size. + H: Key window size. + + Returns: + A DTensor of shape (B, K, H, ...), where K = N // W is the number of windows. + + Raises: + TypeError: If ``x`` is not a DTensor. + ValueError: If ``x`` has fewer than 2 dimensions. + ValueError: If ``x.shape[1]`` is not divisible by ``W``. + """ + # input is assumed to be in shape (B, N, ...) + if not isinstance(x, DTensor): + raise TypeError(f"x must be a DTensor, but got {type(x)}") + if x.ndim < 2: + raise ValueError(f"x must have at least 2 dimensions, but got x.ndim={x.ndim}") + + validate_window_batching_parameters(W, H, True) + + if x.shape[1] % W != 0: + raise ValueError(f"x.shape[1] must be divisible by W, but got x.shape[1]={x.shape[1]} and W={W}") + + K = x.shape[1] // W + h = H // (W // 2) + + # (B, K*W, D) -> (B, 2*K, W//2, D) + x_unflat_hw = shardwise_unflatten_sharded(x, axis=1, sizes=(2 * K, W // 2)) + # (B, 2*K, W//2, D) -> (B, K, h, W // 2, D) + x_unflat_key = distributed_gather_sliding_windows(x_unflat_hw, window_size=h, axis=1) + # (B, K, h, W // 2, D) -> (B, K, H, D) + x_key = shardwise_flatten(x_unflat_key, start_dim=2, end_dim=3) + return x_key + + +def convert_single_repr_window_batched_query_to_key(x: DTensor, W: int, H: int) -> DTensor: + """Converts a window-batched query tensor to a window-batched key tensor. + + First flattens the query-batched input and then converts it to a key tensor using + ``convert_single_repr_to_window_batched_key``. + + Args: + x: Input tensor of shape (B, K, W, ...), where B is batch size, K is number of windows, + and W is query window size. + W: Query window size. + H: Key window size. + + Returns: + A DTensor of shape (B, K, H, ...). + + Raises: + TypeError: If ``x`` is not a DTensor. + ValueError: If ``x`` has fewer than 3 dimensions. + ValueError: If ``x.shape[2]`` is not equal to ``W``. + """ + # input is assumed to be in shape (B, K, W, ...) + if not isinstance(x, DTensor): + raise TypeError(f"x must be a DTensor, but got {type(x)}") + if x.ndim < 3: + raise ValueError(f"x must have at least 3 dimensions, but got x.ndim={x.ndim}") + + if x.shape[2] != W: + raise ValueError(f"x.shape[2] must be equal to W, but got x.shape[2]={x.shape[2]} and W={W}") + + validate_window_batching_parameters(W, H, True) + + # (B, K, W, ...) -> (B, K*W, ...) + x_flat = shardwise_flatten_sharded(x, start_dim=1, end_dim=2) + x_key = convert_single_repr_to_window_batched_key(x_flat, W, H) + return x_key diff --git a/src/boltz/distributed/model/layers/where.py b/src/boltz/distributed/model/layers/where.py new file mode 100644 index 000000000..e0820b0bf --- /dev/null +++ b/src/boltz/distributed/model/layers/where.py @@ -0,0 +1,274 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import torch +from torch.distributed.tensor import DTensor, Partial, Shard + +from boltz.distributed.utils import LayoutRightMap + + +class _WhereImpl(torch.autograd.Function): + """Distributed implementation of where operation using DTensors. + + This autograd function implements distributed where operations that select + elements from two tensors based on a condition. The operation is performed + element-wise across distributed tensors while maintaining proper gradient computation. + + Supported operations: + - WHERE: output = torch.where(condition, x, y) + + Key features: + - Distributed computation across device meshes with various sharding strategies + - Memory-efficient implementation that operates on local tensor chunks + - Supports gradient computation through custom backward pass + - Handles broadcasting between condition, x, and y tensors + """ + + @staticmethod + def forward(ctx, condition: DTensor, x: DTensor, y: DTensor) -> DTensor: + """Forward pass of distributed where operation. + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object for saving information needed in backward pass. + condition : DTensor + Boolean condition tensor. Must be broadcastable with x and y. + x : DTensor + Values to select where condition is True. + y : DTensor + Values to select where condition is False. + + Returns + ------- + DTensor + Output tensor with same shape as the broadcasted shape of inputs. + Contains x where condition is True, y where condition is False. + + Raises + ------ + TypeError + If inputs are not DTensors. + ValueError + If Partial placements are used (not supported), or if tensors have + incompatible device meshes or placements. + """ + if not isinstance(condition, DTensor): + raise TypeError(f"Input 'condition' must be of type DTensor. Got type {type(condition)}.") + if not isinstance(x, DTensor): + raise TypeError(f"Input 'x' must be of type DTensor. Got type {type(x)}.") + if not isinstance(y, DTensor): + raise TypeError(f"Input 'y' must be of type DTensor. Got type {type(y)}.") + + # Validate that all tensors have same device mesh and placements + if condition.device_mesh != x.device_mesh or x.device_mesh != y.device_mesh: + raise ValueError( + f"All input tensors must have identical device mesh. " + f"Got device meshes {condition.device_mesh}, {x.device_mesh}, {y.device_mesh}." + ) + + if condition.placements != x.placements or x.placements != y.placements: + raise ValueError( + f"All input tensors must have identical placements. " + f"Got placements {condition.placements}, {x.placements}, {y.placements}." + ) + + device_mesh_input = x.device_mesh + placements_input = x.placements + + # Validate placements + for i_dim_device_mesh, placement in enumerate(placements_input): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + # Check that all tensors can be evenly sharded + for tensor_name, tensor in [("condition", condition), ("x", x), ("y", y)]: + if tensor.shape[placement.dim] % device_mesh_input.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding {tensor_name} tensor dimension {placement.dim} of size {tensor.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh_input.shape[i_dim_device_mesh]} is not supported" + ) + + # Get local tensors + condition_local = condition.to_local() + x_local = x.to_local() + y_local = y.to_local() + + # Perform the where operation + output_local = torch.where(condition_local, x_local, y_local) + + if x.requires_grad or y.requires_grad: + # Save condition for backward pass + condition_local_copy = condition_local.detach().clone() + ctx.save_for_backward(condition_local_copy) + ctx.device_mesh_input = device_mesh_input + ctx.placements_input = placements_input + ctx.x_requires_grad = x.requires_grad + ctx.y_requires_grad = y.requires_grad + ctx.shape_x = x.shape + ctx.stride_x = x.stride() + ctx.shape_y = y.shape + ctx.stride_y = y.stride() + + # x and y shapes are only constrained to be broadcastable + # without necessarily being the same shape + shape_output = torch.broadcast_shapes(x.shape, y.shape) + stride_output = LayoutRightMap(shape_output).strides + out = DTensor.from_local( + output_local, + device_mesh=device_mesh_input, + placements=placements_input, + shape=shape_output, + stride=stride_output, + ) + return out + + @staticmethod + def backward(ctx, grad_output: DTensor) -> tuple[None, DTensor | None, DTensor | None]: + """Backward pass of distributed where operation. + + Computes gradients with respect to x and y inputs. + + The gradients are: + - For x: grad_output where condition is True, 0 elsewhere + - For y: grad_output where condition is False, 0 elsewhere + - For condition: None (condition is not differentiable) + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFrame + Context object containing saved tensors and metadata from forward pass. + grad_output : DTensor + Gradients of the loss with respect to the output tensor. + + Returns + ------- + tuple[None, DTensor | None, DTensor | None] + Gradients with respect to x and y parameters. + condition gradient is always None. + """ + if not (ctx.x_requires_grad or ctx.y_requires_grad): + return None, None, None + + if not isinstance(grad_output, DTensor): + raise TypeError(f"Input 'grad_output' must be of type DTensor. Got type {type(grad_output)}.") + + if grad_output.device_mesh != ctx.device_mesh_input: + raise ValueError( + f"Input 'grad_output' must have the same device mesh as the input tensors. " + f"Got device meshes {grad_output.device_mesh} and {ctx.device_mesh_input}." + ) + + if grad_output.placements != ctx.placements_input: + raise ValueError( + f"Input 'grad_output' must have the same placements as the input tensors. " + f"Got placements {grad_output.placements} and {ctx.placements_input}." + ) + + grad_output_local = grad_output.to_local() + (condition_local,) = ctx.saved_tensors + + # Compute gradients + grad_x = None + grad_y = None + zeros_local = torch.zeros_like(grad_output_local) + + if ctx.x_requires_grad: + # Gradient flows to x where condition is True + grad_x_local = torch.where(condition_local, grad_output_local, zeros_local) + grad_x = DTensor.from_local( + grad_x_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=ctx.shape_x, + stride=ctx.stride_x, + ) + + if ctx.y_requires_grad: + # Gradient flows to y where condition is False + grad_y_local = torch.where(condition_local, zeros_local, grad_output_local) + grad_y = DTensor.from_local( + grad_y_local, + device_mesh=ctx.device_mesh_input, + placements=ctx.placements_input, + shape=ctx.shape_y, + stride=ctx.stride_y, + ) + + return None, grad_x, grad_y + + +def where(condition: DTensor, x: DTensor, y: DTensor) -> DTensor: + """Apply where operation to distributed tensors. + + This function selects elements from x or y based on condition. + Where condition is True, elements from x are selected; where condition is False, + elements from y are selected. The operation is performed efficiently using local + tensor operations while maintaining gradient computation capabilities. + + Parameters + ---------- + condition : DTensor + Boolean condition tensor. Must be broadcastable with x and y. + Should have placements compatible with x and y. + x : DTensor + Values to select where condition is True. + Can have any shape and sharding strategy compatible with condition and y. + y : DTensor + Values to select where condition is False. + Must have same shape, device mesh, and placements as x. + + Returns + ------- + DTensor + Output tensor with same shape as the broadcasted shape of inputs. + Contains x where condition is True, y where condition is False. + + Examples + -------- + >>> # Assume we have distributed tensors with shape (B, N, D) + >>> condition = x > 0.0 + >>> result = where(condition, x, y) + >>> # result = torch.where(condition, x, y), computed in distributed fashion + >>> + >>> # Clip using where (equivalent to clip operation) + >>> clipped = where(x > 5.0, torch.full_like(x, 5.0), x) + >>> # clipped = torch.where(x > 5.0, 5.0, x), computed in distributed fashion + + Notes + ----- + - All input tensors must be DTensors with compatible device meshes and placements + - Partial placements are not currently supported + - The function is differentiable and supports gradient computation for x and y + - The condition tensor is not differentiable + - The operation is performed on local tensor chunks for efficiency + - Broadcasting is handled by PyTorch's local where operation + + Raises + ------ + TypeError + If inputs are not DTensors. + ValueError + If Partial placements are used (not supported), or if tensors have + incompatible device meshes or placements. + """ + return _WhereImpl.apply(condition, x, y) # type: ignore diff --git a/src/boltz/distributed/model/loss/__init__.py b/src/boltz/distributed/model/loss/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/src/boltz/distributed/model/loss/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/src/boltz/distributed/model/loss/bfactor.py b/src/boltz/distributed/model/loss/bfactor.py new file mode 100644 index 000000000..060f10f09 --- /dev/null +++ b/src/boltz/distributed/model/loss/bfactor.py @@ -0,0 +1,305 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""DTensor-based context-parallel B-factor loss for Boltz-2. + +Implements the B-factor loss as a single torch.autograd.Function, extracting +local tensors once and performing all computation locally with explicit +all_reduce calls at the required communication points. + +The B-factor loss is per-token (no pairwise interaction), making it simpler +than the distogram loss. The only differentiable input is ``pred`` (the +predicted B-factor logits from BFactorModule). + +Equivalence to serial code (src/boltz/model/loss/bfactor.py): + 1. Map atom-level B-factors to tokens via token_to_rep_atom matrix + 2. Bin the per-token B-factors into a histogram (one-hot target) + 3. Compute cross-entropy loss between predicted logits and target + 4. Mask by valid tokens (bfactor > 1e-5), average over valid tokens + +The serial code computes a single global fraction:: + + loss = sum_{b,n}(errors * mask) / (sum_{b,n}(mask) + eps) + +This is NOT per-batch-then-averaged. The distributed version must match +this exactly by reducing both numerator and denominator globally across +all CP ranks and DP ranks before dividing. + +Because ``pred`` has single-representation placements +(Shard(0), Shard(1), Replicate()), only the dp and cp0 mesh dimensions +carry unique data. cp1 is Replicate — all cp1 ranks hold identical +local shards. We reduce over dp and cp0 only (NOT cp1) to avoid +double-counting the replicated data. + +Communication budget: + Forward (2–3 all_reduce calls): + 1. all_reduce(SUM) over dp group for packed [loss_sum, mask_sum] (1 call) + 2. all_reduce(SUM) over cp0 group for the same packed tensor (1 call) + Together equivalent to a single all_reduce over the combined dp×cp0 + group: sum_{dp×cp0}(x) = sum_{cp0}(sum_{dp}(x)). + 3. (optional) all_reduce(AVG) over cp1 group for the scalar loss, + enforcing identical values across Replicate ranks (1 call, when cp1_group + is provided). + Backward (0 collective calls): + The backward of all_reduce(SUM) is identity. +""" + +import torch +import torch.distributed as dist +from torch.autograd.function import FunctionCtx +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor, Partial, Replicate, Shard +from torch.distributed.tensor.device_mesh import DeviceMesh + + +class _BFactorLossCP(torch.autograd.Function): + """Single autograd.Function for the full B-factor loss. + + Forward: to_local() → local math with explicit all_reduces → from_local() + Backward: local math only (no communication) + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx: FunctionCtx, + pred: DTensor, + token_to_rep_atom: DTensor | torch.Tensor, + bfactor: DTensor | torch.Tensor, + device_mesh: DeviceMesh, + dp_group: ProcessGroup, + cp0_group: ProcessGroup, + cp1_group: ProcessGroup | None, + ) -> DTensor: + """Forward pass. + + Parameters + ---------- + pred : DTensor + Predicted B-factor logits [B, N, bins], placements + (Shard(0), Shard(1), Replicate()). + token_to_rep_atom : DTensor | Tensor + Token-to-representative-atom mapping [B, N_tokens, max_atoms_per_shard]. + Represents the diagonal block of the global one-hot tensor; + the last axis is the local atom shard and is NOT sharded across CP. + bfactor : DTensor | Tensor + Per-atom B-factors [B, A]. + device_mesh : DeviceMesh + 3D device mesh (dp, cp0, cp1). + dp_group : ProcessGroup + Process group for the dp mesh dimension (dim 0). + cp0_group : ProcessGroup + Process group for the cp_axis_0 mesh dimension (dim 1). + cp1_group : ProcessGroup | None + Process group for the cp1 (Replicate) mesh dimension (dim 2). + When not None, a mean all-reduce is applied to the scalar loss + to enforce identical values across cp1 ranks. + + Returns + ------- + global_loss : DTensor + Scalar loss, placements (Replicate(), Replicate(), Replicate()). + """ + # --- Validate differentiable input --- + if not isinstance(pred, DTensor): + raise TypeError(f"pred must be DTensor, got {type(pred)}") + + expected_single = (Shard(0), Shard(1), Replicate()) + if pred.placements != expected_single: + raise ValueError(f"pred placements {pred.placements} must be {expected_single}") + for i_dim, placement in enumerate(pred.placements): + if isinstance(placement, Partial): + raise ValueError(f"Partial placement on pred mesh dim {i_dim} is not supported") + if isinstance(placement, Shard) and pred.shape[placement.dim] % device_mesh.shape[i_dim] != 0: + raise ValueError( + f"Uneven sharding pred tensor dim {placement.dim} of size " + f"{pred.shape[placement.dim]} along mesh dim {i_dim} of size " + f"{device_mesh.shape[i_dim]}" + ) + + # Non-differentiable feature inputs: accept DTensor or plain tensor. + # The data pipeline may provide custom per-shard atom slicing that + # differs from standard distribute_tensor placements. + for name, tensor in [("token_to_rep_atom", token_to_rep_atom), ("bfactor", bfactor)]: + if not isinstance(tensor, (DTensor, torch.Tensor)): + raise TypeError(f"{name} must be DTensor or Tensor, got {type(tensor)}") + + # --- Extract local tensors --- + compute_dtype = torch.promote_types(pred.dtype, torch.float32) + pred_local = pred.to_local().to(compute_dtype) # [B_local, N_local, bins] + t2ra_local = (token_to_rep_atom.to_local() if isinstance(token_to_rep_atom, DTensor) else token_to_rep_atom).to( + compute_dtype + ) + bf_local = (bfactor.to_local() if isinstance(bfactor, DTensor) else bfactor).to(compute_dtype) + + bins = pred_local.shape[2] + + # --- Construct target (non-differentiable) --- + # Map atom-level bfactors to tokens + bfactor_token = torch.bmm(t2ra_local, bf_local.unsqueeze(-1)) # [B_local, N_local, 1] + + # Bin into histogram + boundaries = torch.linspace(0, 100, bins - 1, device=pred_local.device, dtype=compute_dtype) + bfactor_token_bin = (bfactor_token > boundaries).sum(dim=-1).long() # [B_local, N_local] + bfactor_target = torch.nn.functional.one_hot(bfactor_token_bin, num_classes=bins).to( + compute_dtype + ) # [B_local, N_local, bins] + + # Token validity mask + token_mask = (bfactor_token > 1e-5).squeeze(-1).to(compute_dtype) # [B_local, N_local] + + # --- Cross-entropy loss --- + log_softmax_local = torch.nn.functional.log_softmax(pred_local, dim=-1) + softmax_local = log_softmax_local.exp() # save for backward + + errors = -(bfactor_target * log_softmax_local).sum(dim=-1) # [B_local, N_local] + masked_errors = errors * token_mask # [B_local, N_local] + + # --- Global reduction matching serial semantics --- + # Serial: loss = sum_{b,n}(errors * mask) / (sum_{b,n}(mask) + eps) + # We sum over ALL local dims, then all_reduce across dp and cp0 + # in two sequential calls. cp1 is Replicate for single-representation + # data, so cp1 ranks hold identical values — reducing over cp1 would + # double-count. + # + # Two sequential all_reduces are equivalent to a single all_reduce + # over the combined dp×cp0 group: + # sum_{dp×cp0}(x) = sum_{cp0}(sum_{dp}(x)) + packed = torch.stack([masked_errors.sum(), token_mask.sum()]) + dist.all_reduce(packed, op=dist.ReduceOp.SUM, group=dp_group) + dist.all_reduce(packed, op=dist.ReduceOp.SUM, group=cp0_group) + + global_denom = packed[1] + 1e-5 + global_loss_local = packed[0] / global_denom + + # Average across cp1 (Replicate) ranks to enforce identical scalar + # loss values. This relies on pred having Replicate() on the cp1 mesh + # dimension (validated above as expected_single[2]). If the Replicate + # axis were on a different mesh dimension, the group here would need + # to match that dimension instead. + if cp1_group is not None: + dist.all_reduce(global_loss_local, op=dist.ReduceOp.AVG, group=cp1_group) + + # --- Save for backward --- + if pred.requires_grad: + ctx.save_for_backward( + softmax_local, + bfactor_target, + token_mask, + global_denom.unsqueeze(0), # wrap scalar in 1D for save_for_backward + ) + ctx.device_mesh = device_mesh + ctx.pred_placements = pred.placements + ctx.pred_shape = pred.shape + ctx.pred_stride = pred.stride() + + # --- Wrap result as DTensor --- + global_loss_placements = (Replicate(), Replicate(), Replicate()) + return DTensor.from_local( + global_loss_local, + device_mesh, + global_loss_placements, + shape=(), + stride=(), + ) + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx: FunctionCtx, d_global_loss: DTensor) -> tuple[DTensor | None, None, None, None, None, None, None]: + """Backward pass — entirely local, no collective communication. + + The gradient of all_reduce(SUM) is identity: each rank computes + its own local gradient contribution. + """ + if not ctx.needs_input_grad[0]: + return None, None, None, None, None, None, None + + softmax_local, bfactor_target, token_mask, (global_denom,) = ctx.saved_tensors + device_mesh = ctx.device_mesh + + d_gl = (d_global_loss.to_local() if isinstance(d_global_loss, DTensor) else d_global_loss).to( + softmax_local.dtype + ) + + # Chain rule: loss = sum_{b,n}(errors * mask) / global_denom + # errors = -sum(target * log_softmax(pred), dim=bins) + # d_pred[b,n,c] = d_loss * mask[b,n] / global_denom * (softmax[b,n,c] - target[b,n,c]) + scale = d_gl / global_denom + + d_pred_local = (softmax_local - bfactor_target) * token_mask.unsqueeze(-1) * scale + + d_pred = DTensor.from_local( + d_pred_local, + device_mesh=device_mesh, + placements=ctx.pred_placements, + shape=ctx.pred_shape, + stride=ctx.pred_stride, + ) + + return d_pred, None, None, None, None, None, None + + +def bfactor_loss( + output: dict[str, DTensor], + feats: dict[str, DTensor], + device_mesh: DeviceMesh, + dp_group: ProcessGroup, + cp0_group: ProcessGroup, + cp1_group: ProcessGroup | None = None, +) -> DTensor: + """Compute the B-factor loss using a single fused autograd.Function. + + Parameters + ---------- + output : dict[str, DTensor] + Model outputs containing: + - "pbfactor": [B, N, bins] predicted B-factor logits (DTensor). + feats : dict[str, DTensor] + Input features containing: + - "token_to_rep_atom": [B, N_tokens, max_atoms_per_shard] + token-to-atom mapping (DTensor). + - "bfactor": [B, A] per-atom B-factors (DTensor). + device_mesh : DeviceMesh + 3D device mesh (dp, cp0, cp1). + dp_group : ProcessGroup + Process group for the dp mesh dimension (dim 0). + cp0_group : ProcessGroup + Process group for the cp_axis_0 mesh dimension (dim 1). + cp1_group : ProcessGroup | None + Process group for the cp1 (Replicate) mesh dimension. + When provided, a mean all-reduce is applied to the scalar loss + to enforce identical values across cp1 ranks. + + Returns + ------- + DTensor + The globally averaged B-factor loss (scalar DTensor). + """ + with torch.autocast("cuda", enabled=False): + return _BFactorLossCP.apply( + output["pbfactor"], + feats["token_to_rep_atom"], + feats["bfactor"], + device_mesh, + dp_group, + cp0_group, + cp1_group, + ) diff --git a/src/boltz/distributed/model/loss/confidencev2.py b/src/boltz/distributed/model/loss/confidencev2.py new file mode 100644 index 000000000..7a07fe6bd --- /dev/null +++ b/src/boltz/distributed/model/loss/confidencev2.py @@ -0,0 +1,3383 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from copy import deepcopy +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor +from torch.autograd.function import FunctionCtx +from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_tensor + +from boltz.data import const +from boltz.distributed.comm import One2OneComm, TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.atom_to_token import single_repr_token_to_atom +from boltz.distributed.model.layers.clip import clip +from boltz.distributed.model.layers.elementwise_op import ( + ElementwiseOp, + elementwise_op, + scalar_tensor_op, +) +from boltz.distributed.model.layers.redistribute_transpose import redistribute_transpose +from boltz.distributed.model.layers.repeat_interleave import shardwise_repeat_interleave +from boltz.distributed.model.layers.sharded_op import sharded_sum +from boltz.distributed.model.loss.triton.cdist_lddt import cdist_lddt +from boltz.distributed.model.loss.triton.cdist_pde import cdist_pde +from boltz.distributed.utils import LayoutMap, LayoutRightMap, get_group_rank_from_axial_shift +from boltz.model.layers.confidence_utils import compute_collinear_mask + + +class _ResolvedNegativeLogLikelihoodImpl(torch.autograd.Function): + """Shardwise computation of resolved negative log-likelihood. + + This implements the forward and backward passes for computing the binary + cross-entropy loss for predicting whether atoms are resolved. The computation + is shardwise (no communication required) due to the block-diagonal structure + of token_to_rep_atom with intersperse padding. + + The NLL computation follows: + ref_mask = bmm(token_to_rep_atom, resolved_mask) + log_probs = log_softmax(pred_resolved, dim=-1) + errors = -ref_mask * log_probs[:,:,0] - (1 - ref_mask) * log_probs[:,:,1] + + The backward pass uses PyTorch autograd on the local computation graph, + avoiding the need for manual gradient derivation. + + See Also + -------- + resolved_negative_log_likelihood : The public API function that calls this. + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx: FunctionCtx, + pred_resolved: DTensor, + token_to_rep_atom: DTensor, + true_coords_resolved_mask: DTensor, + ) -> DTensor: + """Forward pass for computing resolved negative log-likelihood. + + Parameters + ---------- + ctx : FunctionCtx + The autograd context object for saving tensors for backward. + pred_resolved : DTensor + Predicted resolved logits with shape (B*mult, N_token, 2). + Placements: (Shard(0), Shard(1), Replicate()) + token_to_rep_atom : DTensor + One-hot mapping from tokens to representative atoms (non-multiplexed). + Shape (B, N_token, N_atom) with block-diagonal structure. + Placements: (Shard(0), Shard(1), Replicate()) + true_coords_resolved_mask : DTensor + Resolved mask for atoms. Shape (B*mult, N_atom). + Placements: (Shard(0), Shard(1), Replicate()) + + Returns + ------- + DTensor + Error tensor with shape (B*mult, N_token). + Placements: (Shard(0), Shard(1), Replicate()) + + Raises + ------ + TypeError + If inputs are not DTensors. + ValueError + If Partial placements are present or placements don't match expected. + """ + # Type checking + if not isinstance(pred_resolved, DTensor): + raise TypeError(f"Expected DTensor for pred_resolved, got {type(pred_resolved)}") + if not isinstance(token_to_rep_atom, DTensor): + raise TypeError(f"Expected DTensor for token_to_rep_atom, got {type(token_to_rep_atom)}") + if not isinstance(true_coords_resolved_mask, DTensor): + raise TypeError(f"Expected DTensor for true_coords_resolved_mask, got {type(true_coords_resolved_mask)}") + + device_mesh = pred_resolved.device_mesh + pred_placements = pred_resolved.placements + + # Validate placements - check no Partial and expected structure + for i_dim_device_mesh, placement in enumerate(pred_placements): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported for pred_resolved") + elif isinstance(placement, Shard): + # Check that sharded dimensions are evenly divided + if pred_resolved.shape[placement.dim] % device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding of tensor dimension {placement.dim} of size " + f"{pred_resolved.shape[placement.dim]} along device mesh dimension " + f"{i_dim_device_mesh} of size {device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + + # Validate device_mesh consistency across all inputs + if token_to_rep_atom.device_mesh != device_mesh: + raise ValueError( + f"Device mesh mismatch: pred_resolved has {device_mesh}, " + f"token_to_rep_atom has {token_to_rep_atom.device_mesh}" + ) + if true_coords_resolved_mask.device_mesh != device_mesh: + raise ValueError( + f"Device mesh mismatch: pred_resolved has {device_mesh}, " + f"true_coords_resolved_mask has {true_coords_resolved_mask.device_mesh}" + ) + + # Validate placements consistency across all inputs + # All inputs should have the same placements on the 3D mesh: (Shard(0), Shard(1), Replicate()) + if token_to_rep_atom.placements != pred_placements: + raise ValueError( + f"Placements mismatch: pred_resolved has {pred_placements}, " + f"token_to_rep_atom has {token_to_rep_atom.placements}" + ) + if true_coords_resolved_mask.placements != pred_placements: + raise ValueError( + f"Placements mismatch: pred_resolved has {pred_placements}, " + f"true_coords_resolved_mask has {true_coords_resolved_mask.placements}" + ) + + # Validate shape consistency across all inputs + # Extract dimensions from token_to_rep_atom: (B, N_token, N_atom_padded) + if len(token_to_rep_atom.shape) != 3: + raise ValueError(f"token_to_rep_atom must be 3D, got shape {token_to_rep_atom.shape}") + batch_size = token_to_rep_atom.shape[0] + n_token = token_to_rep_atom.shape[1] + + # Validate true_coords_resolved_mask shape: (B*mult, N_atom_padded) + if len(true_coords_resolved_mask.shape) != 2 or true_coords_resolved_mask.shape[0] % batch_size != 0: + raise ValueError( + f"true_coords_resolved_mask must be 2D with shape[0] divisible by batch_size ({batch_size}), " + f"got shape {tuple(true_coords_resolved_mask.shape)}" + ) + multiplicity = true_coords_resolved_mask.shape[0] // batch_size + + # Validate pred_resolved shape: (B*mult, N_token, 2) + expected_pred_shape = (batch_size * multiplicity, n_token, 2) + if tuple(pred_resolved.shape) != expected_pred_shape: + raise ValueError( + f"Shape mismatch: pred_resolved has shape {tuple(pred_resolved.shape)}, expected {expected_pred_shape}" + ) + + # Detach and set requires_grad to build a local computation graph. + # Use promote_types to match serial resolved_loss which casts via .float(); + # promote_types promotes to at least float32 while preserving float64. + compute_dtype = torch.promote_types(pred_resolved.dtype, torch.float32) + pred_local = ( + pred_resolved.to_local().detach().to(dtype=compute_dtype).requires_grad_(pred_resolved.requires_grad) + ) + token_to_rep_atom_local = token_to_rep_atom.to_local().to(dtype=compute_dtype) + resolved_mask_local = true_coords_resolved_mask.to_local().to(dtype=compute_dtype) + + # Validate n_atom consistency on local shards (both padded to max_atoms_per_shard) + if token_to_rep_atom_local.shape[-1] != resolved_mask_local.shape[-1]: + raise ValueError( + f"Local shard atom dimension mismatch: token_to_rep_atom has {token_to_rep_atom_local.shape[-1]}, " + f"true_coords_resolved_mask has {resolved_mask_local.shape[-1]}" + ) + + # Infer multiplicity from local shapes + # token_to_rep_atom_local: (B_local, N_token_local, N_atom_padded) + # resolved_mask_local: (B_local*mult, N_atom_padded) + b_local = token_to_rep_atom_local.shape[0] + multiplicity = resolved_mask_local.shape[0] // b_local + + with torch.enable_grad(): + # Build a local computation graph for the shardwise operations + # Reshape resolved_mask to (B_local, mult, N_atom_padded) for einsum + resolved_mask_reshaped = resolved_mask_local.view(b_local, multiplicity, -1) + + # Use einsum to compute ref_mask without repeat_interleave on token_to_rep_atom + # token_to_rep_atom_local: (B_local, N_token_local, N_atom_padded) -> "btj" + # resolved_mask_reshaped: (B_local, mult, N_atom_padded) -> "bmj" + # ref_mask: (B_local, mult, N_token_local) -> "bmt" + ref_mask = torch.einsum("btj,bmj->bmt", token_to_rep_atom_local, resolved_mask_reshaped) + # Flatten to (B_local*mult, N_token_local) + ref_mask = ref_mask.flatten(0, 1) + + # Compute log softmax probabilities + log_softmax_resolved = F.log_softmax(pred_local, dim=-1) + + # Compute binary cross-entropy errors + # errors = -ref_mask * log_probs[resolved] - (1 - ref_mask) * log_probs[unresolved] + errors = -ref_mask * log_softmax_resolved[:, :, 0] - (1 - ref_mask) * log_softmax_resolved[:, :, 1] + + # Compute output shape and stride (same batch and token dims, remove bins dim) + output_shape = tuple(pred_resolved.shape[:-1]) + output_stride = LayoutRightMap(output_shape).strides + + # Output placements: same as input but without the last (bins) dimension + # Since pred_resolved has (Shard(0), Shard(1), Replicate()), output is (Shard(0), Shard(1), Replicate()) + # but we need to handle the 2D case + output_placements = pred_placements + + # Save tensors for backward pass + ctx.save_for_backward(pred_local, errors) + ctx.device_mesh = device_mesh + ctx.pred_placements = pred_placements + ctx.pred_shape = pred_resolved.shape + ctx.pred_stride = pred_resolved.stride() + ctx.output_shape = output_shape + ctx.output_stride = output_stride + ctx.output_placements = output_placements + + # Create output DTensor + result = DTensor.from_local( + errors.detach(), + device_mesh=device_mesh, + placements=output_placements, + shape=output_shape, + stride=output_stride, + ) + + return result + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward( + ctx: FunctionCtx, + grad_output: DTensor, + ) -> tuple[DTensor | None, None, None]: + """Backward pass for computing resolved negative log-likelihood. + + Computes gradients by backpropagating through the local computation graph + that was built during the forward pass. This leverages PyTorch's autograd + rather than manual gradient computation. + + Parameters + ---------- + ctx : FunctionCtx + The autograd context containing saved tensors from forward. + grad_output : DTensor + Gradient of loss with respect to output errors. + + Returns + ------- + tuple[DTensor | None, None, None] + Gradients for each forward input in order: + - d_pred_resolved: DTensor or None, gradient for pred_resolved + - None: token_to_rep_atom (non-differentiable, one-hot) + - None: true_coords_resolved_mask (non-differentiable, ground truth) + """ + pred_local, errors_local = ctx.saved_tensors + + if not pred_local.requires_grad: + return None, None, None + + grad_output_local = grad_output.to_local() + + # Backprop via the local graph + (d_pred_local,) = torch.autograd.grad( + outputs=[errors_local], + inputs=[pred_local], + grad_outputs=[grad_output_local], + retain_graph=False, # Frees the local graph immediately + ) + + # Wrap gradient in DTensor + d_pred = DTensor.from_local( + d_pred_local, + device_mesh=ctx.device_mesh, + placements=ctx.pred_placements, + shape=ctx.pred_shape, + stride=ctx.pred_stride, + ) + + return d_pred, None, None + + +def resolved_negative_log_likelihood( + pred_resolved: DTensor, + token_to_rep_atom: DTensor, + true_coords_resolved_mask: DTensor, +) -> DTensor: + """Compute shardwise negative log-likelihood for resolved prediction. + + This is the DTensor-compatible version of the resolved loss NLL computation. + All operations are shardwise (no inter-rank communication required) due to + the block-diagonal structure of token_to_rep_atom with intersperse padding. + + The multiplicity is inferred from the shapes: + multiplicity = true_coords_resolved_mask.shape[0] // token_to_rep_atom.shape[0] + + Parameters + ---------- + pred_resolved : DTensor + Predicted resolved logits with shape (B*mult, N_token, 2). + Placements: (Shard(0), Shard(1), Replicate()) + token_to_rep_atom : DTensor + One-hot mapping from tokens to representative atoms (non-multiplexed). + Shape (B, N_token, N_atom) with block-diagonal structure. + Placements: (Shard(0), Shard(1), Replicate()) + true_coords_resolved_mask : DTensor + Resolved mask for atoms. Shape (B*mult, N_atom). + Placements: (Shard(0), Shard(1), Replicate()) + + Returns + ------- + DTensor + Error tensor with shape (B*mult, N_token). + Placements: (Shard(0), Shard(1), Replicate()) + + See Also + -------- + resolved_loss : Full loss function that uses this for NLL computation. + boltz.model.loss.confidencev2.resolved_loss : Serial version. + """ + return _ResolvedNegativeLogLikelihoodImpl.apply(pred_resolved, token_to_rep_atom, true_coords_resolved_mask) + + +def resolved_loss( + pred_resolved: DTensor, + feats: dict[str, DTensor], + true_coords_resolved_mask: DTensor, + multiplicity: int = 1, +) -> DTensor: + """Compute resolved loss using DTensor operations. + + This is the DTensor-compatible version of the resolved_loss function. + It computes binary cross-entropy loss for predicting whether atoms are resolved + in the structure. + + The computation is split into: + 1. Part a) Shardwise NLL computation via resolved_negative_log_likelihood + 2. Part b) Distributed weighted sum along token axis, then mean over batch + + Parameters + ---------- + pred_resolved : DTensor + Predicted resolved logits with shape (B*mult, N_token, 2). + Placements: (Shard(0), Shard(1), Replicate()) + feats : dict[str, DTensor] + Feature dictionary containing: + - token_to_rep_atom: One-hot mapping (B, N_token, N_atom) + - token_pad_mask: Padding mask (B, N_token) + true_coords_resolved_mask : DTensor + Resolved mask for atoms. Shape (B*mult, N_atom). + Placements: (Shard(0), Shard(1), Replicate()) + multiplicity : int, optional + Diffusion batch multiplier, by default 1 + + Returns + ------- + DTensor + Scalar loss with placements (Replicate(), Replicate(), Replicate()) + + See Also + -------- + resolved_negative_log_likelihood : Shardwise NLL computation. + boltz.model.loss.confidencev2.resolved_loss : Serial version. + """ + # Part a) - Shardwise NLL computation (token_to_rep_atom is non-multiplexed) + errors = resolved_negative_log_likelihood(pred_resolved, feats["token_to_rep_atom"], true_coords_resolved_mask) + + # Part b) - Weighted sum along token axis (dim=-1), then mean over batch + # Expand pad_mask with multiplicity + pad_mask = shardwise_repeat_interleave(feats["token_pad_mask"], multiplicity, dim=0) + # Following diffusion.py pattern (lines 413-420) + # num = sum(errors * pad_mask, dim=-1) + num = sharded_sum(elementwise_op(errors, pad_mask, ElementwiseOp.PROD), dim=-1) + # den = sum(pad_mask, dim=-1) + den = sharded_sum(pad_mask, dim=-1) + # loss_per_sample = num / max(den, 1e-7) + loss_per_sample = elementwise_op(num, clip(den, min_val=1e-7, max_val=None), ElementwiseOp.DIV) + + # Mean over batch dimension (following diffusion.py lines 423-427) + loss = scalar_tensor_op( + 1.0 / loss_per_sample.shape[0], + sharded_sum(loss_per_sample, dim=0), + ElementwiseOp.PROD, + ) + + return loss + + +# PAE loss numerical stability constants +FRAME_NORM_EPS = 1e-5 # prevents division by zero in frame basis normalization +PAE_DIST_EPS = 1e-8 # prevents sqrt(0) in PAE target distance computation +PAE_LOSS_DENOM_EPS = 1e-7 # prevents division by zero when normalizing by mask sum + + +def _check_pae_input_consistency( + pred_pae: DTensor, + pred_atom_coords: DTensor, + true_atom_coords: DTensor, + true_coords_resolved_mask: DTensor, + feats: dict[str, DTensor], + multiplicity: int, +) -> None: + """Validate input tensors for PAE loss computation. + + Checks type, device mesh consistency, placement correctness, and shape + compatibility for all inputs to pae_loss. + + Parameters + ---------- + pred_pae : DTensor + Predicted PAE logits. Expected shape: (B*mult, N_token, N_token, bins). + Expected placements: (Shard(0), Shard(1), Shard(2)). + pred_atom_coords : DTensor + Predicted atom coordinates. Expected shape: (B*mult, N_atom, 3). + Expected placements: (Shard(0), Shard(1), Replicate()). + true_atom_coords : DTensor + True atom coordinates. Expected shape: (B*mult, N_atom, 3). + Expected placements: (Shard(0), Shard(1), Replicate()). + true_coords_resolved_mask : DTensor + Resolved mask for atoms. Expected shape: (B*mult, N_atom). + Expected placements: (Shard(0), Shard(1), Replicate()). + feats : dict[str, DTensor] + Feature dictionary containing frames_idx, frame_resolved_mask, token_pad_mask, etc. + multiplicity : int + Diffusion batch multiplier. + + Raises + ------ + TypeError + If any input is not a DTensor. + ValueError + If device meshes are inconsistent, placements are incorrect, shapes don't match, + or sharding is uneven. + """ + # --- Type checks --- + if not isinstance(pred_pae, DTensor): + raise TypeError(f"Expected DTensor for pred_pae, got {type(pred_pae)}") + if not isinstance(pred_atom_coords, DTensor): + raise TypeError(f"Expected DTensor for pred_atom_coords, got {type(pred_atom_coords)}") + if not isinstance(true_atom_coords, DTensor): + raise TypeError(f"Expected DTensor for true_atom_coords, got {type(true_atom_coords)}") + if not isinstance(true_coords_resolved_mask, DTensor): + raise TypeError(f"Expected DTensor for true_coords_resolved_mask, got {type(true_coords_resolved_mask)}") + + # --- Device mesh consistency --- + device_mesh = pred_pae.device_mesh + if pred_atom_coords.device_mesh != device_mesh: + raise ValueError( + f"Device mesh mismatch: pred_pae has {device_mesh}, pred_atom_coords has {pred_atom_coords.device_mesh}" + ) + if true_atom_coords.device_mesh != device_mesh: + raise ValueError( + f"Device mesh mismatch: pred_pae has {device_mesh}, true_atom_coords has {true_atom_coords.device_mesh}" + ) + if true_coords_resolved_mask.device_mesh != device_mesh: + raise ValueError( + f"Device mesh mismatch: pred_pae has {device_mesh}, " + f"true_coords_resolved_mask has {true_coords_resolved_mask.device_mesh}" + ) + + # --- Placement validation for pred_pae: (Shard(0), Shard(1), Shard(2)) --- + expected_pae_placements = (Shard(0), Shard(1), Shard(2)) + if pred_pae.placements != expected_pae_placements: + raise ValueError(f"pred_pae must have placements {expected_pae_placements}, got {pred_pae.placements}") + + # Check sharding divisibility for pred_pae + # Shard(0) -> dim 0 (B*mult) sharded by mesh dim 0 (dp) + # Shard(1) -> dim 1 (N_token) sharded by mesh dim 1 (cp_axis_0) + # Shard(2) -> dim 2 (N_token) sharded by mesh dim 2 (cp_axis_1) + for mesh_dim, placement in enumerate(pred_pae.placements): + if isinstance(placement, Shard): + tensor_dim_size = pred_pae.shape[placement.dim] + mesh_dim_size = device_mesh.shape[mesh_dim] + if tensor_dim_size % mesh_dim_size != 0: + raise ValueError( + f"pred_pae dimension {placement.dim} (size {tensor_dim_size}) " + f"is not evenly divisible by mesh dimension {mesh_dim} (size {mesh_dim_size})" + ) + + # --- Placement validation for coords: (Shard(0), Shard(1), Replicate()) --- + expected_coords_placements = (Shard(0), Shard(1), Replicate()) + if pred_atom_coords.placements != expected_coords_placements: + raise ValueError( + f"pred_atom_coords must have placements {expected_coords_placements}, got {pred_atom_coords.placements}" + ) + if true_atom_coords.placements != expected_coords_placements: + raise ValueError( + f"true_atom_coords must have placements {expected_coords_placements}, got {true_atom_coords.placements}" + ) + + # Check sharding divisibility for coords + for mesh_dim, placement in enumerate(pred_atom_coords.placements): + if isinstance(placement, Shard): + tensor_dim_size = pred_atom_coords.shape[placement.dim] + mesh_dim_size = device_mesh.shape[mesh_dim] + if tensor_dim_size % mesh_dim_size != 0: + raise ValueError( + f"pred_atom_coords dimension {placement.dim} (size {tensor_dim_size}) " + f"is not evenly divisible by mesh dimension {mesh_dim} (size {mesh_dim_size})" + ) + + # --- Shape validation --- + # pred_pae: (B*mult, N_token, N_token, bins) + if pred_pae.ndim != 4: + raise ValueError(f"pred_pae must be 4D, got {pred_pae.ndim}D with shape {pred_pae.shape}") + + batch_mult_size = pred_pae.shape[0] + + if pred_pae.shape[1] != pred_pae.shape[2]: + raise ValueError(f"pred_pae must have equal N_token dimensions (dims 1 and 2), got shape {pred_pae.shape}") + + if batch_mult_size % multiplicity != 0: + raise ValueError( + f"pred_pae batch dimension (shape[0]={batch_mult_size}) must be divisible by multiplicity ({multiplicity})" + ) + + # pred_atom_coords / true_atom_coords: (B*mult, N_atom, 3) + if pred_atom_coords.ndim != 3: + raise ValueError( + f"pred_atom_coords must be 3D, got {pred_atom_coords.ndim}D with shape {pred_atom_coords.shape}" + ) + if pred_atom_coords.shape[0] != batch_mult_size: + raise ValueError( + f"pred_atom_coords batch dimension (shape[0]={pred_atom_coords.shape[0]}) " + f"does not match pred_pae batch dimension ({batch_mult_size})" + ) + if pred_atom_coords.shape[2] != 3: + raise ValueError(f"pred_atom_coords must have 3 coordinates, got shape {pred_atom_coords.shape}") + + if true_atom_coords.shape != pred_atom_coords.shape: + raise ValueError( + f"true_atom_coords shape {true_atom_coords.shape} " + f"does not match pred_atom_coords shape {pred_atom_coords.shape}" + ) + + # true_coords_resolved_mask: (B*mult, N_atom) + if true_coords_resolved_mask.ndim != 2: + raise ValueError( + f"true_coords_resolved_mask must be 2D, got {true_coords_resolved_mask.ndim}D " + f"with shape {true_coords_resolved_mask.shape}" + ) + if true_coords_resolved_mask.shape[0] != batch_mult_size: + raise ValueError( + f"true_coords_resolved_mask batch dimension (shape[0]={true_coords_resolved_mask.shape[0]}) " + f"does not match pred_pae batch dimension ({batch_mult_size})" + ) + + n_atom = pred_atom_coords.shape[1] + if true_coords_resolved_mask.shape[1] != n_atom: + raise ValueError( + f"true_coords_resolved_mask N_atom dimension ({true_coords_resolved_mask.shape[1]}) " + f"does not match pred_atom_coords ({n_atom})" + ) + + if "frames_idx" in feats and feats["frames_idx"].ndim == 4: + raise ValueError( + f"frames_idx has unsqueezed ensemble dim (ndim=4, shape={feats['frames_idx'].shape}). " + "Only E=1 is supported; squeeze the ensemble dim before calling pae_loss." + ) + if "frame_resolved_mask" in feats and feats["frame_resolved_mask"].ndim == 3: + raise ValueError( + f"frame_resolved_mask has unsqueezed ensemble dim (ndim=3, shape={feats['frame_resolved_mask'].shape}). " + "Only E=1 is supported; squeeze the ensemble dim before calling pae_loss." + ) + + +def pae_loss( + pred_pae: DTensor, + pred_atom_coords: DTensor, + true_atom_coords: DTensor, + true_coords_resolved_mask: DTensor, + feats: dict[str, DTensor], + comm: One2OneComm, + dist_manager: DistributedManager, + group_layout: LayoutMap, + multiplicity: int = 1, + max_dist: float = 32.0, +) -> DTensor: + """Compute PAE (Predicted Aligned Error) loss using DTensor operations. + + Tensor-compatible version of the pae_loss function. + It computes cross-entropy loss for predicting the alignment error between + predicted and true atom coordinates when expressed in local reference frames. + + Sharding Strategy + ----------------- + The 3D device mesh has shape (dp, cp_axis_0, cp_axis_1). Placements specify + which tensor dimension to shard across each mesh dimension: + + :: + + pred_pae shape: (B*mult, N_token, N_token, bins) + Placements: (Shard(0), Shard(1), Shard(2)) + │ │ │ + ▼ ▼ ▼ + Mesh dims: dp cp_axis_0 cp_axis_1 + + Example with B=2, mult=2, N=32, bins=64 on dp=2, cp=(2,2): + + Global pred_pae: (4, 32, 32, 64) + + Device Mesh (cp_axis_0 × cp_axis_1): + cp_axis_1 + ┌──────┬──────┐ + cp_axis_0 │ R0 │ R1 │ R0: (2, 16, 16, 64) tokens[0:16, 0:16] + ├──────┼──────┤ R1: (2, 16, 16, 64) tokens[0:16, 16:32] + │ R2 │ R3 │ R2: (2, 16, 16, 64) tokens[16:32, 0:16] + └──────┴──────┘ R3: (2, 16, 16, 64) tokens[16:32, 16:32] + + Parameters + ---------- + pred_pae : DTensor + Predicted PAE logits with shape (B*mult, N_token, N_token, num_bins). + Placements: (Shard(0), Shard(1), Shard(2)) + pred_atom_coords : DTensor + Predicted atom coordinates with shape (B*mult, N_atom, 3). + Placements: (Shard(0), Shard(1), Replicate()) + true_atom_coords : DTensor + True atom coordinates with shape (B*mult, N_atom, 3). + Placements: (Shard(0), Shard(1), Replicate()) + true_coords_resolved_mask : DTensor + Resolved mask for atoms. Shape (B*mult, N_atom). + Placements: (Shard(0), Shard(1), Replicate()) + feats : dict[str, DTensor] + Feature dictionary containing: + + - frames_idx: Frame atom indices (B, N_token, 3) + Placements: (Shard(0), Shard(1), Replicate()) + - frame_resolved_mask: Frame validity mask (B, N_token) + Placements: (Shard(0), Shard(1), Replicate()) + - asym_id: Asymmetric unit IDs (B, N_token) + Placements: (Shard(0), Shard(1), Replicate()) + - atom_to_token: Atom to token mapping (B, N_atom, N_token) + Placements: (Shard(0), Shard(1), Replicate()) with intersperse padding + - atom_pad_mask: Atom padding mask (B, N_atom) + Placements: (Shard(0), Shard(1), Replicate()) with intersperse padding + - mol_type: Molecule type (B, N_token) + Placements: (Shard(0), Shard(1), Replicate()) + - token_pad_mask: Token padding mask (B, N_token) + Placements: (Shard(0), Shard(1), Replicate()) + - atom_resolved_mask: Atom resolved mask (B, N_atom) + Placements: (Shard(0), Shard(1), Replicate()) with intersperse padding + - is_nonpolymer_with_frame: Non-polymer frame indicator (B, N_token) + Placements: (Shard(0), Shard(1), Replicate()) + + comm : One2OneComm + Communication object for coordinate transpose operations in frame + computation for non-polymers. + dist_manager : DistributedManager + Distributed manager for process group information. + group_layout : LayoutMap + Layout map for the 2D CP grid. + multiplicity : int, optional + Diffusion batch multiplier, by default 1 + max_dist : float, optional + Maximum distance for PAE binning, by default 32.0 + + Returns + ------- + DTensor + Scalar loss with placements (Replicate(), Replicate(), Replicate()) + + See Also + -------- + compute_frame_pred : Distributed frame computation for non-polymers. + boltz.model.loss.confidencev2.pae_loss : Serial version. + """ + _check_pae_input_consistency( + pred_pae, + pred_atom_coords, + true_atom_coords, + true_coords_resolved_mask, + feats, + multiplicity, + ) + + return _PAELossImpl.apply( + pred_pae, + pred_atom_coords, + true_atom_coords, + true_coords_resolved_mask, + feats, + comm, + dist_manager, + group_layout, + multiplicity, + max_dist, + ) + + +def _express_coordinate_in_frame_distributed( + atom_coords: Tensor, + frame_atom_a: Tensor, + frame_atom_b: Tensor, + frame_atom_c: Tensor, + dist_manager: DistributedManager, + group_layout: LayoutMap, + transpose_comm: TransposeComm, +) -> Tensor: + """Distributed express_coordinate_in_frame for 2D-sharded output. + + For rank (shard_i, shard_j) in the 2D CP mesh, computes the block + [i_start:i_end, j_start:j_end] of the full N_token × N_token output. + + Args: + atom_coords: [B, mult, N_atom_local, 3] local atom coordinates + frame_atom_a/b/c: [B, mult, N_token_local] local frame indices (already + converted to local atom indices) + dist_manager: DistributedManager for communication + group_layout: LayoutMap for the CP 2D mesh + + Returns: + x_transformed: [B, mult, N_token_local_i, N_token_local_j, 3] transformed coordinates + mask_collinear: [B, mult, N_token_local_i] collinear mask for row tokens + """ + n_atoms_local = atom_coords.shape[2] + n_tokens_local = frame_atom_a.shape[-1] + + cp_group = dist_manager.group["cp"] + rank_coords = group_layout.unravel(dist.get_rank(cp_group)) + atom_offset = rank_coords[0] * n_atoms_local + global_min = atom_offset + global_max = atom_offset + n_atoms_local + + # Identify if any rank will rely on nonlocal frame coordinates. If so, all ranks will use the global gather path + # to avoid deadlock in cooperative operations. + frame_requires_global_coords = torch.any( + (frame_atom_a < global_min) + | (frame_atom_a >= global_max) + | (frame_atom_b < global_min) + | (frame_atom_b >= global_max) + | (frame_atom_c < global_min) + | (frame_atom_c >= global_max) + ).to(torch.int32) + dist.all_reduce(frame_requires_global_coords, op=dist.ReduceOp.MAX, group=cp_group) + + if frame_requires_global_coords.item() == 1: + a = _gather_frame_coords( + atom_coords, + frame_atom_a, + local_only=False, + atom_offset=atom_offset, + n_tokens_local=n_tokens_local, + dist_manager=dist_manager, + group_layout=group_layout, + ) + b = _gather_frame_coords( + atom_coords, + frame_atom_b, + local_only=False, + atom_offset=atom_offset, + n_tokens_local=n_tokens_local, + dist_manager=dist_manager, + group_layout=group_layout, + ) + c = _gather_frame_coords( + atom_coords, + frame_atom_c, + local_only=False, + atom_offset=atom_offset, + n_tokens_local=n_tokens_local, + dist_manager=dist_manager, + group_layout=group_layout, + ) + else: + a = _gather_frame_coords( + atom_coords, + frame_atom_a, + local_only=True, + atom_offset=atom_offset, + n_tokens_local=n_tokens_local, + dist_manager=dist_manager, + group_layout=group_layout, + ) + b = _gather_frame_coords( + atom_coords, + frame_atom_b, + local_only=True, + atom_offset=atom_offset, + n_tokens_local=n_tokens_local, + dist_manager=dist_manager, + group_layout=group_layout, + ) + c = _gather_frame_coords( + atom_coords, + frame_atom_c, + local_only=True, + atom_offset=atom_offset, + n_tokens_local=n_tokens_local, + dist_manager=dist_manager, + group_layout=group_layout, + ) + + # Exchange b coordinates with transpose peer early to overlap with local frame basis work. + b_j = transpose_comm.enqueue_to_dispatch(b.contiguous()) + + # Build orthonormal frame from local a, b, c + # a, b, c: [B, mult, N_token_local, 3] + ab = a - b + cb = c - b + w1 = ab / (torch.norm(ab, dim=-1, keepdim=True) + FRAME_NORM_EPS) + w2 = cb / (torch.norm(cb, dim=-1, keepdim=True) + FRAME_NORM_EPS) + e1 = (w1 + w2) / (torch.norm(w1 + w2, dim=-1, keepdim=True) + FRAME_NORM_EPS) + e2 = (w2 - w1) / (torch.norm(w2 - w1, dim=-1, keepdim=True) + FRAME_NORM_EPS) + e3 = torch.linalg.cross(e1, e2) + + # Collinear mask from correctly gathered frame coordinates + # Flatten (B, mult, N_token_local) to (...) for compute_collinear_mask + orig_shape = ab.shape[:-1] # (B, mult, N_token_local) + mask_collinear = compute_collinear_mask( + ab.reshape(-1, 3), + cb.reshape(-1, 3), + ).reshape(orig_shape) + + # Ensure transpose exchange completed before using b_j. + transpose_comm.wait_until_finished() + + # Pairwise displacement: d[i,j] = b_j[j] - b[i] + # b: [B, mult, N_token_local_i, 3] (local row tokens) + # b_j: [B, mult, N_token_local_j, 3] (gathered column tokens) + d = b_j[:, :, None, :, :] - b[:, :, :, None, :] # [B, mult, N_i, N_j, 3] + + # Project onto local frame basis via batched matmul + basis = torch.stack([e1, e2, e3], dim=-1) # [B, mult, N_i, 3, 3] + x_transformed = torch.matmul(d.unsqueeze(-2), basis[:, :, :, None, :, :]).squeeze(-2) + + return x_transformed, mask_collinear + + +class _PAELossImpl(torch.autograd.Function): + """Distributed PAE loss computation with autograd support. + + Handles the forward pass with local autograd for gradient computation, + and the backward pass with proper gradient routing for DTensors. + + The gradient w.r.t. pred_pae flows through log_softmax → gather → errors. + The gradient w.r.t. pred_atom_coords is zero because bin_index computation + (torch.floor) is non-differentiable. + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + pred_pae: DTensor, + pred_atom_coords: DTensor, + true_atom_coords: DTensor, + true_coords_resolved_mask: DTensor, + feats: dict[str, DTensor], + comm: One2OneComm, + dist_manager: DistributedManager, + group_layout: LayoutMap, + multiplicity: int, + max_dist: float, + ) -> DTensor: + device_mesh = pred_pae.device_mesh + num_bins = pred_pae.shape[-1] + dp_size = device_mesh.shape[0] + + # Extract local tensors from DTensors + # pred_pae is (B*mult, N_token, N_token, bins); unflatten to (B_batch, mult, N_i, N_j, bins) + pred_pae_local = pred_pae.to_local().unflatten(0, (-1, multiplicity)) + pred_atom_coords_local = pred_atom_coords.to_local() + true_atom_coords_local = true_atom_coords.to_local() + true_coords_resolved_mask_local = true_coords_resolved_mask.to_local() + + frame_resolved_mask_local = feats["frame_resolved_mask"].to_local() + token_pad_mask_local = feats["token_pad_mask"].to_local() + + # Get rank info for 2D sharding + B_local, N_atom_local, _ = true_atom_coords_local.shape + group_axis_0 = dist_manager.subgroups["cp"][0] + group_rank_0 = dist.get_rank(group_axis_0) + atom_offset = N_atom_local * group_rank_0 + if B_local % multiplicity != 0: + raise ValueError( + f"true_atom_coords local batch dim ({B_local}) must be " f"divisible by multiplicity ({multiplicity})" + ) + B_batch = B_local // multiplicity + if pred_atom_coords_local.shape[0] != true_atom_coords_local.shape[0]: + raise ValueError( + f"pred_atom_coords batch dim ({pred_atom_coords_local.shape[0]}) " + f"!= true_atom_coords batch dim ({true_atom_coords_local.shape[0]})" + ) + if true_coords_resolved_mask_local.shape != true_atom_coords_local.shape[:2]: + raise ValueError( + f"true_coords_resolved_mask shape {tuple(true_coords_resolved_mask_local.shape)} " + f"must match true_atom_coords[:2] {tuple(true_atom_coords_local.shape[:2])}" + ) + + # --- Step 1: Compute target values and masks --- + with torch.no_grad(): + mask_frame_true = frame_resolved_mask_local + + # Compute frames for true coords (use DTensor wrapper for serial-consistent results) + frames_idx_true_dt, _ = compute_frame_pred( + true_atom_coords, + feats["frames_idx"], + feats, + multiplicity, + resolved_mask=true_coords_resolved_mask, + ) + frames_idx_true = frames_idx_true_dt.to_local() + + true_atom_coords_reshaped = true_atom_coords_local.reshape(B_batch, multiplicity, -1, 3) + transpose_comm = TransposeComm(dist_manager.group["cp"], group_layout) + true_coords_transformed, mask_collinear_true = _express_coordinate_in_frame_distributed( + true_atom_coords_reshaped, + frames_idx_true[:, :, :, 0], + frames_idx_true[:, :, :, 1], + frames_idx_true[:, :, :, 2], + dist_manager, + group_layout, + transpose_comm=transpose_comm, + ) + mask_collinear_true = mask_collinear_true * token_pad_mask_local[:, None, :] + + # Compute frames for pred coords (use DTensor wrapper for serial-consistent results) + frames_idx_pred_dt, _ = compute_frame_pred( + pred_atom_coords, + feats["frames_idx"], + feats, + multiplicity, + ) + frames_idx_pred = frames_idx_pred_dt.to_local() + + pred_atom_coords_reshaped = pred_atom_coords_local.reshape(B_batch, multiplicity, -1, 3) + pred_coords_transformed, mask_collinear_pred = _express_coordinate_in_frame_distributed( + pred_atom_coords_reshaped, + frames_idx_pred[:, :, :, 0], + frames_idx_pred[:, :, :, 1], + frames_idx_pred[:, :, :, 2], + dist_manager, + group_layout, + transpose_comm=transpose_comm, + ) + mask_collinear_pred = mask_collinear_pred * token_pad_mask_local[:, None, :] + + # Compute target PAE distances + target_pae = torch.sqrt(((true_coords_transformed - pred_coords_transformed) ** 2).sum(-1) + PAE_DIST_EPS) + + # Compute bin indices for cross-entropy + bin_index = torch.clamp(torch.floor(target_pae * num_bins / max_dist).long(), max=(num_bins - 1)) + + # Build pair mask: gather resolved status for all 3 frame atoms. + # Each diffusion sample has its own resolved mask (symmetry_correction + # can produce different masks per sample), so preserve the per-sample + # variation rather than collapsing to sample 0. + resolved_reshaped = true_coords_resolved_mask_local.reshape(B_batch, multiplicity, -1) + token_pad_mask_bool = token_pad_mask_local[:, None, :].bool() + N_token_local = frames_idx_true.shape[-2] + + frames_masked = frames_idx_true.masked_fill(~token_pad_mask_bool.unsqueeze(-1), atom_offset) + requires_global_gather = torch.any( + (frames_masked < atom_offset) | (frames_masked >= atom_offset + N_atom_local) + ).to(dtype=torch.int32) + if N_token_local != N_atom_local: + requires_global_gather = torch.ones_like(requires_global_gather) + dist.all_reduce(requires_global_gather, op=dist.ReduceOp.MAX, group=dist_manager.group["cp"]) + + if requires_global_gather.item() == 1: + resolved_flat = resolved_reshaped.reshape(B_batch * multiplicity, N_atom_local, 1) + index_flat = frames_idx_true.reshape(B_batch * multiplicity, N_token_local, 3) + gathered = ring_gather_coordinate(resolved_flat, index_flat, dist_manager, group_layout) + # gathered: (B*mult, N_token_local, 1, 3) → squeeze → (B_batch, mult, N_token_local, 3) + frame_resolved_abc = gathered.squeeze(-2).reshape(B_batch, multiplicity, N_token_local, 3) + else: + frames_local = frames_idx_true - atom_offset + frame_resolved_abc = torch.stack( + [torch.gather(resolved_reshaped, dim=2, index=frames_local[:, :, :, k]) for k in range(3)], + dim=-1, + ) + + b_true_resolved_mask_local = frame_resolved_abc[:, :, :, 1] + + # Exchange masks with transpose peer for column (j) dimension + b_true_resolved_mask_j = transpose_comm.enqueue_to_dispatch(b_true_resolved_mask_local.contiguous()) + transpose_comm.wait_until_finished() + token_pad_mask_j = transpose_comm.enqueue_to_dispatch(token_pad_mask_local.contiguous()) + transpose_comm.wait_until_finished() + + pair_mask = ( + mask_frame_true[:, None, :, None] + * mask_collinear_true[:, :, :, None] + * mask_collinear_pred[:, :, :, None] + * b_true_resolved_mask_j[:, :, None, :] + * token_pad_mask_local[:, None, :, None] + * token_pad_mask_j[:, None, None, :] + ) + + # Compute local mask sum and reduce across CP ranks. + sum_mask_local = pair_mask.sum(dim=(-2, -1)) + sum_mask_global_cp = sum_mask_local.clone() + # Reduce across CP ranks first + dist.all_reduce(sum_mask_global_cp, op=dist.ReduceOp.SUM, group=dist_manager.group["cp"]) + + # --- Step 2: Compute loss with gradients using global normalization --- + with torch.enable_grad(): + pred_pae_local_grad = pred_pae_local.detach().requires_grad_(pred_pae.requires_grad) + log_softmax_pae = F.log_softmax(pred_pae_local_grad, dim=-1) + target_log_prob = torch.gather(log_softmax_pae, dim=-1, index=bin_index.unsqueeze(-1)).squeeze(-1) + errors = -target_log_prob + + # Local sum of masked errors + masked_errors = errors * pair_mask + sum_errors_local = masked_errors.sum(dim=(-2, -1)) + + # Normalize by CP-global mask sum per sample for gradient computation. + loss_per_sample = sum_errors_local / (sum_mask_global_cp + PAE_LOSS_DENOM_EPS) + loss_local = loss_per_sample.mean() / dp_size + + # --- Step 3: Compute global loss--- + # Strategy: Reduce errors across CP, normalize per sample, then average across DP ranks. + # + # This matches the serial definition: + # loss = mean_bm(sum_ij(errors) / sum_ij(mask)) + # + # Reduce errors across CP ranks + sum_errors_global_cp = sum_errors_local.detach().clone() + dist.all_reduce(sum_errors_global_cp, op=dist.ReduceOp.SUM, group=dist_manager.group["cp"]) + + # Normalize per sample, then average across local samples + loss_per_sample_final = sum_errors_global_cp / (sum_mask_global_cp + PAE_LOSS_DENOM_EPS) + loss_scalar = loss_per_sample_final.mean() + + # Average across DP ranks to match global batch mean + dist.all_reduce(loss_scalar, op=dist.ReduceOp.SUM, group=dist_manager.group["dp"]) + loss_scalar = loss_scalar / dp_size + + # Save for backward + ctx.save_for_backward(pred_pae_local_grad, loss_local) + ctx.device_mesh = device_mesh + ctx.pred_pae_shape = pred_pae.shape + ctx.pred_pae_stride = pred_pae.stride() + ctx.pred_pae_placements = pred_pae.placements + # Create replicated DTensor for loss + loss_dtensor = DTensor.from_local( + loss_scalar, + device_mesh=device_mesh, + placements=(Replicate(), Replicate(), Replicate()), + shape=loss_scalar.shape, + stride=loss_scalar.stride(), + ) + + return loss_dtensor + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_loss: DTensor): + pred_pae_local_grad, loss_local = ctx.saved_tensors + if not pred_pae_local_grad.requires_grad: + return ( + None, # pred_pae + None, # pred_atom_coords (non-differentiable through floor) + None, # true_atom_coords + None, # true_coords_resolved_mask + None, # feats + None, # comm + None, # dist_manager + None, # group_layout + None, # multiplicity + None, # max_dist + ) + + grad_loss_scalar = grad_loss.to_local() + (grad_pred_pae_local,) = torch.autograd.grad( + outputs=[loss_local], + inputs=[pred_pae_local_grad], + grad_outputs=[grad_loss_scalar], + retain_graph=False, + ) + + # Flatten (B_batch, mult, ...) back to (B_batch*mult, ...) to match DTensor layout + grad_pred_pae_local = grad_pred_pae_local.flatten(0, 1) + + # Create DTensor gradient for pred_pae + grad_pred_pae = DTensor.from_local( + grad_pred_pae_local, + device_mesh=ctx.device_mesh, + placements=ctx.pred_pae_placements, + shape=ctx.pred_pae_shape, + stride=ctx.pred_pae_stride, + ) + + return ( + grad_pred_pae, # pred_pae + None, # pred_atom_coords (non-differentiable through floor) + None, # true_atom_coords + None, # true_coords_resolved_mask + None, # feats + None, # comm + None, # dist_manager + None, # group_layout + None, # multiplicity + None, # max_dist + ) + + +def all_reduce_dist_mat_argmin( + dist_mat: Tensor, + group_reduce_dist_mat_row: dist.ProcessGroup, + order: int = 3, + inplace: bool = False, +) -> Tensor: + """Perform distributed argmin operation along given dimension up to the k-th smallest element. + + This function expects dist_mat to be sharded across the process group, such that: + device mesh = [[ 0, 1], + [ 2, 3]] + dist_mat = [[ d00, d01], + [ d10, d11]] + + The returned global_argmin will use a different sharding strategy, where it is sharded column-wise and replicated row-wise. + global_argmin = [[ a0, a0], + [ a1, a1]] + + Args: + dist_mat (torch.Tensor): Local shard of distance matrix to perform argmin operation on. Shape = (..., n, n). + group_reduce_dist_mat_row (dist.ProcessGroup): cp axis 0 process group to reduce the argmin operation to. + order (int): Number of the smallest elements to find iteratively. Default is 3 to find the closest 3 atoms to construct a frame. + inplace (bool): If True, perform the operation on dist_mat in place. Default is False. + Returns: + torch.Tensor: global argmin indices of the first n smallest elements in shape = (..., order) + """ + if order > dist_mat.shape[-1] * group_reduce_dist_mat_row.size(): + raise ValueError( + "order must be less than or equal to the number of elements in the group but got order = {}, dist_mat.shape[dim] = {}, group_reduce_dist_mat_row.size() = {}".format( + order, dist_mat.shape[-1], group_reduce_dist_mat_row.size() + ) + ) + if dist_mat.shape[-2] != dist_mat.shape[-1]: + raise ValueError("distance matrix must be square but got shape = {}".format(dist_mat.shape)) + + if not inplace: # save memory through inplace operation + dist_mat = dist_mat.clone() + + out_tensor = [] + for _ in range(order): + # find local min and argmin + local_min, local_argmin = torch.min(dist_mat, dim=-1) # shape = (B, N) + + # reduce to global min and check if it is the same as the local min + global_min = local_min + dist.all_reduce(global_min, op=dist.ReduceOp.MIN, group=group_reduce_dist_mat_row) + + # locate global min locally + is_global_min = global_min.unsqueeze(-1) == dist_mat # shape = (B, N, N) + has_global_min = is_global_min.any(dim=-1) # shape = (B, N) + + # offset local argmin by group rank to get global argmin + global_argmin = local_argmin.clone() + dist.get_rank(group_reduce_dist_mat_row) * dist_mat.shape[-1] + + # mask non-global-min argmin from broadcasting + global_argmin[~has_global_min] = 0 + + # broadcast the output tensor to the process group + dist.all_reduce(global_argmin, op=dist.ReduceOp.SUM, group=group_reduce_dist_mat_row) # shape = (B, N) + + # aggregate indices + out_tensor.append(global_argmin) + + # mask the global min with inf for the next order + dist_mat[is_global_min] = torch.inf + + out_tensor = torch.stack(out_tensor, dim=-1) + return out_tensor + + +def ring_dist_mat_argmin( + dist_mat: Tensor, + group_cp: dist.ProcessGroup, + group_reduce_dist_mat_row: dist.ProcessGroup, + group_layout: LayoutMap, + order: int = 3, +) -> Tensor: + """Perform distributed argmin operation up to the k-th smallest element. The distance matrix is assumed to be square and symmetric. + + This function expects dist_mat to be sharded across the process group, such that: + device mesh = [[ 0, 1], + [ 2, 3]] + dist_mat = [[ d00, d01], + [ d10, d11]] + + The returned global_argmin will use a different sharding strategy, where it is sharded column-wise and replicated row-wise. + global_argmin = [[ a0, a0], + [ a1, a1]] + + Args: + dist_mat (torch.Tensor): Local shard of distance matrix to perform argmin operation on. Shape = (..., n, n). + group_cp (dist.ProcessGroup): Context parallelism process group + group_reduce_dist_mat_row (dist.ProcessGroup): Row-wise reduction process group + group_layout (LayoutMap): Layout map of the process group + order (int): Number of the smallest elements to find iteratively. Default is 3 to find the closest 3 atoms to construct a frame. + Returns: + torch.Tensor: global argmin indices of the first k smallest elements in shape = (..., n, k) + """ + cp_axis_0_group = group_reduce_dist_mat_row + if order > dist_mat.shape[-1] * cp_axis_0_group.size(): + raise ValueError( + "order must be less than or equal to the number of elements in the group but got order = {}, dist_mat.shape[dim] = {}, cp_axis_0_group.size() = {}".format( + order, dist_mat.shape[-1], cp_axis_0_group.size() + ) + ) + + if dist_mat.shape[-2] != dist_mat.shape[-1]: + raise ValueError("distance matrix must be square but got shape = {}".format(dist_mat.shape)) + + # setup for communication + rank_coords = group_layout.unravel(dist.get_rank(group_cp)) + topk_comm = One2OneComm( + group_cp, + rank_send_to=get_group_rank_from_axial_shift(rank_coords, 1, -1, group_layout), + rank_recv_from=get_group_rank_from_axial_shift(rank_coords, 1, 1, group_layout), + ) + topk_idx_comm = deepcopy(topk_comm) + + # find local topk and topk_idx + max_order = min(order, dist_mat.shape[-1]) + local_topk, local_topk_idx = torch.topk(dist_mat, k=max_order, dim=-1, largest=False) # shape = (..., n, k) + + # offset local topk_idx by group rank to get global topk_idx + global_topk = local_topk + global_topk_idx = local_topk_idx + dist.get_rank(group_reduce_dist_mat_row) * dist_mat.shape[-1] + + # send out to the next rank for the first time + current_topk = topk_comm.enqueue_to_dispatch(global_topk) + current_topk_idx = topk_idx_comm.enqueue_to_dispatch(global_topk_idx) + topk_comm.wait_until_finished() + topk_idx_comm.wait_until_finished() + + for step in range(cp_axis_0_group.size() - 1): + # overlap communication with computation by sending out to the next rank + if step != cp_axis_0_group.size() - 2: + next_topk = topk_comm.enqueue_to_dispatch(current_topk) + next_topk_idx = topk_idx_comm.enqueue_to_dispatch(current_topk_idx) + + # concatenate the received topk and topk_idx + global_topk = torch.cat([global_topk, current_topk], dim=-1) # shape = (..., n, 2k) + global_topk_idx = torch.cat([global_topk_idx, current_topk_idx], dim=-1) # shape = (..., n, 2k) + + # find the largest k values in global_topk and the corresponding indices in global_topk_idx + max_order = min(order, global_topk.shape[-1]) + topk_values, topk_indices = torch.topk(global_topk, k=max_order, dim=-1, largest=False) # shape = (..., n, k) + + # select the topk_indices by value from global_topk_idx + global_topk = topk_values + global_topk_idx = global_topk_idx.gather(dim=-1, index=topk_indices) # shape = (..., n, k) + + # receive from the previous rank + if step != cp_axis_0_group.size() - 2: + topk_comm.wait_until_finished() + topk_idx_comm.wait_until_finished() + current_topk = next_topk + current_topk_idx = next_topk_idx + + return global_topk_idx + + +def ring_gather_coordinate( + coordinate: Tensor, + global_argmin: Tensor, + dist_manager: DistributedManager, + group_layout: LayoutMap, +) -> Tensor: + """Distributed version of torch.gather in ring topology for coordinate gathering. + + example sharding strategy on N_atoms=2, world_size=4 + device mesh = [[ 0, 1], + [ 2, 3]] + + the gather operation is aggregated through a ring topology by rolling up the coordinates while offsetting the global_argmin by the axial shift. + + step = 0 + coordinate = [[ c0, c0], global_argmin = [[ g0, g0], + [ c1, c1]] [ g1, g1]] + + step = 1 (roll up) + coordinate = [[ c1, c1], global_argmin = [[ g0, g0], + [ c0, c0]] [ g1, g1]] + + outputs follow same sharding: + gathered_coords = [[ c0, c0 ], + [ c1, c1 ]] + + Args: + coordinate (torch.Tensor): Sharded coordinates of shape = (B, n_atoms_local, D) to be gathered from. + global_argmin (torch.Tensor): Sharded global argmin atom indices of shape = (B, n_tokens_local, order). + dist_manager (DistributedManager): Distributed manager + group_layout (LayoutMap): Layout map of the process group + + Returns: + torch.Tensor: gathered coordinates; shape = (B, n_tokens_local, D, order) + """ + ring_size = group_layout.shape[0] + bs, n_atoms_local, coord_dim = coordinate.shape + _, n_tokens_local, order = global_argmin.shape + + batch_idx = ( + torch.arange(bs, device=coordinate.device).view(bs, 1, 1).expand(-1, n_tokens_local, order) + ) # shape = (B, n_tokens_local, order) + + rank_coords = group_layout.unravel(dist.get_rank(dist_manager.group["cp"])) + comm = One2OneComm( + dist_manager.group["cp"], + rank_send_to=get_group_rank_from_axial_shift(rank_coords, 0, -1, group_layout), + rank_recv_from=get_group_rank_from_axial_shift(rank_coords, 0, 1, group_layout), + ) + + gathered_coords = torch.zeros( + bs, n_tokens_local, order, coord_dim, device=coordinate.device, dtype=coordinate.dtype + ) + for step in range(ring_size): + if step + 1 != ring_size: + next_coordinate = comm.enqueue_to_dispatch(coordinate.contiguous()) + + idx_range = (rank_coords[0] + step) % ring_size + shard_start = idx_range * n_atoms_local + is_argmin_local = (shard_start <= global_argmin) & (global_argmin < (idx_range + 1) * n_atoms_local) + + local_argmin = global_argmin - shard_start + local_argmin = local_argmin.masked_fill(~is_argmin_local, 0) + + gathered = coordinate[batch_idx, local_argmin] + gathered = gathered.masked_fill(~is_argmin_local[..., None], 0) + gathered_coords += gathered + + if step + 1 != ring_size: + comm.wait_until_finished() + coordinate = next_coordinate + + return gathered_coords.permute(0, 1, 3, 2) # shape = (B, n_tokens_local, D, order) + + +def _gather_frame_coords( + atom_coords: Tensor, + frame_atoms: Tensor, + *, + local_only: bool, + atom_offset: int, + n_tokens_local: int, + dist_manager: DistributedManager, + group_layout: LayoutMap, +) -> Tensor: + """Gather frame atom coordinates for local or global frame indices. + + Args: + atom_coords: Local atom coordinates with shape [B, mult, N_atom_local, 3]. + frame_atoms: Frame atom indices with shape [B, mult, N_token_local, order]. + local_only: If True, gather directly from local shard using atom_offset. + atom_offset: Global atom index offset for this shard. + n_tokens_local: Number of local tokens for this shard. + dist_manager: DistributedManager for collective communications. + group_layout: LayoutMap for CP mesh ring gather. + + Returns: + Gathered frame coordinates with shape [B, mult, N_token_local, 3]. + """ + batch, multiplicity = atom_coords.shape[0], atom_coords.shape[1] + n_atoms_local = atom_coords.shape[2] + device = atom_coords.device + + if local_only: + local_idx = frame_atoms - atom_offset + batch_indices0 = torch.arange(batch, device=device)[:, None, None] + batch_indices1 = torch.arange(multiplicity, device=device)[None, :, None] + return atom_coords[batch_indices0, batch_indices1, local_idx] + + # Flatten (batch, multiplicity) into a single batch dimension for ring gather. + # ring_gather_coordinate expects (B, n_atoms_local, D) and (B, n_tokens_local, order). + coords_flat = atom_coords.reshape(batch * multiplicity, n_atoms_local, 3) + idx_flat = frame_atoms.reshape(batch * multiplicity, n_tokens_local, 1) + gathered = ring_gather_coordinate(coords_flat, idx_flat, dist_manager, group_layout) + # Unflatten back to (batch, multiplicity) to match the caller's expectations. + return gathered.squeeze(-1).reshape(batch, multiplicity, n_tokens_local, 3) + + +def _fully_distributed_compute_frame_pred( + pred_atom_coords: Tensor, + frames_idx_true: Tensor, + feats: dict[str, Tensor], + multiplicity: int, + comm: One2OneComm, + dist_manager: DistributedManager, + group_layout: LayoutMap, + resolved_mask: Optional[Tensor] = None, + inference: bool = False, + return_frames_expanded: bool = False, +) -> tuple[Tensor, Tensor, Tensor] | tuple[Tensor, Tensor]: + """Recompute the frames for non-polymer over 3 atoms given the predicted atom coordinates. + + .. deprecated:: + Unused — superseded by ``compute_frame_pred`` (the DTensor wrapper that + gathers inputs and delegates to ``_compute_frame_pred``). Retained for + reference only. Contains known resolved_mask indexing bugs that are + NOT fixed here; see _compute_frame_pred for the corrected version. + + example sharding strategy on N_atoms=2, world_size=4 + device mesh = [[ 0, 1], + [ 2, 3]] + + inputs should follow a sharding strategy below. + pred_coords = [[ c0, c0], frames_idx = [[ f0, f0], resolved_mask = [[m0, m0], + [ c1, c1]] [ f1, f1]] [m1, m1]] + + outputs follow same sharding: + frames_idx_pred = [[ f0, f0], mask_collinear = [[m0, m0], frames_expanded = [[c0, c0], + [ f1, f1]] [m1, m1]] [c1, c1]] + + Args: + pred_atom_coords (torch.Tensor): Predicted atom coordinates of shape = (B, n_atoms_per_shard, 3). + frames_idx_true (torch.Tensor): True frames indices of shape = (B, n_atoms_per_shard, order). + feats (dict): Dictionary of feature tensors + multiplicity (int): Multiplicity of the predicted atom coordinates + comm (One2OneComm): Communication class for sending and receiving tensors + dist_manager (DistributedManager): Distributed manager + group_layout (LayoutMap): Layout map for context parallelism + resolved_mask (torch.Tensor, optional): Resolved mask; shape = (B, n_atoms_per_shard). Defaults to None to use atom_resolved_mask and atom_pad_mask in feats. + inference (bool, optional): Whether to use inference mode which skips resolved_mask. Defaults to False. + return_frames_expanded (bool, optional): Whether to return the expanded frames for unittest purposes. Defaults to False. + + Returns: + torch.Tensor: Updated frames indices; shape = (B, N, order) + torch.Tensor: Mask for collinear or overlapping atoms in the frame; shape = (B, N, order) + optional torch.Tensor: The closest 3 atom coordinates for the frames; shape = (B, N, order, 3). Returned only if return_frames_expanded is True. + """ + # group settings + group_axis_1 = dist_manager.subgroups["cp"][0] + group_rank_1 = dist.get_rank(group_axis_1) + + # extract necessary features + asym_id_token = feats["asym_id"] + asym_id_atom = torch.bmm(feats["atom_to_token"].float(), asym_id_token.unsqueeze(-1).float()).squeeze(-1) + B, N, _ = pred_atom_coords.shape + + pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3) + frames_idx_pred = ( + frames_idx_true.clone().repeat_interleave(multiplicity, 0).reshape(B // multiplicity, multiplicity, -1, 3) + ) + + frames_expanded = [] + + # Iterate through the batch and update the frames for non-polymers + for i, pred_atom_coord in enumerate(pred_atom_coords): + # Gather reference atom coordinates per token per order, i.e. frame per token + pred_atom_coords_sample = pred_atom_coords[i, :, :, :] # atom coordinates; shape = (multiplicity, N_atoms, 3) + idx = ( + frames_idx_pred[i] - pred_atom_coords.shape[-2] * group_rank_1 + ) # token to **local** atom indices; shape = (multiplicity, N_tokens) + idx = idx.masked_fill( + feats["is_nonpolymer_with_frame"][i][:, None], 0 + ) # reset frames from non-polymer chains with fewer than 3 atoms to 0 + + N_tokens = idx.shape[-2] + batch_idx = ( + torch.arange(multiplicity).view(multiplicity, 1, 1).expand(-1, N_tokens, 3) + ) # shape = (multiplicity, N_tokens, 3) + frame_expanded_sample = pred_atom_coords_sample[batch_idx, idx].transpose( + -1, -2 + ) # shape = (multiplicity, N_atoms, 3, orders) + + assert frame_expanded_sample.shape == (multiplicity, N_tokens, 3, 3) + frame_expanded_sample = frame_expanded_sample.masked_fill( + feats["is_nonpolymer_with_frame"][i][None, :, None, None], 0 + ) # reset frames from non-polymer chains with fewer than 3 atoms to 0 + + # Gather unique asym_ids from all ranks + asym_ids_unique = [set() for _ in range(group_axis_1.size())] + dist.all_gather_object(asym_ids_unique, set(asym_id_token[i].tolist()), group=group_axis_1) + asym_ids_unique = sorted(set.union(*asym_ids_unique)) + + for id in asym_ids_unique: + mask_chain_token = (asym_id_token[i] == id) * feats["token_pad_mask"][i] + mask_chain_atom = (asym_id_atom[i] == id) * feats["atom_pad_mask"][i] + mask_chain_token = mask_chain_token.bool() + mask_chain_atom = mask_chain_atom.bool() + + # Check if the chain satisfies the criteria for frame recomputation + # 1. is a non-polymer + # 2. has at least 3 atoms + # TODO: streamline this with is_nonpolymer_with_frame + num_tokens = mask_chain_token.sum() + num_atoms = mask_chain_atom.sum() + dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM, group=group_axis_1) + dist.all_reduce(num_atoms, op=dist.ReduceOp.SUM, group=group_axis_1) + + mol_type = feats["mol_type"][i, mask_chain_token].unique() + assert len(mol_type) <= 1, "all chains in the batch must have the same mol_type" + + is_target_mol_type = ( + (mol_type.item() == const.chain_type_ids["NONPOLYMER"]) and (num_atoms.item() > 3) + if len(mol_type) > 0 + else False + ) + is_target_mol_type = torch.tensor(is_target_mol_type, device=mol_type.device) + dist.all_reduce(is_target_mol_type, op=dist.ReduceOp.SUM, group=group_axis_1) + if ~is_target_mol_type: + continue + assert ( + num_tokens.item() == num_atoms.item() + ), "num_tokens and num_atoms must be the same for non-polymers, got num_tokens = {}, num_atoms = {}".format( + num_tokens.item(), num_atoms.item() + ) + + # Compute all-to-all atom distance matrix, including those that are not part of the chain + pred_atom_coord_i = pred_atom_coord + pred_atom_coord_j = comm.enqueue_to_dispatch(pred_atom_coord_i) + comm.wait_until_finished() + dist_mat = torch.cdist(pred_atom_coord_i, pred_atom_coord_j) # shape = (multiplicity, N, N) + + # Restrict neighborhood frame atom search to + # 1. atoms that are not padding/are resolved, and + # 2. atoms that are part of the chain + if inference: + resolved_mask_i = feats["atom_pad_mask"][i] + elif resolved_mask is None: + resolved_mask_i = feats["atom_resolved_mask"][i] + resolved_mask_i = ( + resolved_mask_i * feats["atom_pad_mask"][i] + ) # apply atom_pad_mask for padding in context parallelism + else: + resolved_mask_i = resolved_mask[i] + + resolved_mask_j = comm.enqueue_to_dispatch(resolved_mask_i) + comm.wait_until_finished() + resolved_pair = (1 - resolved_mask_i[:, None] * resolved_mask_j[None, :]).to( + torch.float32 + ) # shape = (N, N) + resolved_pair[resolved_pair == 1] = torch.inf + + mask_chain_atom_i = mask_chain_atom + mask_chain_atom_j = comm.enqueue_to_dispatch(mask_chain_atom_i) + comm.wait_until_finished() + mask_chain_atom_pair = 1 - (mask_chain_atom_i[:, None] * mask_chain_atom_j[None, :]).to(torch.float) + mask_chain_atom_pair[mask_chain_atom_pair == 1] = torch.inf + + # Sort the atoms by distance + masked_dist_mat = dist_mat + resolved_pair + mask_chain_atom_pair + global_argmin = ring_dist_mat_argmin( + masked_dist_mat, + order=3, + group_cp=dist_manager.group["cp"], + group_reduce_dist_mat_row=dist_manager.subgroups["cp"][1], + group_layout=group_layout, + ) # shape = (multiplicity, N_token, order) + atom_to_ref_coords = ring_gather_coordinate( + pred_atom_coord_i, + global_argmin, + dist_manager, + group_layout=group_layout, + ) # shape = (multiplicity, N_token, 3, order) + + # Map reference atom repr to token repr + # non-polymer has one token per atom so this is a mapping instead of aggregation + atom_to_token = feats["atom_to_token"][i].float() + token_to_ref_coords = torch.einsum( + "miab,ij->mjab", atom_to_ref_coords, atom_to_token + ) # shape = (multiplicity, N_token, 3, order) + frames = torch.einsum( + "mib,ij->mjb", global_argmin.float(), atom_to_token + ).long() # shape = (multiplicity, N_token, order) + + # pass reference frames for non-polymer chains + frame_reorder_index = torch.tensor([1, 0, 2], device=token_to_ref_coords.device) + frames_idx_pred[i, :, :, :] = ( + frames[:, :, frame_reorder_index] * mask_chain_token[None, :, None] + + frames_idx_pred[i, :, :, :] * ~mask_chain_token[None, :, None] + ) # shape = (multiplicity, N_token, order) + + # NOTE: discrepancy in padding will result in different indexing compared to + # single-device implementation; compare frames_expanded instead of frames_idx_pred + + # pass reference atom coordinates for non-polymer chains + frame_expanded_sample = ( + token_to_ref_coords[:, :, :, frame_reorder_index] * mask_chain_token[None, :, None, None] + + frame_expanded_sample * ~mask_chain_token[None, :, None, None] + ) # shape = (multiplicity, N_token, 3, order) + + # append per sample + frames_expanded.append(frame_expanded_sample) + + # concatenate per sample + frames_expanded = torch.cat(frames_expanded, dim=0).reshape(-1, 3, 3) + frames_expanded = frames_expanded.transpose(1, 2) # shape = (..., order, 3) + + # Compute masks for collinear or overlapping atoms in the frame + mask_collinear_pred = compute_collinear_mask( + frames_expanded[:, 1] - frames_expanded[:, 0], + frames_expanded[:, 1] - frames_expanded[:, 2], + ).reshape(B // multiplicity, multiplicity, -1) + + if return_frames_expanded: + return frames_idx_pred, mask_collinear_pred * feats["token_pad_mask"][:, None, :], frames_expanded + else: + return frames_idx_pred, mask_collinear_pred * feats["token_pad_mask"][:, None, :] + + +def _compute_frame_pred( + pred_atom_coords: Tensor, + frames_idx_true: Tensor, + feats: dict[str, Tensor], + asym_id_atom: Tensor, + multiplicity: int, + resolved_mask: Optional[Tensor] = None, + inference: bool = False, +) -> tuple[Tensor, Tensor]: + """Private tensor implementation compatible with sparse/padded atom indexing of the reference compute_frame_pred in src/boltz/model/layers/confidence_utils.py.""" + # Disable autocast to match serial compute_frame_pred which runs inside + # torch.amp.autocast("cuda", enabled=False). Without this, bmm and pow + # are affected by autocast, producing different intermediate precision + # than serial and potentially different frame assignments for borderline + # non-polymer ligand chains. + with torch.amp.autocast("cuda", enabled=False): + return _compute_frame_pred_impl( + pred_atom_coords, + frames_idx_true, + feats, + asym_id_atom, + multiplicity, + resolved_mask, + inference, + ) + + +def _compute_frame_pred_impl( + pred_atom_coords: Tensor, + frames_idx_true: Tensor, + feats: dict[str, Tensor], + asym_id_atom: Tensor, + multiplicity: int, + resolved_mask: Optional[Tensor] = None, + inference: bool = False, +) -> tuple[Tensor, Tensor]: + """Implementation of _compute_frame_pred, called with autocast disabled.""" + asym_id_token = feats["asym_id"] + B, _, _ = pred_atom_coords.shape + if B % multiplicity != 0: + raise ValueError(f"pred_atom_coords batch dim ({B}) must be divisible by multiplicity ({multiplicity})") + if resolved_mask is not None and resolved_mask.shape != pred_atom_coords.shape[:2]: + raise ValueError( + f"resolved_mask shape {tuple(resolved_mask.shape)} must match " + f"pred_atom_coords[:2] {tuple(pred_atom_coords.shape[:2])}" + ) + pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3) + # resolved_mask arrives as (B*mult, N_atom). Reshape to (B_batch, mult, + # N_atom) so that each diffusion sample's per-sample resolved mask is + # preserved (symmetry_correction can produce different masks per sample). + if resolved_mask is not None: + resolved_mask = resolved_mask.reshape(B // multiplicity, multiplicity, -1) + frames_idx_pred = ( + frames_idx_true.clone().repeat_interleave(multiplicity, 0).reshape(B // multiplicity, multiplicity, -1, 3) + ) + + for i, pred_atom_coord in enumerate(pred_atom_coords): + for id in torch.unique(asym_id_token[i]): + mask_chain_token = (asym_id_token[i] == id) * feats["token_pad_mask"][i] + mask_chain_atom = (asym_id_atom[i] == id) * feats["atom_pad_mask"][i] + num_tokens = int(mask_chain_token.sum().item()) + num_atoms = int(mask_chain_atom.sum().item()) + + mol_types = feats["mol_type"][i, mask_chain_token.bool()] + + # sanity check: all chains in the batch must have the same mol_type + mol_type_unique = mol_types.unique() + assert ( + mol_type_unique.numel() <= 1 + ), f"all chains in the batch must have the same mol_type but got {mol_type_unique}" # sanity check + + # skip frame reassignment if the chain is not a non-polymer or has fewer than 3 atoms + if mol_type_unique.item() != const.chain_type_ids["NONPOLYMER"] or num_atoms < 3: + continue + + # sanity check: num_atoms = num_tokens for non-polymers + assert ( + num_atoms == num_tokens + ), "num_atoms and num_tokens must be the same for non-polymers, got num_atoms = {}, num_tokens = {}".format( + num_atoms, num_tokens + ) + + chain_atom_indices = torch.nonzero(mask_chain_atom.bool(), as_tuple=False).squeeze(-1) + chain_atom_coords = pred_atom_coord[:, chain_atom_indices] + dist_mat = ((chain_atom_coords[:, None, :, :] - chain_atom_coords[:, :, None, :]) ** 2).sum(-1) ** 0.5 + + if inference: + resolved_pair = 1 - ( + feats["atom_pad_mask"][i][chain_atom_indices][None, :] + * feats["atom_pad_mask"][i][chain_atom_indices][:, None] + ).to(torch.float32) + resolved_pair[resolved_pair == 1] = torch.inf + indices = torch.sort(dist_mat + resolved_pair, axis=2).indices + else: + if resolved_mask is None: + # atom_resolved_mask is (B_batch, N_atom); expand to + # (B_batch, mult, N_atom) so indexing is uniform. + resolved_mask = feats["atom_resolved_mask"][:, None, :].expand(-1, multiplicity, -1) + # resolved_mask[i]: (mult, N_atom) + rm_chain = resolved_mask[i][:, chain_atom_indices] # (mult, N_chain) + resolved_pair = 1 - (rm_chain[:, None, :] * rm_chain[:, :, None]).to(torch.float32) + resolved_pair[resolved_pair == 1] = torch.inf + indices = torch.sort(dist_mat + resolved_pair, axis=2).indices + + frames_local = torch.cat( + [ + indices[:, :, 1:2], + indices[:, :, 0:1], + indices[:, :, 2:3], + ], + dim=2, + ) + frames = chain_atom_indices[frames_local] + frames_idx_pred[i, :, mask_chain_token.bool(), :] = frames + + frames_expanded = pred_atom_coords[ + torch.arange(0, B // multiplicity, 1)[:, None, None, None].to(frames_idx_pred.device), + torch.arange(0, multiplicity, 1)[None, :, None, None].to(frames_idx_pred.device), + frames_idx_pred, + ].reshape(-1, 3, 3) + + mask_collinear_pred = compute_collinear_mask( + frames_expanded[:, 1] - frames_expanded[:, 0], + frames_expanded[:, 1] - frames_expanded[:, 2], + ).reshape(B // multiplicity, multiplicity, -1) + return frames_idx_pred, mask_collinear_pred * feats["token_pad_mask"][:, None, :] + + +def compute_frame_pred( + pred_atom_coords: DTensor, + frames_idx_true: DTensor, + feats: dict[str, DTensor], + multiplicity: int, + resolved_mask: Optional[DTensor] = None, + inference: bool = False, +) -> tuple[DTensor, DTensor]: + """Distributed wrapper around `compute_frame_pred` for DTensors. + + Gathers DTensor inputs to local tensors, runs the serial implementation, + then reshapes and redistributes outputs to match DTensor placements. + + Args: + pred_atom_coords: DTensor of shape (batch * multiplicity, num_atoms, 3). + frames_idx_true: DTensor of shape (batch, num_tokens, 3). + feats: DTensor feature dict with required keys (`asym_id`, `atom_to_token`, + `atom_pad_mask`, `atom_resolved_mask`, `mol_type`, `token_pad_mask`). + multiplicity: Number of copies per sample in the batch dimension. + resolved_mask: Optional DTensor (batch, num_atoms) to prefer resolved atoms. + inference: If True, uses pad mask instead of resolved mask. + + Returns: + DTensor frame indices and DTensor collinearity mask. + """ + feats_keys = { + "asym_id", + "atom_to_token", + "atom_pad_mask", + "atom_resolved_mask", + "mol_type", + "token_pad_mask", + } + if any(k not in feats for k in feats_keys): + raise ValueError(f"feats must contain the following keys: {feats_keys}, got {feats.keys()}") + + if frames_idx_true.ndim == 4: + raise ValueError( + f"frames_idx_true has unsqueezed ensemble dim (ndim=4, shape={frames_idx_true.shape}). " + "Only E=1 is supported; squeeze the ensemble dim before calling compute_frame_pred." + ) + + # Check device mesh, placements, and shapes + device_mesh = pred_atom_coords.device_mesh + single_repr_placements = (Shard(0), Shard(1), Replicate()) + replicate_placements = (Shard(0), Replicate(), Replicate()) + + global_batch_size, num_atoms = feats["atom_pad_mask"].shape + _, num_tokens = feats["asym_id"].shape + assert ( + pred_atom_coords.shape[0] == global_batch_size * multiplicity + ), f"pred_atom_coords must have shape {global_batch_size * multiplicity}, got {pred_atom_coords.shape[0]}" + + expected_placements = { + "pred_atom_coords": single_repr_placements, + "frames_idx_true": single_repr_placements, + "asym_id": single_repr_placements, + "atom_to_token": single_repr_placements, + "atom_pad_mask": single_repr_placements, + "atom_resolved_mask": single_repr_placements, + "mol_type": single_repr_placements, + "token_pad_mask": single_repr_placements, # context parallelism specific + } + expected_shape = { + "pred_atom_coords": (global_batch_size * multiplicity, num_atoms, 3), # 3D coordinates of the atoms + "frames_idx_true": (global_batch_size, num_tokens, 3), # 3 atoms to form a frame per token + "asym_id": (global_batch_size, num_tokens), # asym_id of the tokens + "atom_to_token": ( + global_batch_size, + num_atoms, + num_tokens // device_mesh.size(1), + ), # mapping from atoms to tokens + "atom_pad_mask": (global_batch_size, num_atoms), # padding mask of the atoms + "atom_resolved_mask": (global_batch_size, num_atoms), # resolved mask of the atoms + "mol_type": (global_batch_size, num_tokens), # mol_type of the tokens + "token_pad_mask": (global_batch_size, num_tokens), # padding mask of the tokens (context parallelism specific) + } + + for k in expected_placements: + match k: + case "pred_atom_coords": + if pred_atom_coords.placements != expected_placements[k]: + raise ValueError( + f"pred_atom_coords must have placements {expected_placements[k]}, got {pred_atom_coords.placements}" + ) + if pred_atom_coords.shape != expected_shape[k]: + raise ValueError( + f"pred_atom_coords must have shape {expected_shape[k]}, got {pred_atom_coords.shape}" + ) + case "frames_idx_true": + if frames_idx_true.device_mesh != device_mesh: + raise ValueError( + f"frames_idx_true must be on the same device mesh as pred_atom_coords, got {frames_idx_true.device_mesh} and {device_mesh}" + ) + if frames_idx_true.placements != expected_placements[k]: + raise ValueError( + f"frames_idx_true must have placements {expected_placements[k]}, got {frames_idx_true.placements}" + ) + if frames_idx_true.shape != expected_shape[k]: + raise ValueError( + f"frames_idx_true must have shape {expected_shape[k]}, got {frames_idx_true.shape}" + ) + case "resolved_mask": + if resolved_mask is not None and resolved_mask.device_mesh != device_mesh: + raise ValueError( + f"resolved_mask must be on the same device mesh as pred_atom_coords, got {resolved_mask.device_mesh} and {device_mesh}" + ) + if resolved_mask is not None and resolved_mask.placements != expected_placements[k]: + raise ValueError( + f"resolved_mask must have placements {expected_placements[k]}, got {resolved_mask.placements}" + ) + if resolved_mask is not None and resolved_mask.shape != expected_shape[k]: + raise ValueError(f"resolved_mask must have shape {expected_shape[k]}, got {resolved_mask.shape}") + case _: + if feats[k].device_mesh != device_mesh: + raise ValueError( + f"feats[{k}] must be on the same device mesh as pred_atom_coords, got {feats[k].device_mesh} and {device_mesh}" + ) + if feats[k].placements != expected_placements[k]: + raise ValueError( + f"feats[{k}] must have placements {expected_placements[k]}, got {feats[k].placements}" + ) + if feats[k].shape != expected_shape[k]: + raise ValueError(f"feats[{k}] must have shape {expected_shape[k]}, got {feats[k].shape}") + + # All-gather all inputs/features + pred_atom_coords_gathered = pred_atom_coords.redistribute(device_mesh, placements=replicate_placements).to_local() + frames_idx_true_gathered = frames_idx_true.redistribute(device_mesh, placements=replicate_placements).to_local() + asym_id_gathered = feats["asym_id"].redistribute(device_mesh, placements=replicate_placements).to_local() + asym_id_atom_gathered = ( + single_repr_token_to_atom(feats["asym_id"].float(), feats["atom_to_token"]) + .redistribute(device_mesh, placements=replicate_placements) + .to_local() + .to(torch.int64) + ) + atom_pad_mask_gathered = ( + feats["atom_pad_mask"].redistribute(device_mesh, placements=replicate_placements).to_local() + ) + atom_resolved_mask_gathered = ( + feats["atom_resolved_mask"].redistribute(device_mesh, placements=replicate_placements).to_local() + ) + mol_type_gathered = feats["mol_type"].redistribute(device_mesh, placements=replicate_placements).to_local() + if resolved_mask is not None: + resolved_mask_gathered = resolved_mask.redistribute(device_mesh, placements=replicate_placements).to_local() + else: + resolved_mask_gathered = None + token_pad_mask_gathered = ( + feats["token_pad_mask"].redistribute(device_mesh, placements=replicate_placements).to_local() + ) + + feats_gathered = { + "asym_id": asym_id_gathered, + "atom_pad_mask": atom_pad_mask_gathered, + "atom_resolved_mask": atom_resolved_mask_gathered, + "mol_type": mol_type_gathered, + "token_pad_mask": token_pad_mask_gathered, + } + + frames_idx_pred_local, mask_collinear_pred_local = _compute_frame_pred( + pred_atom_coords_gathered, + frames_idx_true_gathered, + feats_gathered, + asym_id_atom_gathered, + multiplicity, + resolved_mask=resolved_mask_gathered, + inference=inference, + ) + + # Redistribute frames and mask + shape = torch.Size([global_batch_size, multiplicity, num_tokens, 3]) + cp_submesh = device_mesh["cp_axis_0", "cp_axis_1"] + frames_idx_pred_cp = distribute_tensor( + frames_idx_pred_local, + cp_submesh, + (Shard(2), Replicate()), + src_data_rank=0, # group rank not global rank + ) # broadcast to rest of cp group to consistency amid potential numerical discrepancies + frames_idx_pred = DTensor.from_local( + frames_idx_pred_cp.to_local(), + device_mesh=device_mesh, + placements=(Shard(0), Shard(2), Replicate()), + shape=shape, + stride=LayoutRightMap(shape).strides, + ) + + shape = torch.Size([global_batch_size, multiplicity, num_tokens]) + mask_collinear_pred_cp = distribute_tensor( + mask_collinear_pred_local, + device_mesh["cp_axis_0", "cp_axis_1"], + (Shard(2), Replicate()), + src_data_rank=0, # group rank not global rank + ) # broadcast to rest of cp group to consistency amid potential numerical discrepancies + mask_collinear_pred = DTensor.from_local( + mask_collinear_pred_cp.to_local(), + device_mesh=device_mesh, + placements=(Shard(0), Shard(2), Replicate()), + shape=shape, + stride=LayoutRightMap(shape).strides, + ) + + return frames_idx_pred, mask_collinear_pred + + +def lddt_resolved_token( + pred_atom_coords: DTensor, + true_atom_coords: DTensor, + true_coords_resolved_mask: DTensor, + feats: dict[str, DTensor], + comm: TransposeComm, + multiplicity: int = 1, + cutoff: float = 15.0, + eps: float = 1e-10, +) -> tuple[DTensor, DTensor]: + """Compute per-token lDDT scores in distributed setup. + + This function computes the lDDT (local Distance Difference Test) scores for + each token, which measures the local structural accuracy by comparing pairwise + distances between predicted and true coordinates. + + The computation uses cdist_lddt with factorized masks and redistribute_transpose + to handle the distributed pairwise distance computation efficiently. + + Co-sharding Requirements + ------------------------ + This function assumes specific co-sharding of features along the N_token and N_atom axes: + + 1. **Token-Atom Co-sharding**: The N_token axis of token features (mol_type, token_to_rep_atom) + and the N_atom axis of atom features (pred_atom_coords, true_atom_coords, atom_to_token) + are co-sharded as diagonal blocks. This means: + - Atoms of shard i belong ONLY to tokens of shard i + - The atom-to-token mapping is block-diagonal in the global matrix + + 2. **R-set-Token Co-sharding**: Each R-set element corresponds to exactly one token via its + representative atom. R-set elements are placed in the same shard as their corresponding token: + - R-set elements of shard i are a SUBSET of tokens in shard i + - This enables local matmul for r_set_to_rep_atom @ atom_coords + + 3. **N_atom_max_per_shard Semantics**: The "atom" axis of token_to_rep_atom, r_set_to_rep_atom, + and atom_to_token represents the local shard's atom indices (0 to N_atom_max_per_shard-1), + NOT global atom indices. This is because these tensors are diagonal blocks of the global + reference versions. See src/boltz/distributed/data/feature/featurizer.py for details. + + These co-sharding properties enable all coordinate projections and mask computations to be + performed locally without communication (until the final all-reduce for lDDT aggregation). + + Parameters + ---------- + pred_atom_coords : DTensor + Predicted atom coordinates with shape [B*mult, N_atom, 3]. + Placements: (Shard(batch_dim), Shard(atom_dim), Replicate()) + true_atom_coords : DTensor + True atom coordinates with shape [B*mult, N_atom, 3]. + Placements: (Shard(batch_dim), Shard(atom_dim), Replicate()) + true_coords_resolved_mask : DTensor + Resolved mask for atoms with shape [B*mult, N_atom]. + Placements: (Shard(batch_dim), Shard(atom_dim), Replicate()) + feats : dict[str, DTensor] + Feature dictionary containing: + - token_to_rep_atom: One-hot mapping [B, N_token, N_atom_max_per_shard] + - r_set_to_rep_atom: One-hot mapping [B, N_R, N_atom_max_per_shard] + - atom_to_token: One-hot mapping [B, N_atom_max_per_shard, N_token] + - mol_type: Token types [B, N_token] + All with same device_mesh and placements as pred_atom_coords. + comm : TransposeComm + Communication object for redistribute_transpose operations. + multiplicity : int, optional + Diffusion batch multiplier, by default 1 + cutoff : float, optional + Base cutoff distance for lDDT computation, by default 15.0. + For nucleotide tokens, the cutoff is doubled (cutoff + cutoff * is_nucleotide). + eps : float, optional + Small epsilon for numerical stability, by default 1e-10 + + Returns + ------- + target_lddt : DTensor + Per-token lDDT scores [B*mult, N_token] with same placements as input + combined_mask : DTensor + Combined mask (token_resolved_mask * mask_no_match) [B*mult, N_token] with same placements. + - token_resolved_mask: Whether each token has a resolved representative atom + - mask_no_match: Whether each token has valid pairs for lDDT computation + """ + # === Extract device_mesh and placements from input DTensor === + device_mesh = pred_atom_coords.device_mesh + input_placements = pred_atom_coords.placements + # Validate input placements must be exactly (Shard(0), Shard(1), Replicate()) + expected_placements = (Shard(0), Shard(1), Replicate()) + if input_placements != expected_placements: + raise ValueError(f"pred_atom_coords placements {input_placements} must be {expected_placements}") + + # Extract features + token_to_rep_atom = feats["token_to_rep_atom"] # [B, N_token, N_atom_max_per_shard] + r_set_to_rep_atom = feats["r_set_to_rep_atom"] # [B, N_R, N_atom_max_per_shard] + atom_to_token = feats["atom_to_token"] # [B, N_atom_max_per_shard, N_token] + mol_type = feats["mol_type"] # [B, N_token] + + # === Sanity checks for device_mesh and placements consistency === + all_dtensors = [ + ("pred_atom_coords", pred_atom_coords), + ("true_atom_coords", true_atom_coords), + ("true_coords_resolved_mask", true_coords_resolved_mask), + ("token_to_rep_atom", token_to_rep_atom), + ("r_set_to_rep_atom", r_set_to_rep_atom), + ("atom_to_token", atom_to_token), + ("mol_type", mol_type), + ] + for name, dtensor in all_dtensors: + if dtensor.device_mesh != device_mesh: + raise ValueError(f"{name} has different device_mesh than pred_atom_coords") + # Check placements match exactly + if dtensor.placements != expected_placements: + raise ValueError(f"{name} has placements {dtensor.placements}, expected {expected_placements}") + + # === Extract and validate global shape dimensions === + # NOTE: "Global" here refers to DTensor's global semantics (i.e., DTensor.full_tensor() shape), + # which may differ from the serial equivalent dimensions due to intersperse padding applied + # during CP data processing and dataloader. For example, N_token_global and N_atom_global + # include padding to ensure even sharding across CP ranks. + # + # Note on DTensor global shape semantics (Shard dim -> global size, Replicate dim -> local size): + # All have placements (Shard(0), Shard(1), Replicate()) for 3D or (Shard(0), Shard(1), Replicate()) for 2D + # - token_to_rep_atom: [B, N_token, N_atom_max_per_shard] - dim1 Shard->global, dim2 Replicate->local + # - r_set_to_rep_atom: [B, N_R, N_atom_max_per_shard] - dim1 Shard->global, dim2 Replicate->local + # - atom_to_token: [B, N_atom, N_token_max_per_shard] - dim1 Shard->global, dim2 Replicate->local + # - mol_type: [B, N_token] - dim1 Shard->global + B_mult_global = pred_atom_coords.shape[0] # B * multiplicity + N_atom_global = pred_atom_coords.shape[1] + B_global = token_to_rep_atom.shape[0] # B (without multiplicity) + N_token_global = token_to_rep_atom.shape[1] + N_R_global = r_set_to_rep_atom.shape[1] + N_atom_max_per_shard = token_to_rep_atom.shape[2] # Local atom axis (diagonal block semantics) + N_token_max_per_shard = atom_to_token.shape[2] # Local token axis (diagonal block semantics) + + # Validate multiplicity consistency + if B_mult_global != B_global * multiplicity: + raise ValueError( + f"pred_atom_coords batch dim ({B_mult_global}) != B ({B_global}) * multiplicity ({multiplicity})" + ) + + # Validate coordinate shapes + if true_atom_coords.shape != pred_atom_coords.shape: + raise ValueError( + f"true_atom_coords shape {true_atom_coords.shape} != pred_atom_coords shape {pred_atom_coords.shape}" + ) + if true_coords_resolved_mask.shape != (B_mult_global, N_atom_global): + raise ValueError( + f"true_coords_resolved_mask shape {true_coords_resolved_mask.shape} != expected ({B_mult_global}, {N_atom_global})" + ) + + # Validate feature shapes + # Note: For DTensor, .shape returns global shape. For Shard(dim), global != local. + # - token_to_rep_atom: dim1 is Shard(1) so shape[1]=N_token_global; dim2 is Replicate so shape[2]=N_atom_max_per_shard + # - atom_to_token: dim1 is Shard(1) so shape[1]=N_atom_global; dim2 is Replicate so shape[2]=N_token_max_per_shard + if token_to_rep_atom.shape != (B_global, N_token_global, N_atom_max_per_shard): + raise ValueError(f"token_to_rep_atom shape {token_to_rep_atom.shape} is invalid") + if r_set_to_rep_atom.shape != (B_global, N_R_global, N_atom_max_per_shard): + raise ValueError(f"r_set_to_rep_atom shape {r_set_to_rep_atom.shape} is invalid") + # atom_to_token: dim1 (atom) is Shard(1)->global, dim2 (token) is Replicate->local + if atom_to_token.shape != (B_global, N_atom_global, N_token_max_per_shard): + raise ValueError(f"atom_to_token shape {atom_to_token.shape} is invalid") + if mol_type.shape != (B_global, N_token_global): + raise ValueError(f"mol_type shape {mol_type.shape} != expected ({B_global}, {N_token_global})") + + # === Forward-only computation (no gradient support) === + # The lDDT metric involves step functions (thresholding distance differences at 0.5, 1, 2, 4 Å), + # which are not differentiable. Neither the original lddt_dist nor cdist_lddt is mathematically + # defined with gradients. Therefore, this function does not support backward pass by definition. + with torch.no_grad(): + # === Get local tensors and consolidate dtype casting === + # Promote to at least float32 to match serial get_target_lddt which uses + # .float() inside autocast(enabled=False). promote_types preserves + # float64 for test paths while ensuring BF16 inputs compute in float32. + pred_coords_local = pred_atom_coords.to_local() # [local_B*mult, local_N_atom, 3] + true_coords_local = true_atom_coords.to_local() # [local_B*mult, local_N_atom, 3] + coord_dtype = torch.promote_types(pred_coords_local.dtype, torch.float32) + pred_coords_local = pred_coords_local.to(dtype=coord_dtype) + true_coords_local = true_coords_local.to(dtype=coord_dtype) + + # Cast all feature tensors to coord_dtype + token_to_rep_local = token_to_rep_atom.to_local().to( + dtype=coord_dtype + ) # [local_B, local_N_token, local_N_atom] + r_set_to_rep_local = r_set_to_rep_atom.to_local().to(dtype=coord_dtype) # [local_B, local_N_R, local_N_atom] + atom_to_token_local = atom_to_token.to_local().to(dtype=coord_dtype) # [local_B, local_N_atom, local_N_token] + mol_type_local = mol_type.to_local() # [local_B, local_N_token] + resolved_mask_local = true_coords_resolved_mask.to_local().to(dtype=coord_dtype) # [local_B*mult, local_N_atom] + + # Get local batch size for einsum reshaping + local_B = token_to_rep_local.shape[0] + + # === Project atom coords to token space (row) using einsum with multiplicity broadcasting === + # This avoids repeat_interleave memory overhead by using einsum's implicit broadcasting. + # Co-sharding assumption: token_to_rep_local operates on local atoms only (diagonal block). + # + # token_to_rep_local: [local_B, local_N_token, local_N_atom] -> "bta" + # pred_coords_local reshaped: [local_B, mult, local_N_atom, 3] -> "bmac" + # Output: [local_B, mult, local_N_token, 3] -> "bmtc", then reshape to [local_B*mult, local_N_token, 3] + pred_coords_reshaped = pred_coords_local.view(local_B, multiplicity, -1, 3) # [local_B, mult, local_N_atom, 3] + true_coords_reshaped = true_coords_local.view(local_B, multiplicity, -1, 3) # [local_B, mult, local_N_atom, 3] + + pred_token_coords_row_local = torch.einsum("bta,bmac->bmtc", token_to_rep_local, pred_coords_reshaped).reshape( + -1, token_to_rep_local.shape[1], 3 + ) # [local_B*mult, local_N_token, 3] + true_token_coords_row_local = torch.einsum("bta,bmac->bmtc", token_to_rep_local, true_coords_reshaped).reshape( + -1, token_to_rep_local.shape[1], 3 + ) # [local_B*mult, local_N_token, 3] + + # === Project atom coords to R-set space (col) using einsum with multiplicity broadcasting === + # Co-sharding assumption: r_set_to_rep_local operates on local atoms only (diagonal block). + # + # r_set_to_rep_local: [local_B, local_N_R, local_N_atom] -> "bra" + # coords reshaped: [local_B, mult, local_N_atom, 3] -> "bmac" + # Output: [local_B, mult, local_N_R, 3] -> "bmrc" + pred_R_coords_col_local = torch.einsum("bra,bmac->bmrc", r_set_to_rep_local, pred_coords_reshaped).reshape( + -1, r_set_to_rep_local.shape[1], 3 + ) # [local_B*mult, local_N_R, 3] + true_R_coords_col_local = torch.einsum("bra,bmac->bmrc", r_set_to_rep_local, true_coords_reshaped).reshape( + -1, r_set_to_rep_local.shape[1], 3 + ) # [local_B*mult, local_N_R, 3] + + # === Compute factorized masks (row and col) using einsum with multiplicity broadcasting === + # Masks can vary along the multiplicity axis, so we use einsum to broadcast properly. + # Co-sharding assumption: the mapping tensors operate on local atoms only. + # + # resolved_mask_local: [local_B*mult, local_N_atom] -> reshaped to [local_B, mult, local_N_atom] -> "bma" + # token_to_rep_local: [local_B, local_N_token, local_N_atom] -> "bta" + # mask_row: [local_B, mult, local_N_token] -> "bmt", then reshape to [local_B*mult, local_N_token] + resolved_mask_reshaped = resolved_mask_local.view(local_B, multiplicity, -1) # [local_B, mult, local_N_atom] + + mask_row_local = torch.einsum("bta,bma->bmt", token_to_rep_local, resolved_mask_reshaped).reshape( + -1, token_to_rep_local.shape[1] + ) # [local_B*mult, local_N_token] + mask_col_local = torch.einsum("bra,bma->bmr", r_set_to_rep_local, resolved_mask_reshaped).reshape( + -1, r_set_to_rep_local.shape[1] + ) # [local_B*mult, local_N_R] + + # === Compute cutoff_col based on nucleotide type === + # is_nucleotide_token: [local_B, local_N_token] + # Use atom_to_token_local.dtype for consistency with subsequent bmm operations + is_nucleotide_token_local = (mol_type_local == const.chain_type_ids["DNA"]).to( + dtype=atom_to_token_local.dtype + ) + (mol_type_local == const.chain_type_ids["RNA"]).to( + dtype=atom_to_token_local.dtype + ) # [local_B, local_N_token] + + # is_nucleotide_R_element = r_set_to_rep_atom @ (atom_to_token @ is_nucleotide_token) + # Co-sharding assumption: atom_to_token operates on local atoms/tokens only (diagonal block). + is_nucleotide_atom_local = torch.bmm(atom_to_token_local, is_nucleotide_token_local.unsqueeze(-1)).squeeze( + -1 + ) # [local_B, local_N_atom] + is_nucleotide_R_element_local = torch.bmm(r_set_to_rep_local, is_nucleotide_atom_local.unsqueeze(-1)).squeeze( + -1 + ) # [local_B, local_N_R] + + # cutoff_col = cutoff + cutoff * is_nucleotide_R_element + cutoff_col_local = cutoff + cutoff * is_nucleotide_R_element_local # [local_B, local_N_R] + + # === Get rep_atom indices for diagonal masking (local indices within shard) === + # These indices use N_atom_max_per_shard semantics (local atom indices 0 to N_atom_max_per_shard-1). + # Due to co-sharding, each token/R-element's representative atom is guaranteed to be in the same shard, + # so argmax on the local diagonal block yields valid local indices for diagonal masking. + rep_atom_token_local = token_to_rep_local.argmax(dim=-1) # [local_B, local_N_token] + rep_atom_r_set_local = r_set_to_rep_local.argmax(dim=-1) # [local_B, local_N_R] + + # === Derive target placements from input (avoid hardcoding) === + # input_placements is e.g. (Shard(0), Shard(1), Replicate()) for the 3D device mesh. + # Note: DTensor placements tuple length matches mesh dimensions, not tensor dimensions. + # Target placements after redistribute_transpose: swap Shard(1) <-> Replicate() + # From (Shard(0), Shard(1), Replicate()) to (Shard(0), Replicate(), Shard(1)) + target_placements = (input_placements[0], input_placements[2], input_placements[1]) + + # === Create DTensors for column tensors that need transpose === + # Compute shapes and strides for contiguous tensors using LayoutRightMap + coords_3d_shape = (B_mult_global, N_R_global, 3) + coords_3d_stride = LayoutRightMap(coords_3d_shape).strides + + pred_R_coords_col_dtensor = DTensor.from_local( + pred_R_coords_col_local, + device_mesh, + input_placements, + shape=torch.Size(coords_3d_shape), + stride=coords_3d_stride, + ) + true_R_coords_col_dtensor = DTensor.from_local( + true_R_coords_col_local, + device_mesh, + input_placements, + shape=torch.Size(coords_3d_shape), + stride=coords_3d_stride, + ) + + # Create DTensors for mask_col: [B*mult, N_R] - masks have multiplicity + mask_col_2d_shape = (B_mult_global, N_R_global) + mask_col_2d_stride = LayoutRightMap(mask_col_2d_shape).strides + + mask_col_dtensor = DTensor.from_local( + mask_col_local, + device_mesh, + input_placements, + shape=torch.Size(mask_col_2d_shape), + stride=mask_col_2d_stride, + ) + + # Create DTensors for cutoff_col and rep_atom_r_set: [B, N_R] - no multiplicity + feat_2d_shape = (B_global, N_R_global) + feat_2d_stride = LayoutRightMap(feat_2d_shape).strides + + cutoff_col_dtensor = DTensor.from_local( + cutoff_col_local, + device_mesh, + input_placements, + shape=torch.Size(feat_2d_shape), + stride=feat_2d_stride, + ) + rep_atom_r_set_dtensor = DTensor.from_local( + rep_atom_r_set_local, + device_mesh, + input_placements, + shape=torch.Size(feat_2d_shape), + stride=feat_2d_stride, + ) + + # === redistribute_transpose for column tensors === + # Transform placements from (S(0), S(1), R) to (S(0), R, S(1)) via all-to-all communication. + # This distributes the N_R axis across the cp_axis_1 dimension of the device mesh. + pred_R_coords_col_t = redistribute_transpose( + pred_R_coords_col_dtensor, comm, target_placements, dim0=None, dim1=None + ) + true_R_coords_col_t = redistribute_transpose( + true_R_coords_col_dtensor, comm, target_placements, dim0=None, dim1=None + ) + mask_col_t = redistribute_transpose(mask_col_dtensor, comm, target_placements, dim0=None, dim1=None) + cutoff_col_t = redistribute_transpose(cutoff_col_dtensor, comm, target_placements, dim0=None, dim1=None) + rep_atom_r_set_t = redistribute_transpose(rep_atom_r_set_dtensor, comm, target_placements, dim0=None, dim1=None) + + # === Factorized Pair-Mask Algorithm with cdist_lddt === + # + # The lDDT computation requires pairwise distance comparisons between all (token, R-element) pairs. + # Instead of materializing the full [N_token, N_R] pair_mask, we use factorized masks: + # pair_mask[i,j] = mask_row[i] * mask_col[j] + # + # This factorization is valid because: + # - mask_row[i] = 1 iff token i has a resolved representative atom + # - mask_col[j] = 1 iff R-element j has a resolved representative atom + # - A pair (i,j) is valid iff BOTH atoms are resolved + # + # For diagonal masking (excluding self-pairs where a token's rep_atom equals an R-element's rep_atom): + # - atom_indices_row = rep_atom_token: local atom index of each token's representative atom + # - atom_indices_col = rep_atom_r_set: local atom index of each R-element's representative atom + # + # Why local indices work for diagonal masking despite N_atom_max_per_shard semantics: + # Due to co-sharding, diagonal device_mesh ranks (where cp_axis_0 == cp_axis_1) have: + # - Row tokens from shard i with local atom indices [0, N_atom_max_per_shard) + # - Column R-elements ALSO from shard i with the SAME local atom index range + # Thus, when rep_atom_token[t] == rep_atom_r_set[r], it genuinely means the same physical atom, + # and the diagonal mask correctly excludes self-pairs. + # + # Off-diagonal ranks (cp_axis_0 != cp_axis_1) have row/col from different shards, so their + # local atom indices never match (different index spaces), and we skip diagonal masking entirely. + + # Determine if this rank is on diagonal (for do_mask_diagonal) + # Convert to native Python bool for Triton kernel compatibility + is_diagonal_rank = bool(comm.is_self_comm) + + out_num_local, out_denom_local, mask_no_match_local = cdist_lddt( + pred_coords_row=pred_token_coords_row_local, # [local_B*mult, local_N_token, 3] + pred_coords_col=pred_R_coords_col_t.to_local(), # [local_B*mult, local_N_R_t, 3] + true_coords_row=true_token_coords_row_local, # [local_B*mult, local_N_token, 3] + true_coords_col=true_R_coords_col_t.to_local(), # [local_B*mult, local_N_R_t, 3] + mask_row=mask_row_local, # [local_B*mult, local_N_token] - factorized row mask + mask_col=mask_col_t.to_local(), # [local_B*mult, local_N_R_t] - factorized col mask (transposed) + multiplicity=multiplicity, + atom_indices_row=rep_atom_token_local if is_diagonal_rank else None, # Local indices for diagonal masking + atom_indices_col=rep_atom_r_set_t.to_local() if is_diagonal_rank else None, # Transposed local indices + cutoff_col=cutoff_col_t.to_local(), # [local_B, local_N_R_t] - per-column cutoff (transposed) + do_mask_diagonal=is_diagonal_rank, # Only diagonal ranks need self-pair exclusion + return_unnormalized_score=True, # Return partial sums for distributed aggregation + per_atom=True, # Per-token output for token-level lDDT + ) + # Output shapes: out_num, out_denom, mask_no_match are [local_B*mult, local_N_token] + + # === All-reduce across N_R axis (cp_axis_1 group) === + # Each rank computed partial lDDT contributions from its local N_R shard. + # Sum across all N_R shards to get the full lDDT numerator and denominator. + # This transforms partial sums from (S(0), S(1), partial_N_R) to (S(0), S(1), R). + group_col = device_mesh.get_group(2) # cp_axis_1 + dist.all_reduce(out_num_local, op=dist.ReduceOp.SUM, group=group_col) + dist.all_reduce(out_denom_local, op=dist.ReduceOp.SUM, group=group_col) + + # All-reduce mask_no_match with logical OR (any rank having valid pairs means token has matches) + dist.all_reduce(mask_no_match_local, op=dist.ReduceOp.MAX, group=group_col) + + # === Compute combined_mask = token_resolved_mask * mask_no_match === + # mask_row_local is already token_resolved_mask (computed above via einsum) + # Both masks don't require gradients + combined_mask_local = mask_row_local * mask_no_match_local + + # === Normalize to get final lDDT scores === + # Preserve input coordinate dtype (e.g., float64 for precision) + norm = 1.0 / (eps + out_denom_local) + target_lddt_local = norm * (eps + out_num_local) + + # === Wrap outputs as DTensors === + output_shape = (B_mult_global, N_token_global) + output_stride = LayoutRightMap(output_shape).strides + + target_lddt = DTensor.from_local( + target_lddt_local, + device_mesh, + input_placements, + shape=torch.Size(output_shape), + stride=output_stride, + ) + combined_mask = DTensor.from_local( + combined_mask_local, + device_mesh, + input_placements, + shape=torch.Size(output_shape), + stride=output_stride, + ) + + return target_lddt, combined_mask + + +class _PLDDTLossImpl(torch.autograd.Function): + """Fused pLDDT loss computation with gradient flow to pred_lddt. + + This fuses the entire pLDDT loss computation into a single autograd Function: + 1. Cross-entropy errors: errors = -sum(one_hot * log_softmax(pred_lddt), dim=-1) + 2. Masked errors: errors * combined_mask + 3. Sum over token dim (all_reduce over CP axis) + 4. Normalize: numerator / clamp(denominator, min=eps) + 5. Sum over batch dim (all_reduce over DP axis) + 6. Mean: loss_sum / batch_size + + WHY THIS WORKS: + --------------- + dist.all_reduce is an in-place op that is invisible to PyTorch autograd. + Autograd records the computation graph as if all_reduce never happened, + but the actual tensor values ARE the all_reduced values. + + This is correct because all_reduce(SUM) has IDENTITY GRADIENT: + Forward: y = sum_over_ranks(x_i), all ranks get the same y + Backward: ∂L/∂x_i = ∂L/∂y * ∂y/∂x_i = ∂L/∂y * 1 = ∂L/∂y + + So autograd "accidentally" computes the correct gradient by ignoring + all_reduce, since passing the gradient through unchanged is exactly + what all_reduce(SUM) backward should do. + + IMPORTANT: This ONLY works for ReduceOp.SUM. Other ops (MEAN, MAX, etc.) + have non-identity gradients and would produce wrong results. However, + this is guaranteed by the pLDDT loss semantics: we're computing + loss = mean_b(sum_t(errors * mask) / sum_t(mask)) + which requires SUM reduction to accumulate partial sums across ranks. + + Gradient flows back to pred_lddt through the log_softmax operation. + target_lddt and combined_mask are non-differentiable. + + See Also + -------- + plddt_loss : The DTensor wrapper API that calls this function. + lddt_resolved_token : Computes target_lddt and combined_mask. + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx: FunctionCtx, + pred_lddt: DTensor, + target_lddt: DTensor, + combined_mask: DTensor, + ) -> DTensor: + """Forward pass for fused pLDDT loss computation. + + Parameters + ---------- + ctx : FunctionCtx + The autograd context object for saving tensors for backward. + pred_lddt : DTensor + Predicted lDDT logits with shape (B * multiplicity, N_token, num_bins). + Placements: (Shard(0), Shard(1), Replicate()). + This is the only tensor that requires gradients. + target_lddt : DTensor + Target lDDT scores with shape (B * multiplicity, N_token), values in [0, 1]. + Placements: (Shard(0), Shard(1), Replicate()). + No gradients (computed from lddt_resolved_token with step functions). + combined_mask : DTensor + Combined mask (token_resolved_mask * mask_no_match) with shape (B * multiplicity, N_token). + Placements: (Shard(0), Shard(1), Replicate()). + No gradients. + + Returns + ------- + DTensor + Scalar loss with placements (Replicate(), Replicate(), Replicate()). + """ + # === Validate input dimensions === + if pred_lddt.ndim != 3: + raise ValueError(f"pred_lddt must be 3D (B*mult, N_token, num_bins), got {pred_lddt.ndim}D") + if target_lddt.ndim != 2: + raise ValueError(f"target_lddt must be 2D (B*mult, N_token), got {target_lddt.ndim}D") + if combined_mask.ndim != 2: + raise ValueError(f"combined_mask must be 2D (B*mult, N_token), got {combined_mask.ndim}D") + + # === Validate device_mesh consistency === + device_mesh = pred_lddt.device_mesh + for name, dtensor in [ + ("target_lddt", target_lddt), + ("combined_mask", combined_mask), + ]: + if dtensor.device_mesh != device_mesh: + raise ValueError(f"{name} has different device_mesh than pred_lddt") + + # === Validate placements === + # All inputs should have placements (S(0), S(1), R) for 3D or (S(0), S(1), R) conceptually for 2D + pred_placements = pred_lddt.placements + expected_pred_placements = (Shard(0), Shard(1), Replicate()) + if pred_placements != expected_pred_placements: + raise ValueError(f"pred_lddt placements {pred_placements} must be {expected_pred_placements}") + + # For 2D tensors, placements should match first two dimensions + expected_2d_placements = (Shard(0), Shard(1), Replicate()) + for name, dtensor in [ + ("target_lddt", target_lddt), + ("combined_mask", combined_mask), + ]: + if dtensor.placements != expected_2d_placements: + raise ValueError(f"{name} placements {dtensor.placements} must be {expected_2d_placements}") + + # === Validate shape consistency === + B_mult_global = pred_lddt.shape[0] + N_token_global = pred_lddt.shape[1] + + if target_lddt.shape != (B_mult_global, N_token_global): + raise ValueError(f"target_lddt shape {target_lddt.shape} != expected ({B_mult_global}, {N_token_global})") + if combined_mask.shape != (B_mult_global, N_token_global): + raise ValueError( + f"combined_mask shape {combined_mask.shape} != expected ({B_mult_global}, {N_token_global})" + ) + + # === Get process groups for all_reduce === + group_cp = device_mesh.get_group(1) # CP axis (token dimension) + group_dp = device_mesh.get_group(0) # DP axis (batch dimension) + + # === Get local tensors === + pred_lddt_local = pred_lddt.to_local().detach().requires_grad_(pred_lddt.requires_grad) + target_lddt_local = target_lddt.to_local().detach() # No gradient + combined_mask_local = combined_mask.to_local().detach() # No gradient + + # Compute bin indices from target_lddt (no gradient flow through this) + num_bins = pred_lddt_local.shape[-1] + bin_index = torch.floor(target_lddt_local * num_bins).long() + bin_index = torch.clamp(bin_index, max=(num_bins - 1)) + + # One-hot encode target bins (no gradient) + lddt_one_hot = F.one_hot(bin_index, num_classes=num_bins).to(pred_lddt_local.dtype) + + # === Fused subgraph: errors + mask + sum + normalize + sum + mean === + with torch.enable_grad(): + # Compute cross-entropy errors (gradient flows through log_softmax) + log_probs = F.log_softmax(pred_lddt_local, dim=-1) + errors_local = -torch.sum(lddt_one_hot * log_probs, dim=-1) # [local_B*mult, local_N_token] + + # Apply combined_mask + masked_errors_local = errors_local * combined_mask_local # [local_B*mult, local_N_token] + + # Sum over token dimension (local sum first) + numerator_local = masked_errors_local.sum(dim=-1) # [local_B*mult] + denominator_local = combined_mask_local.sum(dim=-1) # [local_B*mult] + + # All-reduce over CP axis (token dimension was sharded) + # Clone to protect against potential upstream saved_for_backward + # (though sum() output is not typically saved by its upstream, clone is safe and cheap) + numerator = numerator_local.clone() + denominator = denominator_local.clone() + with torch.no_grad(): + dist.all_reduce(numerator, op=dist.ReduceOp.SUM, group=group_cp) + dist.all_reduce(denominator, op=dist.ReduceOp.SUM, group=group_cp) + + # Normalize: numerator / clamp(denominator, min=eps) + eps = 1e-7 + denominator_safe = torch.clamp(denominator, min=eps) + per_sample_loss_local = numerator / denominator_safe # [local_B*mult] + + # Sum over batch dimension (local sum first) + loss_sum_local = per_sample_loss_local.sum() # scalar + + # All-reduce over DP axis (batch dimension was sharded) + # Clone needed: sum() may save input for backward + loss_sum = loss_sum_local.clone() + with torch.no_grad(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM, group=group_dp) + + # Mean over global batch size + loss_local = loss_sum / B_mult_global + + # === Save for backward === + ctx.save_for_backward(pred_lddt_local, loss_local) + ctx.device_mesh = device_mesh + ctx.pred_placements = pred_placements + ctx.pred_lddt_shape = pred_lddt.shape + ctx.pred_lddt_stride = pred_lddt.stride() + + # === Wrap output as DTensor === + # Output is a scalar, fully replicated across all mesh dimensions + output_placements = (Replicate(), Replicate(), Replicate()) + + loss = DTensor.from_local( + loss_local.detach(), + device_mesh=device_mesh, + placements=output_placements, + shape=torch.Size(()), + stride=(), + ) + + return loss + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward( + ctx: FunctionCtx, + grad_loss: DTensor, + ) -> tuple[DTensor | None, None, None]: + """Backward pass using subgraph trick via autograd.grad. + + The forward pass fuses the entire pLDDT loss computation into one subgraph. + Autograd ignores all_reduce (which has identity gradient for SUM), + so we can backprop through the entire fused computation in one call. + + Parameters + ---------- + ctx : FunctionCtx + Context containing saved tensors and metadata. + grad_loss : DTensor + Gradient of upstream loss w.r.t. this loss output (scalar). + + Returns + ------- + tuple + (grad_pred_lddt, None, None) - only pred_lddt has gradients. + """ + pred_lddt_local, loss_local = ctx.saved_tensors + + if not pred_lddt_local.requires_grad: + return None, None, None + + # === Validate grad_loss === + if grad_loss.shape != torch.Size(()): + raise ValueError(f"grad_loss must be scalar, got shape {grad_loss.shape}") + + expected_grad_placements = (Replicate(), Replicate(), Replicate()) + if grad_loss.placements != expected_grad_placements: + raise ValueError(f"grad_loss placements {grad_loss.placements} must be {expected_grad_placements}") + + if grad_loss.device_mesh != ctx.device_mesh: + raise ValueError("grad_loss has different device_mesh than forward inputs") + + # Get local gradient (scalar, replicated) + grad_loss_local = grad_loss.to_local() + + # Backprop through entire fused subgraph: loss -> pred_lddt + # Autograd treats all_reduce as invisible, which is correct since + # all_reduce(SUM) has identity gradient (see forward pass comment). + (grad_pred_lddt_local,) = torch.autograd.grad( + outputs=[loss_local], + inputs=[pred_lddt_local], + grad_outputs=[grad_loss_local], + retain_graph=False, + ) + + # Wrap as DTensor + grad_pred_lddt = DTensor.from_local( + grad_pred_lddt_local, + device_mesh=ctx.device_mesh, + placements=ctx.pred_placements, + shape=ctx.pred_lddt_shape, + stride=ctx.pred_lddt_stride, + ) + + return grad_pred_lddt, None, None + + +def plddt_loss( + pred_lddt: DTensor, + pred_atom_coords: DTensor, + true_atom_coords: DTensor, + true_coords_resolved_mask: DTensor, + feats: dict[str, DTensor], + comm: TransposeComm, + multiplicity: int = 1, + cutoff: float = 15.0, +) -> DTensor: + """Compute the pLDDT loss using DTensor. + + This is the DTensor version of boltz.model.loss.confidencev2.plddt_loss. + It computes the cross-entropy loss between predicted lDDT bins and target + lDDT scores computed from coordinates. + + The entire computation is fused into _PLDDTLossImpl using native PyTorch ops: + 1. lddt_resolved_token: Compute per-token target lDDT scores (no gradient) + 2. _PLDDTLossImpl fuses: + - Cross-entropy errors: -sum(one_hot * log_softmax(pred_lddt), dim=-1) + - Masked errors: errors * combined_mask + - Sum over token dim + all_reduce(SUM) across CP + - Normalize: numerator / clamp(denominator, min=eps) + - Sum over batch dim + all_reduce(SUM) across DP + - Mean: loss_sum / batch_size + + All all_reduce ops use SUM which has identity gradient, allowing the entire + computation to be captured in a single autograd subgraph. + + Parameters + ---------- + pred_lddt : DTensor + Predicted lDDT logits with shape (B * multiplicity, N_token, num_bins). + Placements: (Shard(0), Shard(1), Replicate()). + pred_atom_coords : DTensor + Predicted atom coordinates with shape (B * multiplicity, N_atom, 3). + Placements: (Shard(0), Shard(1), Replicate()). + true_atom_coords : DTensor + Ground truth atom coordinates with shape (B * multiplicity, N_atom, 3). + Placements: (Shard(0), Shard(1), Replicate()). + true_coords_resolved_mask : DTensor + Mask for resolved coordinates with shape (B * multiplicity, N_atom). + Placements: (Shard(0), Shard(1), Replicate()). + feats : dict[str, DTensor] + Dictionary containing feature tensors: + - "token_to_rep_atom": [B, N_token, N_atom_max_per_shard] + - "r_set_to_rep_atom": [B, N_R, N_atom_max_per_shard] + - "atom_to_token": [B, N_atom, N_token_max_per_shard] + - "mol_type": [B, N_token] + comm : TransposeComm + Communication object for redistribute_transpose operations. + multiplicity : int, optional + Diffusion batch multiplier, by default 1. + cutoff : float, optional + Base cutoff distance for lDDT computation, by default 15.0. + + Returns + ------- + DTensor + Scalar loss with placements (Replicate(), Replicate(), Replicate()). + """ + # Compute target lDDT and combined_mask (no gradients, uses step functions) + # combined_mask = token_resolved_mask * mask_no_match (computed inside lddt_resolved_token) + target_lddt, combined_mask = lddt_resolved_token( + pred_atom_coords=pred_atom_coords, + true_atom_coords=true_atom_coords, + true_coords_resolved_mask=true_coords_resolved_mask, + feats=feats, + comm=comm, + multiplicity=multiplicity, + cutoff=cutoff, + ) + + # _PLDDTLossImpl fuses the entire loss computation using native PyTorch ops: + # errors -> mask -> sum(CP) -> normalize -> sum(DP) -> mean + loss = _PLDDTLossImpl.apply(pred_lddt, target_lddt, combined_mask) + + return loss + + +class _PDELossImpl(torch.autograd.Function): + """Shardwise computation of PDE loss with gradient flow to pred_pde. + + This computes the PDE cross-entropy loss per token row, with all-reduce across + the column dimension (cp_axis_1). The row-wise aggregation is done by sharded_sum + in the wrapper API pde_loss, which has proper autograd support. + + The computation uses cdist_pde which fuses: + - Distance computation: true_d = cdist(true_coords_row, true_coords_col) + - Distance computation: pred_d = cdist(pred_coords_row, pred_coords_col) + - Target PDE: target_pde = abs(true_d - pred_d) + - Binning: bin_index = clamp(floor(target_pde * num_bins / max_dist), max=num_bins-1) + - Cross-entropy: errors = -sum(one_hot(bin_index) * log_softmax(pred_pde), dim=-1) + - Masked sum along column: out_loss_num = sum(errors * mask, dim=-1) + + Gradient flows back to pred_pde through the log_softmax operation in cdist_pde. + Coordinates and masks are non-differentiable (no gradient flow). + + See Also + -------- + pde_loss : The DTensor wrapper API that calls this function and does sharded_sum. + cdist_pde : The fused Triton kernel for PDE cross-entropy computation. + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx: FunctionCtx, + pred_pde: DTensor, + pred_atom_coords: DTensor, + true_atom_coords: DTensor, + true_coords_resolved_mask: DTensor, + token_to_rep_atom: DTensor, + comm: TransposeComm, + multiplicity: int, + max_dist: float, + ) -> tuple[DTensor, DTensor]: + """Forward pass for PDE loss computation with transpose and all-reduce. + + Parameters + ---------- + ctx : FunctionCtx + The autograd context object for saving tensors for backward. + pred_pde : DTensor + Predicted PDE logits with shape (B * multiplicity, N_token, N_token, num_bins). + Placements: (Shard(0), Shard(1), Shard(2)). + This is the only tensor that requires gradients. + pred_atom_coords : DTensor + Predicted atom coordinates with shape (B * multiplicity, N_atom, 3). + Placements: (Shard(0), Shard(1), Replicate()). + true_atom_coords : DTensor + Ground truth atom coordinates with shape (B * multiplicity, N_atom, 3). + Placements: (Shard(0), Shard(1), Replicate()). + true_coords_resolved_mask : DTensor + Resolved mask with shape (B * multiplicity, N_atom). + Placements: (Shard(0), Shard(1), Replicate()). + token_to_rep_atom : DTensor + Token to representative atom mapping with shape (B, N_token, N_atom). + Placements: (Shard(0), Shard(1), Replicate()). + comm : TransposeComm + Communication object for redistribute_transpose operations. + multiplicity : int + Diffusion batch multiplier. + max_dist : float + Maximum distance for binning. + + Returns + ------- + out_loss_num : DTensor + Partial sum of cross-entropy loss per row, shape [B*mult, N_token_row]. + Placements: (Shard(0), Shard(1), Replicate()). + out_mask_denom : DTensor + Partial sum of mask per row, shape [B*mult, N_token_row]. + Placements: (Shard(0), Shard(1), Replicate()). + """ + # === Validate input dimensions === + if pred_pde.ndim != 4: + raise ValueError(f"pred_pde must be 4D (B*mult, N_token, N_token, num_bins), got {pred_pde.ndim}D") + if pred_atom_coords.ndim != 3: + raise ValueError(f"pred_atom_coords must be 3D (B*mult, N_atom, 3), got {pred_atom_coords.ndim}D") + if true_atom_coords.ndim != 3: + raise ValueError(f"true_atom_coords must be 3D (B*mult, N_atom, 3), got {true_atom_coords.ndim}D") + if true_coords_resolved_mask.ndim != 2: + raise ValueError( + f"true_coords_resolved_mask must be 2D (B*mult, N_atom), got {true_coords_resolved_mask.ndim}D" + ) + if token_to_rep_atom.ndim != 3: + raise ValueError( + f"token_to_rep_atom must be 3D (B, N_token, N_atom_max_per_shard), got {token_to_rep_atom.ndim}D" + ) + + # === Validate device_mesh consistency === + device_mesh = pred_pde.device_mesh + for name, dtensor in [ + ("pred_atom_coords", pred_atom_coords), + ("true_atom_coords", true_atom_coords), + ("true_coords_resolved_mask", true_coords_resolved_mask), + ("token_to_rep_atom", token_to_rep_atom), + ]: + if dtensor.device_mesh != device_mesh: + raise ValueError(f"{name} has different device_mesh than pred_pde") + + # === Validate placements === + # pred_pde is pair representation: (S(0), S(1), S(2)) - sharded on batch and both token axes + pred_pde_placements = pred_pde.placements + expected_pred_pde_placements = (Shard(0), Shard(1), Shard(2)) + if pred_pde_placements != expected_pred_pde_placements: + raise ValueError(f"pred_pde placements {pred_pde_placements} must be {expected_pred_pde_placements}") + + # Other inputs are single representation: (S(0), S(1), R) - sharded on batch and one axis + input_placements = pred_atom_coords.placements + expected_input_placements = (Shard(0), Shard(1), Replicate()) + for name, dtensor in [ + ("pred_atom_coords", pred_atom_coords), + ("true_atom_coords", true_atom_coords), + ("true_coords_resolved_mask", true_coords_resolved_mask), + ("token_to_rep_atom", token_to_rep_atom), + ]: + if dtensor.placements != expected_input_placements: + raise ValueError(f"{name} placements {dtensor.placements} must be {expected_input_placements}") + + # === Validate shape consistency === + B_mult_global = pred_pde.shape[0] + N_token_global = pred_pde.shape[1] + N_atom_global = pred_atom_coords.shape[1] + B_global = token_to_rep_atom.shape[0] + + # pred_pde must be square in token dimensions + if pred_pde.shape[1] != pred_pde.shape[2]: + raise ValueError(f"pred_pde token dimensions must be equal, got shape {pred_pde.shape}") + + # Validate multiplicity consistency + if B_mult_global != B_global * multiplicity: + raise ValueError(f"pred_pde batch dim ({B_mult_global}) != B ({B_global}) * multiplicity ({multiplicity})") + + # Validate coordinate shapes match + if true_atom_coords.shape != pred_atom_coords.shape: + raise ValueError( + f"true_atom_coords shape {true_atom_coords.shape} != pred_atom_coords shape {pred_atom_coords.shape}" + ) + + # Validate resolved_mask shape + if true_coords_resolved_mask.shape != (B_mult_global, N_atom_global): + raise ValueError( + f"true_coords_resolved_mask shape {true_coords_resolved_mask.shape} != expected ({B_mult_global}, {N_atom_global})" + ) + + # Validate token_to_rep_atom token dimension matches pred_pde + if token_to_rep_atom.shape[1] != N_token_global: + raise ValueError( + f"token_to_rep_atom N_token ({token_to_rep_atom.shape[1]}) != pred_pde N_token ({N_token_global})" + ) + + # Get local tensors + pred_pde_local = pred_pde.to_local().detach().requires_grad_(pred_pde.requires_grad) + pred_coords_local = pred_atom_coords.to_local().detach() # [local_B*mult, local_N_atom, 3] + true_coords_local = true_atom_coords.to_local().detach() # [local_B*mult, local_N_atom, 3] + token_to_rep_local = ( + token_to_rep_atom.to_local().detach() + ) # [local_B, local_N_token, local_N_atom_max_per_shard] + resolved_mask_local = true_coords_resolved_mask.to_local().detach() # [local_B*mult, local_N_atom] + + # Validate local atom dimension consistency (co-sharding requirement) + # token_to_rep_local.shape[2] is N_atom_max_per_shard (Replicate dim -> local) + # pred_coords_local.shape[1] is local N_atom (Shard(1) dim -> local) + # These must match for the einsum to work correctly + local_N_atom = pred_coords_local.shape[1] + N_atom_max_per_shard = token_to_rep_local.shape[2] + if N_atom_max_per_shard != local_N_atom: + raise ValueError( + f"Co-sharding violation: token_to_rep_atom local atom dim ({N_atom_max_per_shard}) " + f"!= pred_atom_coords local atom dim ({local_N_atom}). " + "These must match due to diagonal block co-sharding semantics." + ) + + # Promote to at least float32 to match serial get_target_pde which uses + # .float() inside autocast(enabled=False). + coord_dtype = torch.promote_types(pred_coords_local.dtype, torch.float32) + pred_coords_local = pred_coords_local.to(dtype=coord_dtype) + true_coords_local = true_coords_local.to(dtype=coord_dtype) + token_to_rep_local = token_to_rep_local.to(dtype=coord_dtype) + + # Get local batch size for einsum reshaping + local_B = token_to_rep_local.shape[0] + local_N_token = token_to_rep_local.shape[1] + + # === Project atom coords to token space using einsum with multiplicity broadcasting === + # Co-sharding assumption: token_to_rep_local operates on local atoms only (diagonal block). + pred_coords_reshaped = pred_coords_local.view(local_B, multiplicity, -1, 3) + true_coords_reshaped = true_coords_local.view(local_B, multiplicity, -1, 3) + + # token_to_rep_local: [local_B, local_N_token, local_N_atom] -> "bta" + # coords reshaped: [local_B, mult, local_N_atom, 3] -> "bmac" + # Output: [local_B, mult, local_N_token, 3] -> "bmtc" + pred_token_coords_row_local = torch.einsum("bta,bmac->bmtc", token_to_rep_local, pred_coords_reshaped).reshape( + -1, local_N_token, 3 + ) # [local_B*mult, local_N_token, 3] + true_token_coords_row_local = torch.einsum("bta,bmac->bmtc", token_to_rep_local, true_coords_reshaped).reshape( + -1, local_N_token, 3 + ) # [local_B*mult, local_N_token, 3] + + # === Compute factorized mask using einsum with multiplicity broadcasting === + resolved_mask_reshaped = resolved_mask_local.view(local_B, multiplicity, -1).to( + dtype=coord_dtype + ) # [local_B, mult, local_N_atom] + mask_row_local = torch.einsum("bta,bma->bmt", token_to_rep_local, resolved_mask_reshaped).reshape( + -1, local_N_token + ) # [local_B*mult, local_N_token] + + # === Get global shapes for DTensor metadata === + B_mult_global = pred_pde.shape[0] + N_token_global = pred_pde.shape[1] # Row dimension + num_bins = pred_pde.shape[-1] + + # === Derive target placements for transpose === + # From (S(0), S(1), R) to (S(0), R, S(1)) + target_placements = (input_placements[0], input_placements[2], input_placements[1]) + + # === Create DTensors for column tensors that need transpose === + coords_3d_shape = (B_mult_global, N_token_global, 3) + coords_3d_stride = LayoutRightMap(coords_3d_shape).strides + + pred_token_coords_col_dtensor = DTensor.from_local( + pred_token_coords_row_local.clone(), # Clone since we need separate row/col + device_mesh, + input_placements, + shape=torch.Size(coords_3d_shape), + stride=coords_3d_stride, + ) + true_token_coords_col_dtensor = DTensor.from_local( + true_token_coords_row_local.clone(), + device_mesh, + input_placements, + shape=torch.Size(coords_3d_shape), + stride=coords_3d_stride, + ) + + # Create DTensor for mask_col: [B*mult, N_token] + mask_2d_shape = (B_mult_global, N_token_global) + mask_2d_stride = LayoutRightMap(mask_2d_shape).strides + + mask_col_dtensor = DTensor.from_local( + mask_row_local.clone(), + device_mesh, + input_placements, + shape=torch.Size(mask_2d_shape), + stride=mask_2d_stride, + ) + + # === redistribute_transpose for column tensors === + pred_token_coords_col_t = redistribute_transpose( + pred_token_coords_col_dtensor, comm, target_placements, dim0=None, dim1=None + ) + true_token_coords_col_t = redistribute_transpose( + true_token_coords_col_dtensor, comm, target_placements, dim0=None, dim1=None + ) + mask_col_t = redistribute_transpose(mask_col_dtensor, comm, target_placements, dim0=None, dim1=None) + + # === Fused subgraph: cdist_pde + all_reduce (CP) + normalize + all_reduce (DP) + mean === + # + # WHY THIS WORKS: + # --------------- + # dist.all_reduce is an in-place op that is invisible to PyTorch autograd. + # Autograd records the computation graph as if all_reduce never happened, + # but the actual tensor values ARE the all_reduced values. + # + # This is correct because all_reduce(SUM) has IDENTITY GRADIENT: + # Forward: y = sum_over_ranks(x_i), all ranks get the same y + # Backward: ∂L/∂x_i = ∂L/∂y * ∂y/∂x_i = ∂L/∂y * 1 = ∂L/∂y + # + # So autograd "accidentally" computes the correct gradient by ignoring + # all_reduce, since passing the gradient through unchanged is exactly + # what all_reduce(SUM) backward should do. + # + # IMPORTANT: This ONLY works for ReduceOp.SUM. Other ops (MEAN, MAX, etc.) + # have non-identity gradients and would produce wrong results. However, + # this is guaranteed by the PDE loss semantics: we're computing + # loss = mean_b(sum_{i,j}(errors * mask) / sum_{i,j}(mask)) + # which requires SUM reduction to accumulate partial sums across ranks. + # + # Get process group for DP all_reduce (batch dimension) + group_dp = device_mesh.get_group(0) + + with torch.enable_grad(): + # Kernel returns fully summed outputs [B_mul] (sum over both row and col axes) + out_loss_num_local, out_mask_denom_local = cdist_pde( + pred_pde=pred_pde_local, + true_coords_row=true_token_coords_row_local, + true_coords_col=true_token_coords_col_t.to_local(), + pred_coords_row=pred_token_coords_row_local, + pred_coords_col=pred_token_coords_col_t.to_local(), + mask_row=mask_row_local, + mask_col=mask_col_t.to_local(), + multiplicity=multiplicity, + num_bins=num_bins, + max_dist=max_dist, + ) + + # All-reduce across full CP group for numerator and denominator. + # Shape: [B_mul_local] where B_mul_local = (B * multiplicity) / dp_size. + # Typically small (e.g., 1-4 elements), so clone cost is negligible. + # + # Clone prevents in-place all_reduce from modifying tensors that upstream + # ops (cdist_pde) might have saved for backward. This is not strictly + # necessary here because we know cdist_pde only saves its inputs (pred_pde, + # coords, masks), not its outputs (out_loss_num_local, out_mask_denom_local). + # We keep the clone for safety and clarity since the cost is negligible. + # + # torch.no_grad() is required to prevent autograd from recording all_reduce + # in the computation graph. Without it, autograd.grad() encounters the + # all_reduce node during backward and emits a warning because all_reduce + # has no registered autograd kernel. This is safe because all_reduce(SUM) + # has identity gradient (grad just passes through unchanged). + numerator = out_loss_num_local.clone() + denominator = out_mask_denom_local.clone() + with torch.no_grad(): + dist.all_reduce(numerator, op=dist.ReduceOp.SUM, group=comm.group) + dist.all_reduce(denominator, op=dist.ReduceOp.SUM, group=comm.group) + + # Elementwise ops: clamp and divide to get per-sample loss + eps = 1e-7 + denominator_safe = torch.clamp(denominator, min=eps) + per_sample_loss_local = numerator / denominator_safe # [B_mul_local] + + # Sum over local batch samples + loss_sum_local = per_sample_loss_local.sum() # scalar + + # All-reduce across DP group to get global sum. + # Shape: scalar (0-dim tensor), so clone cost is negligible. + # + # Clone needed: the .sum() op is a standard PyTorch operation that may + # internally save tensors for backward. Clone ensures in-place all_reduce + # won't corrupt any saved state. + # + # torch.no_grad() prevents autograd from recording all_reduce in the graph. + # Same reasoning as above: all_reduce(SUM) has identity gradient. + loss_sum = loss_sum_local.clone() + with torch.no_grad(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM, group=group_dp) + + # Mean over global batch size + loss_local = loss_sum / B_mult_global + + # === Save for backward === + ctx.save_for_backward(pred_pde_local, loss_local) + ctx.device_mesh = device_mesh + ctx.pred_pde_placements = pred_pde_placements + ctx.pred_pde_shape = pred_pde.shape + ctx.pred_pde_stride = pred_pde.stride() + + # === Wrap output as DTensor === + # Output is a scalar, fully replicated across all mesh dimensions + output_placements = (Replicate(), Replicate(), Replicate()) + + loss = DTensor.from_local( + loss_local.detach(), + device_mesh=device_mesh, + placements=output_placements, + shape=torch.Size(()), + stride=(), + ) + + return loss + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward( + ctx: FunctionCtx, + grad_loss: DTensor, + ) -> tuple[DTensor | None, None, None, None, None, None, None, None]: + """Backward pass using subgraph trick via autograd.grad. + + The forward pass fuses the entire PDE loss computation into one subgraph: + cdist_pde -> all_reduce(CP) -> normalize -> sum -> all_reduce(DP) -> mean. + Autograd ignores all_reduce (which has identity gradient for SUM), + so we can backprop through the entire fused computation in one call. + + Parameters + ---------- + ctx : FunctionCtx + Context containing saved tensors and metadata. + grad_loss : DTensor + Gradient of upstream loss w.r.t. this loss output (scalar). + + Returns + ------- + tuple + (grad_pred_pde, None, None, None, None, None, None, None) + Only pred_pde has gradients. + """ + pred_pde_local, loss_local = ctx.saved_tensors + + if not pred_pde_local.requires_grad: + return None, None, None, None, None, None, None, None + + # === Validate grad_loss === + # grad_loss should be a scalar DTensor with fully replicated placements + if grad_loss.shape != torch.Size(()): + raise ValueError(f"grad_loss must be scalar, got shape {grad_loss.shape}") + + expected_grad_placements = (Replicate(), Replicate(), Replicate()) + if grad_loss.placements != expected_grad_placements: + raise ValueError(f"grad_loss placements {grad_loss.placements} must be {expected_grad_placements}") + + if grad_loss.device_mesh != ctx.device_mesh: + raise ValueError("grad_loss has different device_mesh than forward inputs") + + # Get local gradient (scalar, replicated) + grad_loss_local = grad_loss.to_local() + + # Backprop through entire fused subgraph: loss -> pred_pde + # Autograd treats all_reduce as invisible, which is correct since + # all_reduce(SUM) has identity gradient (see forward pass comment). + (grad_pred_pde_local,) = torch.autograd.grad( + outputs=[loss_local], + inputs=[pred_pde_local], + grad_outputs=[grad_loss_local], + retain_graph=False, + ) + + # Wrap as DTensor + grad_pred_pde = DTensor.from_local( + grad_pred_pde_local, + device_mesh=ctx.device_mesh, + placements=ctx.pred_pde_placements, + shape=ctx.pred_pde_shape, + stride=ctx.pred_pde_stride, + ) + + return grad_pred_pde, None, None, None, None, None, None, None + + +def pde_loss( + pred_pde: DTensor, + pred_atom_coords: DTensor, + true_atom_coords: DTensor, + true_coords_resolved_mask: DTensor, + feats: dict[str, DTensor], + comm: TransposeComm, + multiplicity: int = 1, + max_dist: float = 32.0, +) -> DTensor: + """Compute the PDE loss using DTensor. + + This is the DTensor version of boltz.model.loss.confidencev2.pde_loss. + It computes the cross-entropy loss between predicted PDE bins and target + PDE scores computed from pairwise coordinate distances. + + The entire computation is fused into _PDELossImpl using native PyTorch ops: + 1. Coordinate mapping via einsum (token_to_rep_atom @ coords) + 2. redistribute_transpose for column tensors + 3. cdist_pde Triton kernel (fused distance + binning + cross-entropy + masking + sum) + 4. all_reduce(SUM) across CP group for numerator/denominator + 5. Normalization: numerator / clamp(denominator, min=eps) + 6. Batch sum + all_reduce(SUM) across DP group + 7. Mean scaling by 1/batch_size + + All all_reduce ops use SUM which has identity gradient, allowing the entire + computation to be captured in a single autograd subgraph. + + Parameters + ---------- + pred_pde : DTensor + Predicted PDE logits with shape (B * multiplicity, N_token, N_token, num_bins). + Placements: (Shard(0), Shard(1), Shard(2)). + pred_atom_coords : DTensor + Predicted atom coordinates with shape (B * multiplicity, N_atom, 3). + Placements: (Shard(0), Shard(1), Replicate()). + true_atom_coords : DTensor + Ground truth atom coordinates with shape (B * multiplicity, N_atom, 3). + Placements: (Shard(0), Shard(1), Replicate()). + true_coords_resolved_mask : DTensor + Mask for resolved coordinates with shape (B * multiplicity, N_atom). + Placements: (Shard(0), Shard(1), Replicate()). + feats : dict[str, DTensor] + Dictionary containing feature tensors: + - "token_to_rep_atom": [B, N_token, N_atom_max_per_shard] + comm : TransposeComm + Communication object for redistribute_transpose operations. + multiplicity : int, optional + Diffusion batch multiplier, by default 1. + max_dist : float, optional + Maximum distance for binning, by default 32.0. + + Returns + ------- + DTensor + Scalar loss with placements (Replicate(), Replicate(), Replicate()). + """ + # Extract token_to_rep_atom from features + token_to_rep_atom = feats["token_to_rep_atom"] + + # _PDELossImpl fuses the entire loss computation using native PyTorch ops: + # cdist_pde -> all_reduce(CP) -> normalize -> sum -> all_reduce(DP) -> mean + loss = _PDELossImpl.apply( + pred_pde, + pred_atom_coords, + true_atom_coords, + true_coords_resolved_mask, + token_to_rep_atom, + comm, + multiplicity, + max_dist, + ) + + return loss + + +def confidence_loss( + model_out: dict[str, DTensor], + feats: dict[str, DTensor], + true_coords: DTensor, + true_coords_resolved_mask: DTensor, + comm: TransposeComm, + token_level_confidence: bool = True, + multiplicity: int = 1, + alpha_pae: float = 0.0, + mask_loss: Optional[DTensor] = None, + relative_supervision_weight: float = 0.0, + dist_manager: Optional[DistributedManager] = None, + group_layout: Optional[LayoutMap] = None, +) -> dict[str, DTensor | dict[str, DTensor]]: + """Compute confidence loss using DTensor operations. + + This is the DTensor-compatible version of boltz.model.loss.confidencev2.confidence_loss. + It aggregates plddt, pde, resolved, and (optionally) pae losses. + + The sub-loss implementations (plddt_loss, pde_loss, resolved_loss) operate at + token level, matching the Boltz-2 ``token_level_confidence=True`` setting. + + Parameters + ---------- + model_out : dict[str, DTensor] + Dictionary containing the model output DTensors: + - "plddt_logits": Shape [B*mult, N_token, num_bins], Placements (Shard(0), Shard(1), Replicate()) + - "pde_logits": Shape [B*mult, N_token, N_token, num_bins], Placements (Shard(0), Shard(1), Shard(2)) + - "resolved_logits": Shape [B*mult, N_token, 2], Placements (Shard(0), Shard(1), Replicate()) + - "sample_atom_coords": Shape [B*mult, N_atom, 3], Placements (Shard(0), Shard(1), Replicate()) + - "pae_logits" (when alpha_pae > 0): Shape [B*mult, N_token, N_token, num_bins], + Placements (Shard(0), Shard(1), Shard(2)) + feats : dict[str, DTensor] + Dictionary containing the model input DTensors: + - "token_to_rep_atom": Shape [B, N_token, N_atom_max_per_shard], Placements (Shard(0), Shard(1), Replicate()) + - "token_pad_mask": Shape [B, N_token], Placements (Shard(0), Shard(1), Replicate()) + - "r_set_to_rep_atom": Shape [B, N_R, N_atom_max_per_shard], Placements (Shard(0), Shard(1), Replicate()) + - "atom_to_token": Shape [B, N_atom, N_token_max_per_shard], Placements (Shard(0), Shard(1), Replicate()) + - "mol_type": Shape [B, N_token], Placements (Shard(0), Shard(1), Replicate()) + true_coords : DTensor + The atom coordinates after symmetry correction. + Shape [B*mult, N_atom, 3], Placements (Shard(0), Shard(1), Replicate()) + true_coords_resolved_mask : DTensor + The resolved mask after symmetry correction. + Shape [B*mult, N_atom], Placements (Shard(0), Shard(1), Replicate()) + comm : TransposeComm + Communication object for redistribute_transpose operations. + token_level_confidence : bool, optional + Must be True (default). The atom-level path (False) is not implemented. + multiplicity : int, optional + The diffusion batch size, by default 1 + alpha_pae : float, optional + The weight of the pae loss, by default 0.0. + mask_loss : DTensor, optional + Per-sample loss mask. Not yet implemented; must be None. + relative_supervision_weight : float, optional + Weight for relative confidence supervision. Not yet implemented; must be 0.0. + dist_manager : DistributedManager, optional + Required when alpha_pae > 0.0 for pae_loss communication. + group_layout : LayoutMap, optional + Required when alpha_pae > 0.0 for pae_loss 2D CP grid layout. + + Returns + ------- + dict[str, DTensor | dict[str, DTensor]] + Dictionary containing: + - "loss": Scalar DTensor with total loss, Placements (Replicate(), Replicate(), Replicate()) + - "loss_breakdown": dict with individual loss DTensors: + - "plddt_loss": Scalar DTensor + - "pde_loss": Scalar DTensor + - "resolved_loss": Scalar DTensor + - "pae_loss": Scalar DTensor + + See Also + -------- + boltz.model.loss.confidencev2.confidence_loss : Serial version. + plddt_loss : DTensor pLDDT loss computation. + pde_loss : DTensor PDE loss computation. + resolved_loss : DTensor resolved loss computation. + pae_loss : DTensor PAE loss computation. + """ + if not token_level_confidence: + raise NotImplementedError( + "confidence_loss only supports token_level_confidence=True. " + "The atom-level confidence path is not implemented for DTensor." + ) + if mask_loss is not None: + raise NotImplementedError("confidence_loss does not yet support mask_loss. " "Pass mask_loss=None (default).") + if relative_supervision_weight != 0.0: + raise NotImplementedError( + "confidence_loss does not yet support relative_supervision_weight != 0.0. " + f"Got {relative_supervision_weight}." + ) + if alpha_pae > 0.0: + pae = pae_loss( + model_out["pae_logits"], + model_out["sample_atom_coords"], + true_coords, + true_coords_resolved_mask, + feats, + comm, + dist_manager, + group_layout, + multiplicity, + ) + else: + device_mesh = model_out["plddt_logits"].device_mesh + pae = DTensor.from_local( + torch.tensor(0.0, device=model_out["plddt_logits"].device), + device_mesh=device_mesh, + placements=(Replicate(), Replicate(), Replicate()), + shape=torch.Size(()), + stride=(), + ) + + # Compute plddt loss + plddt = plddt_loss( + pred_lddt=model_out["plddt_logits"], + pred_atom_coords=model_out["sample_atom_coords"], + true_atom_coords=true_coords, + true_coords_resolved_mask=true_coords_resolved_mask, + feats=feats, + comm=comm, + multiplicity=multiplicity, + ) + + # Compute pde loss + pde = pde_loss( + pred_pde=model_out["pde_logits"], + pred_atom_coords=model_out["sample_atom_coords"], + true_atom_coords=true_coords, + true_coords_resolved_mask=true_coords_resolved_mask, + feats=feats, + comm=comm, + multiplicity=multiplicity, + ) + + # Compute resolved loss + resolved = resolved_loss( + pred_resolved=model_out["resolved_logits"], + feats=feats, + true_coords_resolved_mask=true_coords_resolved_mask, + multiplicity=multiplicity, + ) + + # Sum the losses: loss = plddt + pde + resolved + alpha_pae * pae + loss = elementwise_op(plddt, pde, ElementwiseOp.SUM) + loss = elementwise_op(loss, resolved, ElementwiseOp.SUM) + if alpha_pae > 0.0: + pae_scaled = scalar_tensor_op(alpha_pae, pae, ElementwiseOp.PROD) + loss = elementwise_op(loss, pae_scaled, ElementwiseOp.SUM) + + # Build output dictionary + dict_out = { + "loss": loss, + "loss_breakdown": { + "plddt_loss": plddt, + "pde_loss": pde, + "resolved_loss": resolved, + "pae_loss": pae, + }, + } + + return dict_out diff --git a/src/boltz/distributed/model/loss/diffusion.py b/src/boltz/distributed/model/loss/diffusion.py new file mode 100644 index 000000000..c66381e18 --- /dev/null +++ b/src/boltz/distributed/model/loss/diffusion.py @@ -0,0 +1,1063 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import warnings + +import torch +from einops import einsum +from torch.distributed.tensor import DTensor, Replicate, Shard + +try: + from boltz.distributed.model.loss.triton.smooth_lddt_loss import ( + grid_launch_config, + smooth_lddt_loss_bwd_kernel, + smooth_lddt_loss_fwd_kernel, + ) + + has_smooth_lddt_loss_triton_kernels = True +except ImportError: + has_smooth_lddt_loss_triton_kernels = False + + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.model.layers.clip import clip +from boltz.distributed.model.layers.elementwise_op import ( + ElementwiseOp, + elementwise_op, + scalar_tensor_op, + single_tensor_op, +) +from boltz.distributed.model.layers.outer_op import OuterOp, replicate_to_shard_outer_op +from boltz.distributed.model.layers.redistribute_transpose import redistribute_transpose +from boltz.distributed.model.layers.repeat_interleave import shardwise_repeat_interleave +from boltz.distributed.model.layers.replicate_op import ReplicateOp, replicate_op +from boltz.distributed.model.layers.sharded_op import sharded_sum +from boltz.distributed.model.layers.where import where +from boltz.distributed.utils import LayoutRightMap + + +def weighted_rigid_align( + true_coords: DTensor, + pred_coords: DTensor, + weights: DTensor, + mask: DTensor, +) -> DTensor: + """Compute weighted alignment and return the aligned true coordinates using DTensor. + + Implements the same algorithm as boltz.model.loss.diffusionv2.weighted_rigid_align + (Algorithm 28). Supports only 3D inputs (B, N, 3) for coords and 2D (B, N) for + weights/mask, with placements (Shard(0), Shard(1), Replicate()). + SVD is computed in float64 for numerical stability. + This function is NOT differentiable (uses torch.no_grad() internally). + + The computation is structured to ensure binary-identical results across + cp_axis_1 column groups by computing on column-0 then broadcasting. + SVD is computed on rank (:, 0, 0) then broadcast to avoid numerical + divergence from parallel SVD. + + Parameters + ---------- + true_coords : DTensor + Ground truth atom coordinates, shape (B, N, 3). + Placements: (Shard(0), Shard(1), Replicate()). + pred_coords : DTensor + Predicted atom coordinates, shape (B, N, 3). + Placements: (Shard(0), Shard(1), Replicate()). + weights : DTensor + Alignment weights, shape (B, N). + Placements: (Shard(0), Shard(1), Replicate()). + mask : DTensor + Atom mask, shape (B, N). + Placements: (Shard(0), Shard(1), Replicate()). + + Returns + ------- + DTensor + Aligned true coordinates with same placements as input true_coords. + + """ + # Ndim checks (3D coords, 2D weights/mask) + if true_coords.ndim != 3: + raise ValueError(f"true_coords must be 3D (B, N, 3), got ndim={true_coords.ndim}") + if pred_coords.ndim != 3: + raise ValueError(f"pred_coords must be 3D (B, N, 3), got ndim={pred_coords.ndim}") + if weights.ndim != 2: + raise ValueError(f"weights must be 2D (B, N), got ndim={weights.ndim}") + if mask.ndim != 2: + raise ValueError(f"mask must be 2D (B, N), got ndim={mask.ndim}") + + # Shape checks + if true_coords.shape != pred_coords.shape: + raise ValueError(f"true_coords shape {true_coords.shape} != pred_coords shape {pred_coords.shape}") + if weights.shape != mask.shape: + raise ValueError(f"weights shape {weights.shape} != mask shape {mask.shape}") + if weights.shape != true_coords.shape[:2]: + raise ValueError(f"weights shape {weights.shape} != expected {true_coords.shape[:2]}") + + # Device mesh checks + if true_coords.device_mesh != pred_coords.device_mesh: + raise ValueError("true_coords and pred_coords must be on the same device_mesh") + if true_coords.device_mesh != weights.device_mesh: + raise ValueError("true_coords and weights must be on the same device_mesh") + if true_coords.device_mesh != mask.device_mesh: + raise ValueError("true_coords and mask must be on the same device_mesh") + + # Placement checks + placements = (Shard(0), Shard(1), Replicate()) + if true_coords.placements != placements: + raise ValueError(f"true_coords placements {true_coords.placements} != expected {placements}") + if pred_coords.placements != placements: + raise ValueError(f"pred_coords placements {pred_coords.placements} != expected {placements}") + if weights.placements != placements: + raise ValueError(f"weights placements {weights.placements} != expected {placements}") + if mask.placements != placements: + raise ValueError(f"mask placements {mask.placements} != expected {placements}") + + with torch.no_grad(): + # Convert to local tensors + true_coords_local = true_coords.to_local() + pred_coords_local = pred_coords.to_local() + weights_local = weights.to_local() + mask_local = mask.to_local() + + device_mesh = true_coords.device_mesh + # mesh axis 1: reduce along atom dimension (cp_axis_0) + # mesh axis 2: broadcast along coordinate dimension (cp_axis_1, where coords are replicated) + group_reduce_atoms = device_mesh.get_group(1) + group_broadcast = device_mesh.get_group(2) + + rank_coord = device_mesh.get_coordinate() + assert rank_coord is not None + + batch_size, num_points, dim = true_coords_local.shape + weights_expanded = (mask_local * weights_local).unsqueeze(-1) + + # Scalar degenerate check (all ranks; compatible with per-batch check below) + total_num_points = num_points * device_mesh.shape[1] + if total_num_points < (dim + 1): + warnings.warn( + "The size of one of the point clouds is <= dim+1. " + "`WeightedRigidAlign` cannot return a unique rotation.", + UserWarning, + stacklevel=1, + ) + + is_first_column_group = rank_coord[-1] == 0 + rank_broadcast = torch.distributed.get_global_rank(group_broadcast, 0) + + if is_first_column_group: + # Per-batch degenerate check first (same as diffusionv2: mask.sum(dim=-1) < (dim + 1)) + mask_count_global = mask_local.sum(dim=1).clone() + torch.distributed.all_reduce( + mask_count_global, + op=torch.distributed.ReduceOp.SUM, + group=group_reduce_atoms, + ) + degenerate_batch_indices = torch.where(mask_count_global < (dim + 1))[0] + if degenerate_batch_indices.numel() > 0: + warnings.warn( + f"[rank_coord:{rank_coord}] " + "The size of one of the point clouds is <= dim+1. " + "`WeightedRigidAlign` cannot return a unique rotation. " + f"Batch indices (subset): {degenerate_batch_indices.tolist()}", + UserWarning, + stacklevel=1, + ) + + # Compute on column-0 then broadcast for binary-identical results + # Overlapped async reductions for centroids (dim=1 = points, equiv. dim=-2 in serial v2) + weights_sum_local = weights_expanded.sum(dim=1, keepdim=True) + req_reduce_weights = torch.distributed.all_reduce( + weights_sum_local, op=torch.distributed.ReduceOp.SUM, group=group_reduce_atoms, async_op=True + ) + + true_coords_weighted_sum_local = (true_coords_local * weights_expanded).sum( + dim=1, keepdim=True + ) # points dim + req_reduce_true_coords = torch.distributed.all_reduce( + true_coords_weighted_sum_local, + op=torch.distributed.ReduceOp.SUM, + group=group_reduce_atoms, + async_op=True, + ) + + pred_coords_weighted_sum_local = (pred_coords_local * weights_expanded).sum( + dim=1, keepdim=True + ) # points dim + req_reduce_pred_coords = torch.distributed.all_reduce( + pred_coords_weighted_sum_local, + op=torch.distributed.ReduceOp.SUM, + group=group_reduce_atoms, + async_op=True, + ) + + req_reduce_weights.wait() + req_reduce_true_coords.wait() + true_centroid = true_coords_weighted_sum_local / weights_sum_local + + req_reduce_pred_coords.wait() + pred_centroid = pred_coords_weighted_sum_local / weights_sum_local + torch.distributed.broadcast(true_centroid, rank_broadcast, group=group_broadcast) + torch.distributed.broadcast(pred_centroid, rank_broadcast, group=group_broadcast) + else: + true_centroid = torch.empty_like(true_coords_local[:, 0:1, :]) + pred_centroid = torch.empty_like(pred_coords_local[:, 0:1, :]) + torch.distributed.broadcast(true_centroid, rank_broadcast, group=group_broadcast) + torch.distributed.broadcast(pred_centroid, rank_broadcast, group=group_broadcast) + + # Center the coordinates + true_coords_centered = true_coords_local - true_centroid + pred_coords_centered = pred_coords_local - pred_centroid + + # Compute the weighted covariance matrix + cov_matrix_local = einsum( + weights_expanded * pred_coords_centered, true_coords_centered, "b n i, b n j -> b i j" + ) + original_dtype = cov_matrix_local.dtype + + if is_first_column_group: + # Reduce covariance matrix, compute SVD on rank (:, 0, 0), broadcast + rank_reduce_cov_matrix = torch.distributed.get_global_rank(group_reduce_atoms, 0) + torch.distributed.reduce( + cov_matrix_local, + op=torch.distributed.ReduceOp.SUM, + dst=rank_reduce_cov_matrix, + group=group_reduce_atoms, + ) + if rank_coord[1] == 0: + # SVD in float64 for numerical stability + cov_matrix_64 = cov_matrix_local.to(dtype=torch.float64) + U, S, V = torch.linalg.svd(cov_matrix_64, driver="gesvd" if cov_matrix_64.is_cuda else None) + V = V.mH + + # Same logic as diffusionv2: scalar num_points check (v2 uses num_points, not per-batch mask) + if (S.abs() <= 1e-15).any() and not (total_num_points < (dim + 1)): + warnings.warn( + f"[rank_coord:{rank_coord}] " + "Excessively low rank of " + "cross-correlation between aligned point clouds. " + "`WeightedRigidAlign` cannot return a unique rotation.", + UserWarning, + stacklevel=1, + ) + + # Rotation matrix with proper determinant + rot_matrix = torch.einsum("b i j, b k j -> b i k", U, V) + F = torch.eye(dim, dtype=cov_matrix_64.dtype, device=cov_matrix_64.device)[None].repeat( + batch_size, 1, 1 + ) + F[:, -1, -1] = torch.det(rot_matrix) + rot_matrix = einsum(U, F, V, "b i j, b j k, b l k -> b i l") + rot_matrix = rot_matrix.to(dtype=original_dtype).contiguous() + torch.distributed.broadcast(rot_matrix, rank_reduce_cov_matrix, group=group_reduce_atoms) + else: + rot_matrix = torch.empty( + (batch_size, dim, dim), dtype=original_dtype, device=true_coords_local.device + ).contiguous() + torch.distributed.broadcast(rot_matrix, rank_reduce_cov_matrix, group=group_reduce_atoms) + # Broadcast within each row + torch.distributed.broadcast(rot_matrix, rank_broadcast, group=group_broadcast) + else: + rot_matrix = torch.empty( + (batch_size, dim, dim), dtype=original_dtype, device=true_coords_local.device + ).contiguous() + torch.distributed.broadcast(rot_matrix, rank_broadcast, group=group_broadcast) + + # Apply rotation and translation + aligned_coords_local = einsum(true_coords_centered, rot_matrix, "b n i, b j i -> b n j") + pred_centroid + + # Convert back to DTensor + aligned_coords = DTensor.from_local( + aligned_coords_local, + device_mesh=device_mesh, + placements=true_coords.placements, + shape=true_coords.shape, + stride=true_coords.stride(), + ) + + return aligned_coords + + +def smooth_lddt_loss( + pred_coords: DTensor, + true_coords: DTensor, + is_nucleotide: DTensor, + coords_mask: DTensor, + comm: TransposeComm, + nucleic_acid_cutoff: float = 30.0, + other_cutoff: float = 15.0, + multiplicity: int = 1, + v2: bool = True, +) -> DTensor: + """Compute the smooth LDDT loss using DTensor. + + NOTE There is potential memory optimization in diffusionv2.py in Boltz2. + + Parameters + ---------- + pred_coords: DTensor + The predicted atom coordinates with placements (Shard(0), Shard(1), Replicate()) + true_coords: DTensor + The ground truth atom coordinates with placements (Shard(0), Shard(1), Replicate()) + is_nucleotide: DTensor + The weights for alignment with placements (Shard(0), Shard(1), Replicate()) + coords_mask: DTensor + The atoms mask with placements (Shard(0), Shard(1), Replicate()) + comm: TransposeComm + The communication object + nucleic_acid_cutoff: float + The cutoff for nucleic acid + other_cutoff: float + The cutoff for other atoms + multiplicity: int + The multiplicity of the atoms + v2: bool + Whether to use the v2 version of the smooth LDDT loss, where the denominator is added with 1e-5. + + Returns + ------- + DTensor + The smooth LDDT loss with placement (Replicate(), Replicate(), Replicate()) + + """ + is_nucleotide = is_nucleotide.to(torch.bool) + + coords_mask_pairwise_section = redistribute_transpose( + coords_mask, + comm, + output_placements=(Shard(0), Replicate(), Shard(1)), + dim0=None, + dim1=None, + ) + true_dists = replicate_to_shard_outer_op(true_coords, OuterOp.CDIST, 1, comm) + dtype = true_dists.dtype + + is_nucleotide = shardwise_repeat_interleave( + is_nucleotide, multiplicity, 0 + ) # (batch_size * multiplicity, num_atoms) + coords_mask = shardwise_repeat_interleave(coords_mask, multiplicity, 0) + coords_mask_pairwise_section = shardwise_repeat_interleave(coords_mask_pairwise_section, multiplicity, 0) + + # broadcast is_nucleotide over the second cp axis + # serial code:is_nucleotide.unsqueeze(-1).expand(-1, -1, is_nucleotide.shape[-1]) + if is_nucleotide.placements != (Shard(0), Shard(1), Replicate()): + raise ValueError( + f"is_nucleotide placements {is_nucleotide.placements} != expected (Shard(0), Shard(1), Replicate())" + ) + + # [B, N] + is_nucleotide_local = is_nucleotide.to_local() + # [B, N, N] + is_nucleotide_pair_local = is_nucleotide_local.unsqueeze(-1).expand(-1, -1, is_nucleotide_local.shape[-1]) + shape_is_nucleotide_pair = (is_nucleotide.shape[0], is_nucleotide.shape[1], is_nucleotide.shape[1]) + # torch.Tensor.expand sets the expanded axes' stride to 0. See official doc: + # https://docs.pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch-tensor-expand + stride_is_nucleotide_pair = is_nucleotide.stride() + (0,) + is_nucleotide_pair = DTensor.from_local( + is_nucleotide_pair_local, + device_mesh=is_nucleotide.device_mesh, + placements=(Shard(0), Shard(1), Shard(2)), + shape=shape_is_nucleotide_pair, + stride=stride_is_nucleotide_pair, + ) + + mask = where( + is_nucleotide_pair, + scalar_tensor_op( + nucleic_acid_cutoff, + true_dists, + ElementwiseOp.GT, + ), + scalar_tensor_op( + other_cutoff, + true_dists, + ElementwiseOp.GT, + ), + ) + mask = mask.to(dtype=dtype) + + # Zero out the diagonal. If in CP mode, this means only diagonal ranks participate. + local_num_samples, local_num_atoms = pred_coords.to_local().shape[:2] + if comm.is_self_comm: + diag_mask_local = 1 - torch.eye(local_num_atoms, device=pred_coords.device) + else: + diag_mask_local = torch.ones(local_num_atoms, local_num_atoms, device=pred_coords.device) + diag_mask_local = diag_mask_local.unsqueeze(0).expand(local_num_samples, -1, -1) + shape_diag_mask = (pred_coords.shape[0], pred_coords.shape[1], pred_coords.shape[1]) + # diag_mask is created from scratch in LayoutRight. The expanded leading axis + # has stride 0 -- see official doc: + # https://docs.pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch-tensor-expand + stride_diag_mask = (0,) + LayoutRightMap(shape=shape_diag_mask[1:]).strides + diag_mask = DTensor.from_local( + diag_mask_local, + device_mesh=mask.device_mesh, + placements=mask.placements, + shape=shape_diag_mask, + stride=stride_diag_mask, + ) + mask = elementwise_op(mask, diag_mask, ElementwiseOp.PROD) + + # Apply coordinate mask + mask = replicate_op(mask, coords_mask, dim_to_unsqueeze_rhs=2, op=ReplicateOp.PROD) + mask = replicate_op(mask, coords_mask_pairwise_section, dim_to_unsqueeze_rhs=1, op=ReplicateOp.PROD) + + # Compute distances between all pairs of atoms + pred_dists = replicate_to_shard_outer_op(pred_coords, OuterOp.CDIST, 1, comm) + + dist_diff = single_tensor_op( + elementwise_op(true_dists, pred_dists, ElementwiseOp.SUB), + ElementwiseOp.ABS, + ) + # Compute epsilon values + eps = single_tensor_op(scalar_tensor_op(0.5, dist_diff, ElementwiseOp.SUB), ElementwiseOp.SIGMOID) + for cutoff in (1.0, 2.0, 4.0): + eps = elementwise_op( + eps, + single_tensor_op( + scalar_tensor_op(cutoff, dist_diff, ElementwiseOp.SUB), + ElementwiseOp.SIGMOID, + ), + ElementwiseOp.SUM, + ) + eps = scalar_tensor_op(0.25, eps, ElementwiseOp.PROD) + + assert mask.requires_grad is False + num = sharded_sum(elementwise_op(eps, mask, ElementwiseOp.PROD), dim=(-1, -2)) + den = sharded_sum(mask, dim=(-1, -2)) # mask have no gradient. thus no need to use torch.no_grad() + if v2: + den = scalar_tensor_op(1e-5, den, ElementwiseOp.SUM) + else: + den = clip(den, min_val=1.0, max_val=None) + + lddt = elementwise_op( + num, + den, + ElementwiseOp.DIV, + ) + lddt = scalar_tensor_op(1.0, lddt, ElementwiseOp.SUB) + + lddt = scalar_tensor_op( + 1 / lddt.shape[0], + sharded_sum(lddt, dim=0), + ElementwiseOp.PROD, + ) # mean along DP axis; placements: (Replicate(), Replicate(), Replicate()) + + return lddt + + +def _smooth_lddt_loss_forward_local( + pred_coords_local: torch.Tensor, + true_coords_local: torch.Tensor, + pred_coords_t_local: torch.Tensor, + true_coords_t_local: torch.Tensor, + is_nucleotide_local: torch.Tensor, + coords_mask_local: torch.Tensor, + coords_mask_t_local: torch.Tensor, + is_self_comm: bool, + nucleic_acid_cutoff: float, + other_cutoff: float, + multiplicity: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Local computation for smooth LDDT loss. + + This function computes the local numerator and denominator for the smooth LDDT loss + on shardwise tensors. + """ + #### Compute forward pass locally to get pair masks #### + dtype = true_coords_local.dtype + + # Expand is_nucleotide and masks locally + # (B, N) -> (B * multiplicity, N) + is_nucleotide_local = is_nucleotide_local.repeat_interleave(multiplicity, dim=0) + coords_mask_local = coords_mask_local.repeat_interleave(multiplicity, dim=0) + coords_mask_t_local = coords_mask_t_local.repeat_interleave(multiplicity, dim=0) + + # Construct pairwise nucleotide mask + num_cols = pred_coords_t_local.shape[1] + is_nucleotide_pair_local = is_nucleotide_local.unsqueeze(-1).expand(-1, -1, num_cols) # O(N^2) tensor + + # Compute true distances + true_dists = torch.cdist(true_coords_local, true_coords_t_local) # O(N^2) tensor + + # Compute mask based on cutoffs + mask = torch.where( + is_nucleotide_pair_local.bool(), + (true_dists < nucleic_acid_cutoff), + (true_dists < other_cutoff), + ).to(dtype=dtype) # O(N^2) tensor + + # Zero out diagonal + local_num_samples, local_num_atoms = pred_coords_local.shape[:2] + if is_self_comm: + diag_mask_local = 1 - torch.eye(local_num_atoms, num_cols, device=pred_coords_local.device) + else: + diag_mask_local = torch.ones(local_num_atoms, num_cols, device=pred_coords_local.device) + + diag_mask_local = diag_mask_local.unsqueeze(0).expand(local_num_samples, -1, -1) # O(N^2) tensor + mask = mask * diag_mask_local + + # Apply coordinate mask + mask = mask * coords_mask_local.unsqueeze(-1) + if coords_mask_t_local.ndim == 3: + mask = mask * coords_mask_t_local.transpose(1, 2) + else: + mask = mask * coords_mask_t_local.unsqueeze(1) + + #### Compute forward pass #### + # Compute predicted distances + pred_dists = torch.cdist(pred_coords_local, pred_coords_t_local) # O(N^2) tensor + dist_diff = (true_dists - pred_dists).abs() # O(N^2) tensor + + # Compute epsilon: O(N^2) tensors + eps = torch.sigmoid(0.5 - dist_diff) + for cutoff in (1.0, 2.0, 4.0): + eps = eps + torch.sigmoid(cutoff - dist_diff) + eps *= 0.25 + + # Compute numerators and denominators + # Sum over local atoms (rows and cols) + num_local = (eps * mask).sum(dim=(1, 2)) + den_local = mask.sum(dim=(1, 2)) + + return num_local, den_local + + +def _smooth_lddt_loss_backward_local( + grad_num_reduced, + grad_den_reduced, + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + is_self_comm, + nucleic_acid_cutoff, + other_cutoff, + multiplicity, +): + #### Recompute forward pass locally to get pair masks #### + dtype = true_coords_local.dtype + + # Expand is_nucleotide and masks locally + # (B, N) -> (B * multiplicity, N) + is_nucleotide_local = is_nucleotide_local.repeat_interleave(multiplicity, dim=0) + coords_mask_local = coords_mask_local.repeat_interleave(multiplicity, dim=0) + coords_mask_t_local = coords_mask_t_local.repeat_interleave(multiplicity, dim=0) + + # Construct pairwise nucleotide mask + num_cols = pred_coords_t_local.shape[1] + is_nucleotide_pair_local = is_nucleotide_local.unsqueeze(-1).expand(-1, -1, num_cols) # O(N^2) tensor + + # Compute true distances + true_dists = torch.cdist(true_coords_local, true_coords_t_local) # O(N^2) tensor + + # Compute mask based on cutoffs + mask = torch.where( + is_nucleotide_pair_local.bool(), + (true_dists < nucleic_acid_cutoff), + (true_dists < other_cutoff), + ).to(dtype=dtype) # O(N^2) tensor + + # Zero out diagonal + local_num_samples, local_num_atoms = pred_coords_local.shape[:2] + if is_self_comm: + diag_mask_local = 1 - torch.eye(local_num_atoms, num_cols, device=pred_coords_local.device) + else: + diag_mask_local = torch.ones(local_num_atoms, num_cols, device=pred_coords_local.device) + + diag_mask_local = diag_mask_local.unsqueeze(0).expand(local_num_samples, -1, -1) # O(N^2) tensor + mask = mask * diag_mask_local + + # Apply coordinate mask + mask = mask * coords_mask_local.unsqueeze(-1) + if coords_mask_t_local.ndim == 3: + mask = mask * coords_mask_t_local.transpose(1, 2) + else: + mask = mask * coords_mask_t_local.unsqueeze(1) + + #### Compute backward pass #### + # Compute pred diffs and dists + # diff_vec = P_row - P_col + diff_vec = pred_coords_local.unsqueeze(2) - pred_coords_t_local.unsqueeze(1) # O(N^2) tensor + pred_dists = diff_vec.norm(dim=-1) # O(N^2) tensor + dist_diff = (true_dists - pred_dists).abs() # O(N^2) tensor + + # Compute d_eps_d_diff: O(N^2) tensors + d_eps_d_diff = torch.zeros_like(dist_diff) + for cutoff in (0.5, 1.0, 2.0, 4.0): + val = cutoff - dist_diff + sig = torch.sigmoid(val) + d_eps_d_diff -= sig * (1 - sig) + d_eps_d_diff *= 0.25 + + # Compute d_L_d_pred_dists + # d(diff)/d(pred) = sign(pred - true) + grad_num_broadcast = grad_num_reduced.view_as(pred_dists[:, 0, 0]).view(-1, 1, 1) + # O(N^2) tensor + d_L_d_pred_dists = grad_num_broadcast * mask * d_eps_d_diff * torch.sign(pred_dists - true_dists) + + # Compute gradients w.r.t coords + # safe normalization + pred_dists_safe = pred_dists.unsqueeze(-1) + 1e-8 # O(N^2) tensor + diff_dir = diff_vec / pred_dists_safe # O(N^2) tensor + + d_L_d_diff_vec = d_L_d_pred_dists.unsqueeze(-1) * diff_dir # O(N^2) tensor + + grad_pred_local = d_L_d_diff_vec.sum(dim=2) + grad_pred_t_local = -d_L_d_diff_vec.sum(dim=1) + + return grad_pred_local, grad_pred_t_local + + +def _smooth_lddt_loss_local_triton_forward( + pred_coords_local: torch.Tensor, + true_coords_local: torch.Tensor, + pred_coords_t_local: torch.Tensor, + true_coords_t_local: torch.Tensor, + is_nucleotide_local: torch.Tensor, + coords_mask_local: torch.Tensor, + coords_mask_t_local: torch.Tensor, + is_self_comm: bool, + nucleic_acid_cutoff: float, + other_cutoff: float, + multiplicity: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Triton-based local computation for smooth LDDT loss forward pass.""" + if not has_smooth_lddt_loss_triton_kernels: + raise ImportError("Smooth LDDT loss Triton kernels are not available.") + + if pred_coords_local.dtype == torch.bfloat16 or pred_coords_local.dtype == torch.float16: + raise ValueError( + f"Triton kernel for smooth LDDT loss does not support {pred_coords_local.dtype} " + "due to precision issues. Please use float32." + ) + + # Handle coords_mask_t shape + if coords_mask_t_local.ndim == 3: + coords_mask_t_local = ( + coords_mask_t_local.squeeze(-1) if coords_mask_t_local.shape[-1] == 1 else coords_mask_t_local.squeeze(1) + ) + + assert pred_coords_local.shape == pred_coords_t_local.shape + assert pred_coords_local.shape == true_coords_local.shape + assert pred_coords_t_local.shape == true_coords_t_local.shape + + # different multiplicity input can share the same mask and is_nucleotide + # input so we only require B * M % B == 0 here, with pred_coords_local.shape[0] = B * M + # and is_nucleotide_local.shape[0] = B + assert pred_coords_local.shape[0] % is_nucleotide_local.shape[0] == 0 + assert pred_coords_local.shape[0] % coords_mask_local.shape[0] == 0 + assert pred_coords_local.shape[0] % coords_mask_t_local.shape[0] == 0 + + assert is_nucleotide_local.shape[1] == pred_coords_local.shape[1] + assert coords_mask_local.shape[1] == pred_coords_local.shape[1] + assert coords_mask_t_local.shape[1] == pred_coords_local.shape[1] + + # TODO Ensure inputs are contiguous where needed or handle strides + # Note: Triton handles strides, so we don't strictly need contiguous, but it helps performance + + B = pred_coords_local.shape[0] + + num_output = torch.zeros(B, device=pred_coords_local.device, dtype=pred_coords_local.dtype) + den_output = torch.zeros(B, device=pred_coords_local.device, dtype=pred_coords_local.dtype) + + # Cast inputs + is_nucleotide_local = is_nucleotide_local.to(dtype=torch.int8) # Use int8 for bool + + # Define grid lambda that autotune will use with the selected BLOCK size + + smooth_lddt_loss_fwd_kernel[grid_launch_config]( + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + num_output, + den_output, + pred_coords_local.stride(0), + pred_coords_local.stride(1), + pred_coords_local.stride(2), + true_coords_local.stride(0), + true_coords_local.stride(1), + true_coords_local.stride(2), + pred_coords_t_local.stride(0), + pred_coords_t_local.stride(1), + pred_coords_t_local.stride(2), + true_coords_t_local.stride(0), + true_coords_t_local.stride(1), + true_coords_t_local.stride(2), + is_nucleotide_local.stride(0), + is_nucleotide_local.stride(1), + coords_mask_local.stride(0), + coords_mask_local.stride(1), + coords_mask_t_local.stride(0), + coords_mask_t_local.stride(1), + nucleic_acid_cutoff, + other_cutoff, + is_self_comm, + pred_coords_local.shape[0], + pred_coords_local.shape[1], + coords_mask_local.shape[0], + ) + + return num_output, den_output + + +def _smooth_lddt_loss_local_triton_backward( + grad_num_reduced: torch.Tensor, + grad_den_reduced: torch.Tensor, + pred_coords_local: torch.Tensor, + true_coords_local: torch.Tensor, + pred_coords_t_local: torch.Tensor, + true_coords_t_local: torch.Tensor, + is_nucleotide_local: torch.Tensor, + coords_mask_local: torch.Tensor, + coords_mask_t_local: torch.Tensor, + is_self_comm: bool, + nucleic_acid_cutoff: float, + other_cutoff: float, + multiplicity: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Triton-based local computation for smooth LDDT loss backward pass.""" + if not has_smooth_lddt_loss_triton_kernels: + raise ImportError("Smooth LDDT loss Triton kernels are not available.") + + if pred_coords_local.dtype == torch.bfloat16 or pred_coords_local.dtype == torch.float16: + raise ValueError( + f"Triton kernel for smooth LDDT loss does not support {pred_coords_local.dtype} " + "due to precision issues. Please use float32." + ) + + if coords_mask_t_local.ndim == 3: + coords_mask_t_local = ( + coords_mask_t_local.squeeze(-1) if coords_mask_t_local.shape[-1] == 1 else coords_mask_t_local.squeeze(1) + ) + + assert pred_coords_local.shape == pred_coords_t_local.shape + assert pred_coords_local.shape == true_coords_local.shape + assert pred_coords_t_local.shape == true_coords_t_local.shape + + assert pred_coords_local.shape[0] == grad_num_reduced.shape[0] + assert pred_coords_local.shape[0] == grad_den_reduced.shape[0] + + # different multiplicity input can share the same mask and is_nucleotide + # input so we only require B * M % B == 0 here, with pred_coords_local.shape[0] = B * M + # and is_nucleotide_local.shape[0] = B + assert pred_coords_local.shape[0] % is_nucleotide_local.shape[0] == 0 + assert pred_coords_local.shape[0] % coords_mask_local.shape[0] == 0 + assert pred_coords_local.shape[0] % coords_mask_t_local.shape[0] == 0 + + assert is_nucleotide_local.shape[1] == pred_coords_local.shape[1] + assert coords_mask_local.shape[1] == pred_coords_local.shape[1] + assert coords_mask_t_local.shape[1] == pred_coords_local.shape[1] + + # Output gradients + grad_pred_local = torch.zeros_like(pred_coords_local) + grad_pred_t_local = torch.zeros_like(pred_coords_t_local) + + is_nucleotide_local = is_nucleotide_local.to(dtype=torch.int8) + + smooth_lddt_loss_bwd_kernel[grid_launch_config]( + grad_num_reduced, + grad_den_reduced, + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + grad_pred_local, + grad_pred_t_local, + pred_coords_local.stride(0), + pred_coords_local.stride(1), + pred_coords_local.stride(2), + true_coords_local.stride(0), + true_coords_local.stride(1), + true_coords_local.stride(2), + pred_coords_t_local.stride(0), + pred_coords_t_local.stride(1), + pred_coords_t_local.stride(2), + true_coords_t_local.stride(0), + true_coords_t_local.stride(1), + true_coords_t_local.stride(2), + is_nucleotide_local.stride(0), + is_nucleotide_local.stride(1), + coords_mask_local.stride(0), + coords_mask_local.stride(1), + coords_mask_t_local.stride(0), + coords_mask_t_local.stride(1), + grad_pred_local.stride(0), + grad_pred_local.stride(1), + grad_pred_local.stride(2), + grad_pred_t_local.stride(0), + grad_pred_t_local.stride(1), + grad_pred_t_local.stride(2), + nucleic_acid_cutoff, + other_cutoff, + is_self_comm, + pred_coords_local.shape[0], + pred_coords_local.shape[1], + coords_mask_local.shape[0], + ) + + return grad_pred_local, grad_pred_t_local + + +class SmoothLDDTLossTritonFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + pred_coords: DTensor, + true_coords: DTensor, + is_nucleotide: DTensor, + coords_mask: DTensor, + comm: TransposeComm, + nucleic_acid_cutoff: float = 30.0, + other_cutoff: float = 15.0, + multiplicity: int = 1, + v2: bool = True, + ) -> DTensor: + # 1. Comm Block: Transpose inputs + # Output placement (Shard(0), Replicate(), Shard(1)) puts the atom dim on Mesh 2 (Cols) + # This matches the requirement for pairwise operations where one operand is on Mesh 1 (Rows) + # and the other on Mesh 2 (Cols). + target_placements = (Shard(0), Replicate(), Shard(1)) + + true_coords_t = redistribute_transpose(true_coords, comm, target_placements, dim0=None, dim1=None) + pred_coords_t = redistribute_transpose(pred_coords, comm, target_placements, dim0=None, dim1=None) + coords_mask_t = redistribute_transpose(coords_mask, comm, target_placements, dim0=None, dim1=None) + + # 2. Comp Block: Local computation with autograd + # We use torch.enable_grad() to allow gradients to flow through the local computation + # which will be used in backward pass. + pred_coords_local = pred_coords.to_local() + true_coords_local = true_coords.to_local() + is_nucleotide_local = is_nucleotide.to_local() + coords_mask_local = coords_mask.to_local() + + pred_coords_t_local = pred_coords_t.to_local() + true_coords_t_local = true_coords_t.to_local() + coords_mask_t_local = coords_mask_t.to_local() + + num_local, den_local = _smooth_lddt_loss_local_triton_forward( + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + bool(comm.is_self_comm), + nucleic_acid_cutoff, + other_cutoff, + multiplicity, + ) + + # 3. Comm Block: Reduction + # Reduce over atom groups (Mesh 1 and Mesh 2) + device_mesh = pred_coords.device_mesh + group_row = device_mesh.get_group(1) + group_col = device_mesh.get_group(2) + + # All reduce num and den + # Note: We can pack them into one tensor for fewer comms + metrics_local = torch.stack([num_local, den_local]) + + # Reduce over row group + torch.distributed.all_reduce(metrics_local, op=torch.distributed.ReduceOp.SUM, group=group_row) + # Reduce over col group + torch.distributed.all_reduce(metrics_local, op=torch.distributed.ReduceOp.SUM, group=group_col) + + num_reduced, den_reduced = metrics_local[0], metrics_local[1] + if v2: + den_reduced = den_reduced + 1e-5 + else: + den_reduced = torch.clamp(den_reduced, min=1.0, max=None) + + # Compute LDDT per sample + lddt_per_sample = num_reduced / den_reduced + lddt_per_sample = 1.0 - lddt_per_sample + + # Final average over batch (Mesh 0) if needed + # The original code does: lddt = scalar_tensor_op(1/N, sharded_sum(lddt, dim=0), ...) + # Here lddt_per_sample is (B*mult,). + # We need to sum over batch dimension locally and reduce over Mesh 0. + + # Note: We save tensors for backward + ctx.comm = comm + ctx.nucleic_acid_cutoff = nucleic_acid_cutoff + ctx.other_cutoff = other_cutoff + ctx.multiplicity = multiplicity + ctx.is_self_comm = bool(comm.is_self_comm) + ctx.device_mesh = device_mesh + ctx.global_batch_size = pred_coords.shape[0] + ctx.pred_coords_shape = pred_coords.shape + ctx.pred_coords_stride = pred_coords.stride() + + ctx.save_for_backward( + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + num_reduced, + den_reduced, + ) + + # Final aggregation + lddt_sum = lddt_per_sample.sum() + # Reduce over batch group (Mesh 0) + group_batch = device_mesh.get_group(0) + torch.distributed.all_reduce(lddt_sum, op=torch.distributed.ReduceOp.SUM, group=group_batch) + + lddt_final = lddt_sum / pred_coords.shape[0] + + # Return as DTensor (Replicate, Replicate, Replicate) + # It's a scalar wrapped in DTensor + return DTensor.from_local( + lddt_final, + device_mesh, + (Replicate(), Replicate(), Replicate()), + shape=torch.Size(()), + stride=(), + ) + + @staticmethod + def backward(ctx, grad_output): + # grad_output is w.r.t lddt_final (scalar) + ( + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + num_reduced, + den_reduced, + ) = ctx.saved_tensors + + # Backprop logic + # lddt = 1 - num / den + # Loss L = lddt_final. + # dL/d(num_reduced) = dL/d(lddt_final) * d(lddt_final)/d(lddt_per_sample) * d(lddt_per_sample)/d(num_reduced) + # d(lddt_final)/d(lddt_per_sample) = 1 / GlobalBatchSize + # d(lddt_per_sample)/d(num_reduced) = -1 / den_reduced + # d(lddt_per_sample)/d(den_reduced) = num_reduced / den_reduced^2 + + grad_output_local = grad_output.to_local().item() + scale = grad_output_local / ctx.global_batch_size + + inv_den = ( + 1.0 / den_reduced + ) # torch.clamp(den_reduced, min=1.0, max=None) or (den_reduced + 1e-5) is done in forward pass + grad_num_reduced = scale * (-inv_den) + grad_den_reduced = scale * (num_reduced * inv_den**2) + + grad_pred_local, grad_pred_t_local = _smooth_lddt_loss_local_triton_backward( + grad_num_reduced, + grad_den_reduced, + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + ctx.is_self_comm, + ctx.nucleic_acid_cutoff, + ctx.other_cutoff, + ctx.multiplicity, + ) + + # Accumulate partial gradients (reduce over missing dimensions) + # grad_pred_local is partial sum over Cols (Dim 2) + # grad_pred_t_local is partial sum over Rows (Dim 1) + group_row = ctx.device_mesh.get_group(1) + group_col = ctx.device_mesh.get_group(2) + torch.distributed.all_reduce(grad_pred_local, op=torch.distributed.ReduceOp.SUM, group=group_col) + torch.distributed.all_reduce(grad_pred_t_local, op=torch.distributed.ReduceOp.SUM, group=group_row) + + # Comm Block: Reverse transpose + # We need to move grad_pred_t_local back to pred_coords layout + # pred_coords_t was (S0, R, S1). grad has same layout. + # We want (S0, S1, R). + + grad_pred_t_local_transposed = ctx.comm.enqueue_to_dispatch(grad_pred_t_local.contiguous()) + ctx.comm.wait_until_finished() + + # Sum gradients + total_grad_pred_local = grad_pred_local + grad_pred_t_local_transposed + + # Return None for true_coords (no grad), etc. + return ( + DTensor.from_local( + total_grad_pred_local, + ctx.device_mesh, + (Shard(0), Shard(1), Replicate()), + shape=ctx.pred_coords_shape, + stride=ctx.pred_coords_stride, + ), # pred_coords + None, # true_coords + None, # is_nucleotide + None, # coords_mask + None, # comm + None, # nucleic_acid_cutoff + None, # other_cutoff + None, # multiplicity + None, # v2 + ) + + +def smooth_lddt_loss_triton( + pred_coords: DTensor, + true_coords: DTensor, + is_nucleotide: DTensor, + coords_mask: DTensor, + comm: TransposeComm, + nucleic_acid_cutoff: float = 30.0, + other_cutoff: float = 15.0, + multiplicity: int = 1, + v2: bool = True, +) -> DTensor: + """Compute the smooth LDDT loss using DTensor and Triton kernel.""" + if not has_smooth_lddt_loss_triton_kernels: + raise ImportError("Smooth LDDT loss Triton kernels are not available.") + + return SmoothLDDTLossTritonFunction.apply( + pred_coords, + true_coords, + is_nucleotide, + coords_mask, + comm, + nucleic_acid_cutoff, + other_cutoff, + multiplicity, + v2, + ) diff --git a/src/boltz/distributed/model/loss/distogram.py b/src/boltz/distributed/model/loss/distogram.py new file mode 100644 index 000000000..c269d40bd --- /dev/null +++ b/src/boltz/distributed/model/loss/distogram.py @@ -0,0 +1,457 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""DTensor-based context-parallel distogram loss for Boltz-2. + +Implements the distogram loss as a single torch.autograd.Function, extracting +local tensors once and performing all computation locally with explicit +all_reduce calls at the required communication points. + +Design rationale for a single autograd.Function: +- Single to_local()/from_local() pair instead of one per DTensor op +- One autograd graph node instead of one per op +- All local math is plain PyTorch and could be torch.compiled +- Minimal dependencies (only TransposeComm for the mask outer product) + +Communication budget: + Forward (4 collective calls): + 1. transpose_then_redistribute for mask outer product (1 call) + 2. all_reduce over CP group for denom (async, overlaps compute) (1 call) + 3. all_reduce over CP group for total (1 call) + 4. all_reduce over DP group for batch mean (1 call) + Backward (0 collective calls): + The backward of all_reduce(SUM) is identity — each rank computes + gradients for its own local spatial chunk with no communication. + +Equivalence to serial code (src/boltz/model/loss/distogramv2.py): +- Aggregate mode: sum+normalize target over K → K_eff=1, compute loss against + each of D predictions, take min over D. For D=1 this is identical to the + serial aggregate path (min/mean over size-1 dims are identity ops). +- Non-aggregate mode: keep K conformers (K_eff=K), compute all K_eff×D + cross-entropies vectorized, min over D, mean over K. + +Named distogram.py (not distogramv2) because v2 only differs from v1 by the +extra num_distogram axis, which falls back to v1 via unsqueeze(3). +""" + +import torch +import torch.distributed as dist +from torch.autograd.function import FunctionCtx +from torch.distributed.tensor import DTensor, Partial, Replicate, Shard + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.model.layers.redistribute_transpose_without_dtensor import transpose_then_redistribute + + +def _build_pairwise_mask_local( + mask_local: torch.Tensor, + comm: TransposeComm, +) -> torch.Tensor: + """Build pairwise [B, N_row, N_col] boolean mask from token mask [B, N_local]. + + This performs the R2S outer-BITAND: mask_i & mask_j where mask_j comes from + the transposed rank. Also zeros the diagonal on self-comm ranks. + + No gradients needed — the mask is treated as a constant. + """ + # mask_local: [B_local, N_local] + # Expand for outer product: [B_local, N_local, 1] + mask_expanded = mask_local.unsqueeze(2) + # Get the transposed chunk: [B_local, 1, N_local_col] + mask_transposed = transpose_then_redistribute(mask_expanded, 1, 2, comm) + # Outer BITAND: [B_local, N_local_row, N_local_col] + mask_2d = mask_expanded & mask_transposed + + # Zero diagonal on self-comm ranks (where row and column ranges overlap) + if comm.is_self_comm: + local_n = mask_local.shape[1] + diag_mask = ~torch.eye(local_n, dtype=torch.bool, device=mask_local.device) + mask_2d = mask_2d & diag_mask.unsqueeze(0) + + return mask_2d + + +class _DistogramLossCP(torch.autograd.Function): + """Single autograd.Function for the full distogram loss. + + Forward: to_local() → local math with explicit all_reduces → from_local() + Backward: local math only (no communication) + """ + + @staticmethod + def forward( + ctx: FunctionCtx, + pred: DTensor, + target: DTensor, + mask_token: DTensor, + comm: TransposeComm, + aggregate_distogram: bool, + ) -> tuple[DTensor, DTensor]: + """Forward pass. + + Parameters + ---------- + pred : DTensor + Prediction logits [B, N, N, D, bins], placements (Shard(0), Shard(1), Shard(2)). + target : DTensor + Target distributions [B, N, N, K, bins], placements (Shard(0), Shard(1), Shard(2)). + mask_token : DTensor + Token validity mask [B, N], placements (Shard(0), Shard(1), Replicate()). + comm : TransposeComm + Communication handle for CP transpose operations. + aggregate_distogram : bool + Whether to aggregate target over K conformers. + + Returns + ------- + global_loss : DTensor + Scalar loss, placements (Replicate(), Replicate(), Replicate()). + batch_loss : DTensor + Per-example loss [B], placements (Shard(0), Replicate(), Replicate()). + """ + # --- Validate inputs --- + if not isinstance(pred, DTensor): + raise TypeError(f"pred must be DTensor, got {type(pred)}") + if not isinstance(target, DTensor): + raise TypeError(f"target must be DTensor, got {type(target)}") + if not isinstance(mask_token, DTensor): + raise TypeError(f"mask_token must be DTensor, got {type(mask_token)}") + + device_mesh = pred.device_mesh + + for name, dtensor in [("target", target), ("mask_token", mask_token)]: + if dtensor.device_mesh != device_mesh: + raise ValueError(f"{name} has different device_mesh than pred") + + expected_pairlike = (Shard(0), Shard(1), Shard(2)) + if pred.placements != expected_pairlike: + raise ValueError(f"pred placements {pred.placements} must be {expected_pairlike}") + if target.placements != expected_pairlike: + raise ValueError(f"target placements {target.placements} must be {expected_pairlike}") + expected_mask = (Shard(0), Shard(1), Replicate()) + if mask_token.placements != expected_mask: + raise ValueError(f"mask_token placements {mask_token.placements} must be {expected_mask}") + for name, dtensor in [("pred", pred), ("target", target), ("mask_token", mask_token)]: + # Check placements and handle sharded dimensions + for i_dim, placement in enumerate(dtensor.placements): + if isinstance(placement, Partial): + raise ValueError(f"Partial placement on {name} mesh dim {i_dim} is not supported") + elif isinstance(placement, Shard): + # Check that sharded dimensions are evenly divided + if dtensor.shape[placement.dim] % device_mesh.shape[i_dim] != 0: + raise ValueError( + f"Uneven sharding {name} tensor dimension {placement.dim} of size {dtensor.shape[placement.dim]} " + f"along device mesh dimension {i_dim} of size {device_mesh.shape[i_dim]} is not supported" + ) + if pred.ndim != 5: # noqa: PLR2004 + raise ValueError(f"pred must be 5D [B, N, N, D, bins], got {pred.ndim}D") + if target.ndim != 5: # noqa: PLR2004 + raise ValueError(f"target must be 5D [B, N, N, K, bins], got {target.ndim}D") + if mask_token.ndim != 2: # noqa: PLR2004 + raise ValueError(f"mask_token must be 2D [B, N], got {mask_token.ndim}D") + if pred.shape[0] != target.shape[0] or pred.shape[1] != target.shape[1] or pred.shape[2] != target.shape[2]: + raise ValueError(f"pred shape {pred.shape} and target shape {target.shape} must match on dims 0,1,2") + if pred.shape[4] != target.shape[4]: + raise ValueError(f"pred bins {pred.shape[4]} != target bins {target.shape[4]}") + if mask_token.shape[0] != pred.shape[0] or mask_token.shape[1] != pred.shape[1]: + raise ValueError(f"mask_token shape {mask_token.shape} inconsistent with pred shape {pred.shape}") + + # --- Extract local tensors (single to_local() per input) --- + # Compute in at least float32 for numerical stability, while respecting + # float64 if either input is float64. + compute_dtype = torch.promote_types(torch.promote_types(pred.dtype, target.dtype), torch.float32) + pred_local = pred.to_local().to(compute_dtype) # [B_local, N_row, N_col, D, bins] + target_local = target.to_local().to(compute_dtype) # [B_local, N_row, N_col, K, bins] + mask_token_local = mask_token.to_local().to(torch.bool) # [B_local, N_local] + + D = pred_local.shape[3] # noqa: N806 + K = target_local.shape[3] # noqa: N806 + + # --- Build pairwise mask (involves R2S transpose, no grad needed) --- + mask_local = _build_pairwise_mask_local(mask_token_local, comm).to(compute_dtype) + # mask_local: [B_local, N_row, N_col] + + # --- Denom: launch async all_reduce so latency overlaps with compute --- + cp_group = comm.group + denom_local = mask_local.sum(dim=(-1, -2)) # [B_local] + denom_work = dist.all_reduce(denom_local, op=dist.ReduceOp.SUM, group=cp_group, async_op=True) + + # --- Target preparation --- + if aggregate_distogram: + # Sum over K conformers, normalize → single effective conformer + P_local = target_local.sum(dim=3) # [B,N_r,N_c,bins] + P_denom = P_local.sum(dim=-1, keepdim=True).clamp(min=1) # [B,N_r,N_c,1] + P_local = P_local / P_denom # [B,N_r,N_c,bins] + P_local = P_local.unsqueeze(3) # [B,N_r,N_c,1,bins] + K_eff = 1 # noqa: N806 + else: + P_local = target_local # [B,N_r,N_c,K,bins] + K_eff = K # noqa: N806 + + # --- Vectorized cross-entropy for all (k, d) pairs --- + log_Q_local = torch.nn.functional.log_softmax(pred_local, dim=-1) # [B,N_r,N_c,D,bins] + softmax_local = log_Q_local.exp() # save for backward + + # Expand P: [B,N_r,N_c,K_eff,bins] → [B,N_r,N_c,K_eff,D,bins] + P_expanded = P_local.unsqueeze(4).expand(-1, -1, -1, -1, D, -1) + # Expand log_Q: [B,N_r,N_c,D,bins] → [B,N_r,N_c,K_eff,D,bins] + log_Q_expanded = log_Q_local.unsqueeze(3).expand(-1, -1, -1, K_eff, -1, -1) + + # Cross-entropy: -sum(P * log_Q, dim=bins) → [B,N_r,N_c,K_eff,D] + errors_local = -(P_expanded * log_Q_expanded).sum(dim=-1) + + # --- Flatten K_eff*D, apply mask, spatial reduction --- + errors_flat = errors_local.reshape( + errors_local.shape[0], errors_local.shape[1], errors_local.shape[2], K_eff * D + ) # [B,N_r,N_c,K_eff*D] + mask_exp = mask_local.unsqueeze(-1).expand(-1, -1, -1, K_eff * D) # [B,N_r,N_c,K_eff*D] + masked = errors_flat * mask_exp # [B,N_r,N_c,K_eff*D] + + # --- Separate CP reductions for total and denom --- + total_local = masked.sum(dim=(1, 2)) # [B_local, K_eff*D] + dist.all_reduce(total_local, op=dist.ReduceOp.SUM, group=cp_group) + denom_work.wait() + denom_local = denom_local + 1e-5 # [B_local] + + # --- Divide by denom --- + total_local = total_local / denom_local.unsqueeze(-1) # [B_local, K_eff*D] + + # --- Min over D, mean over K_eff --- + batch_loss_kd = total_local.reshape(total_local.shape[0], K_eff, D) # [B_local, K_eff, D] + min_result = torch.min(batch_loss_kd, dim=-1) # values: [B_local, K_eff], indices saved + batch_loss_k = min_result.values + min_indices = min_result.indices # [B_local, K_eff] — needed for backward + + batch_loss_local = batch_loss_k.sum(dim=-1) / K_eff # [B_local] + + # --- Global loss: mean over batch (all_reduce over DP dim) --- + B_global = pred.shape[0] # noqa: N806 + global_loss_local = batch_loss_local.sum(dim=0, keepdim=False) # scalar + dp_group = device_mesh.get_group(0) + dist.all_reduce(global_loss_local, op=dist.ReduceOp.SUM, group=dp_group) + global_loss_local = global_loss_local / B_global + + # --- Save for backward --- + if pred.requires_grad: + ctx.save_for_backward( + softmax_local, # [B,N_r,N_c,D,bins] + P_local, # [B,N_r,N_c,K_eff,bins] + mask_local, # [B,N_r,N_c] + denom_local, # [B_local] + min_indices, # [B_local, K_eff] + ) + ctx.D = D + ctx.K_eff = K_eff + ctx.B_global = B_global + ctx.device_mesh = device_mesh + ctx.pred_placements = pred.placements + ctx.pred_shape = pred.shape + ctx.pred_stride = pred.stride() + + # --- Wrap results as DTensors --- + # batch_loss: [B] sharded on DP dim + batch_loss_placements = (Shard(0), Replicate(), Replicate()) + bl_shape = (B_global,) + bl_stride = (1,) + batch_loss_dt = DTensor.from_local( + batch_loss_local, device_mesh, batch_loss_placements, shape=bl_shape, stride=bl_stride + ) + + # global_loss: scalar, replicated + global_loss_placements = (Replicate(), Replicate(), Replicate()) + gl_shape = () + gl_stride = () + global_loss_dt = DTensor.from_local( + global_loss_local, device_mesh, global_loss_placements, shape=gl_shape, stride=gl_stride + ) + + return global_loss_dt, batch_loss_dt + + @staticmethod + def backward( + ctx: FunctionCtx, d_global_loss: DTensor, d_batch_loss: DTensor + ) -> tuple[DTensor | None, None, None, None, None]: + """Backward pass — entirely local, no collective communication. + + The forward's all_reduces produce replicated results. Their backward is + broadcast (expand), which is local since each rank only fills its own chunk. + """ + if not ctx.needs_input_grad[0]: + return None, None, None, None, None + + softmax_local, P_local, mask_local, denom_local, min_indices = ctx.saved_tensors + D = ctx.D # noqa: N806 + K_eff = ctx.K_eff # noqa: N806 + B_global = ctx.B_global # noqa: N806 + device_mesh = ctx.device_mesh + + B_local = softmax_local.shape[0] # noqa: N806 + N_row = softmax_local.shape[1] # noqa: N806 + N_col = softmax_local.shape[2] # noqa: N806 + + # d_global_loss is Replicate scalar (DTensor), d_batch_loss may be + # DTensor, plain Tensor, or None depending on what the caller backward'd through. + compute_dtype = softmax_local.dtype + d_gl = (d_global_loss.to_local() if isinstance(d_global_loss, DTensor) else d_global_loss).to(compute_dtype) + if d_batch_loss is None: + d_bl = torch.zeros(B_local, device=softmax_local.device, dtype=compute_dtype) + elif isinstance(d_batch_loss, DTensor): + d_bl = d_batch_loss.to_local().to(compute_dtype) + else: + d_bl = d_batch_loss.to(compute_dtype) + + # ----------------------------------------------------------------- + # Chain rule from global_loss = sum(batch_loss) / B_global + # d_batch_loss_from_global = d_gl / B_global (broadcast to [B_local]) + # total d_batch_loss_local = d_bl + d_gl / B_global + # ----------------------------------------------------------------- + d_batch_local = d_bl + d_gl / B_global # [B_local] + + # ----------------------------------------------------------------- + # Backward through mean over K_eff: batch_loss = sum(batch_loss_k) / K_eff + # d_batch_loss_k = d_batch_local / K_eff → [B_local, K_eff] + # ----------------------------------------------------------------- + d_batch_loss_k = (d_batch_local / K_eff).unsqueeze(-1).expand(-1, K_eff) # [B_local, K_eff] + + # ----------------------------------------------------------------- + # Backward through min over D: scatter gradient to argmin index + # d_batch_loss_kd[b, k, d] = d_batch_loss_k[b, k] if d == min_indices[b, k] else 0 + # → [B_local, K_eff, D] + # ----------------------------------------------------------------- + d_batch_loss_kd = torch.zeros(B_local, K_eff, D, device=d_batch_local.device, dtype=d_batch_local.dtype) + d_batch_loss_kd.scatter_(-1, min_indices.unsqueeze(-1), d_batch_loss_k.unsqueeze(-1)) + + # ----------------------------------------------------------------- + # Backward through reshape [B,K_eff,D] → [B,K_eff*D] + # and division by denom: total = total_pre_div / denom + # d_total_pre_div = d_total / denom + # ----------------------------------------------------------------- + d_total = d_batch_loss_kd.reshape(B_local, K_eff * D) # [B_local, K_eff*D] + d_total = d_total / denom_local.unsqueeze(-1) # [B_local, K_eff*D] + + # ----------------------------------------------------------------- + # Backward through spatial sum + all_reduce: + # forward: total = all_reduce(sum(masked, dim=(1,2))) + # backward of sum is broadcast: d_masked = d_total expanded to [B,N_r,N_c,K_eff*D] + # backward of all_reduce(SUM) is identity (each rank gets the full gradient + # and needs to compute its local contribution — no extra comm needed) + # ----------------------------------------------------------------- + d_masked = d_total.unsqueeze(1).unsqueeze(2).expand(-1, N_row, N_col, -1) # [B,N_r,N_c,K_eff*D] + + # Backward through mask multiply: d_errors_flat = d_masked * mask_exp + mask_exp = mask_local.unsqueeze(-1).expand(-1, -1, -1, K_eff * D) + d_errors_flat = d_masked * mask_exp # [B,N_r,N_c,K_eff*D] + + # Reshape to [B,N_r,N_c,K_eff,D] + d_errors = d_errors_flat.reshape(B_local, N_row, N_col, K_eff, D) + + # ----------------------------------------------------------------- + # Backward through cross-entropy: errors = -sum(P * log_Q, dim=-1) + # d_log_Q_expanded = -P_expanded * d_errors.unsqueeze(-1) + # Sum over K_eff to get d_log_Q + # ----------------------------------------------------------------- + P_expanded = P_local.unsqueeze(4).expand(-1, -1, -1, -1, D, -1) # [B,N_r,N_c,K_eff,D,bins] + d_log_Q_expanded = -P_expanded * d_errors.unsqueeze(-1) # [B,N_r,N_c,K_eff,D,bins] + d_log_Q = d_log_Q_expanded.sum(dim=3) # [B,N_r,N_c,D,bins] + + # ----------------------------------------------------------------- + # Backward through log_softmax: + # If y = log_softmax(x), then dy/dx = diag(1) - softmax(x) + # So d_pred = d_log_Q - softmax * sum(d_log_Q, dim=-1, keepdim=True) + # ----------------------------------------------------------------- + d_pred_local = d_log_Q - softmax_local * d_log_Q.sum(dim=-1, keepdim=True) + + # --- Wrap gradient as DTensor --- + d_pred = DTensor.from_local( + d_pred_local, + device_mesh=device_mesh, + placements=ctx.pred_placements, + shape=ctx.pred_shape, + stride=ctx.pred_stride, + ) + + return d_pred, None, None, None, None + + +def distogram_loss( + output: dict[str, DTensor], + feats: dict[str, DTensor], + comm: TransposeComm, + aggregate_distogram: bool = True, +) -> tuple[DTensor, DTensor]: + """Compute the distogram loss using a single fused autograd.Function. + + Both aggregate and non-aggregate modes share a unified pipeline: + 1. Prepare target P and set K_eff (differs by mode) + 2. Expand P and log_Q to [B, N, N, K_eff, D, bins] + 3. Compute cross-entropy for all (k, d) pairs + 4. Mask + spatial reduce with async denom + total all_reduce → [B, K_eff*D] + 5. Divide by denom → [B, K_eff*D] + 6. Unflatten → [B, K_eff, D], min over D, mean over K_eff → [B] + 7. Global loss (mean over DP-sharded batch dim) + + All local math is done in a single autograd.Function with explicit + all_reduce calls, avoiding the overhead of ~15 separate DTensor + autograd.Function round-trips. + + Boltz-1 (v1) compatibility: + The v1 serial loss (src/boltz/model/loss/distogram.py) uses 4D tensors + [B, N, N, bins] with no D or K axes. To use this function as a v1 loss, + unsqueeze dim 3 of both pred and target before passing them in:: + + output["pdistogram"] = pred_4d.unsqueeze(3) # [B,N,N,bins] → [B,N,N,1,bins] + feats["disto_target"] = target_4d.unsqueeze(3) # [B,N,N,bins] → [B,N,N,1,bins] + distogram_loss(output, feats, comm, aggregate_distogram=True) + + With D=1 and K=1, the min-over-D and mean-over-K steps are identity ops, + producing results identical to v1. + + Parameters + ---------- + output : dict[str, DTensor] + Output of the model containing: + - "pdistogram": [B, N, N, D, bins] prediction logits (DTensor) + feats : dict[str, DTensor] + Input features containing: + - "disto_target": [B, N, N, K, bins] target distributions (DTensor) + - "token_disto_mask": [B, N] token validity mask (DTensor) + comm : TransposeComm + Communication object for CP transpose and reduction. + aggregate_distogram : bool + If True, aggregates target over K conformers into a single normalized + distribution (K_eff=1). Works with any D (generalizes serial D=1 constraint). + If False, computes per-conformer loss (K_eff=K) and takes min over D. + + Returns + ------- + DTensor + The globally averaged loss (scalar DTensor). + DTensor + Per-example loss [B] (DTensor). + """ + with torch.autocast("cuda", enabled=False): + return _DistogramLossCP.apply( + output["pdistogram"], + feats["disto_target"], + feats["token_disto_mask"], + comm, + aggregate_distogram, + ) diff --git a/src/boltz/distributed/model/loss/triton/__init__.py b/src/boltz/distributed/model/loss/triton/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/src/boltz/distributed/model/loss/triton/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/src/boltz/distributed/model/loss/triton/cdist_lddt.py b/src/boltz/distributed/model/loss/triton/cdist_lddt.py new file mode 100644 index 000000000..d45c5d66c --- /dev/null +++ b/src/boltz/distributed/model/loss/triton/cdist_lddt.py @@ -0,0 +1,784 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _cdist_lddt_kernel( + # Pointers to inputs + pred_coords_row, + pred_coords_col, + true_coords_row, + true_coords_col, + mask_row, + mask_col, + atom_indices_row, # Can be None (use implicit arange) + atom_indices_col, # Can be None (use implicit arange) + cutoff_col_ptr, # Can be None (use scalar cutoff) + out_num, + out_denom, + # Shape and strides for pred_coords_row [B_mul, N_row, 3] + B_mul_pred_coords_row, + N_pred_coords_row, + stride_pred_coords_row_b, + stride_pred_coords_row_n, + stride_pred_coords_row_d, + # Shape and strides for pred_coords_col [B_mul, N_col, 3] + B_mul_pred_coords_col, + N_pred_coords_col, + stride_pred_coords_col_b, + stride_pred_coords_col_n, + stride_pred_coords_col_d, + # Shape and strides for true_coords_row [B_mul, N_row, 3] + B_mul_true_coords_row, + N_true_coords_row, + stride_true_coords_row_b, + stride_true_coords_row_n, + stride_true_coords_row_d, + # Shape and strides for true_coords_col [B_mul, N_col, 3] + B_mul_true_coords_col, + N_true_coords_col, + stride_true_coords_col_b, + stride_true_coords_col_n, + stride_true_coords_col_d, + # Shape and strides for mask_row [B, N_row] + B_mask_row, + N_mask_row, + stride_mask_row_b, + stride_mask_row_n, + # Shape and strides for mask_col [B, N_col] + B_mask_col, + N_mask_col, + stride_mask_col_b, + stride_mask_col_n, + # Shape and strides for atom_indices_row [B, N_row] (only used if USE_EXPLICIT_INDICES_ROW) + B_atom_indices_row, + N_atom_indices_row, + stride_atom_indices_row_b, + stride_atom_indices_row_n, + # Shape and strides for atom_indices_col [B, N_col] (only used if USE_EXPLICIT_INDICES_COL) + B_atom_indices_col, + N_atom_indices_col, + stride_atom_indices_col_b, + stride_atom_indices_col_n, + # Shape and strides for cutoff_col [B, N_col] (only used if USE_CUTOFF_COL) + B_cutoff_col, + N_cutoff_col, + stride_cutoff_col_b, + stride_cutoff_col_n, + # Shape and strides for out_num [B_mul] or [B_mul, N_row] if PER_ATOM + B_mul_out_num, + stride_out_num_b, + stride_out_num_n, # Only used if PER_ATOM + # Shape and strides for out_denom [B_mul] or [B_mul, N_row] if PER_ATOM + B_mul_out_denom, + stride_out_denom_b, + stride_out_denom_n, # Only used if PER_ATOM + # Constants + cutoff, + # Block sizes + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + # Coordinate dimension size (e.g., 3 for 3D) + SIZE_DIM_D: tl.constexpr, + # Flags for implicit arange (avoids creating torch.arange tensors) + USE_EXPLICIT_INDICES_ROW: tl.constexpr, # If False, use offs_m as indices + USE_EXPLICIT_INDICES_COL: tl.constexpr, # If False, use offs_n as indices + # Flag for per-column cutoff + USE_CUTOFF_COL: tl.constexpr, # If True, use per-column cutoff values instead of scalar + # Flag for diagonal masking + DO_MASK_DIAGONAL: tl.constexpr, # If True, exclude self-pairs where indices match + # Flag for per-atom output + PER_ATOM: tl.constexpr, # If True, output per-row scores; otherwise total score + # Memory layout orders for make_block_ptr (computed from argsort of strides) + # For 3D coord blocks [1, N, D], order is a tuple of 3 indices + ORDER_PRED_COORDS_ROW_0: tl.constexpr, + ORDER_PRED_COORDS_ROW_1: tl.constexpr, + ORDER_PRED_COORDS_ROW_2: tl.constexpr, + ORDER_PRED_COORDS_COL_0: tl.constexpr, + ORDER_PRED_COORDS_COL_1: tl.constexpr, + ORDER_PRED_COORDS_COL_2: tl.constexpr, + ORDER_TRUE_COORDS_ROW_0: tl.constexpr, + ORDER_TRUE_COORDS_ROW_1: tl.constexpr, + ORDER_TRUE_COORDS_ROW_2: tl.constexpr, + ORDER_TRUE_COORDS_COL_0: tl.constexpr, + ORDER_TRUE_COORDS_COL_1: tl.constexpr, + ORDER_TRUE_COORDS_COL_2: tl.constexpr, + # For 2D mask blocks [1, N], order is a tuple of 2 indices + ORDER_MASK_ROW_0: tl.constexpr, + ORDER_MASK_ROW_1: tl.constexpr, + ORDER_MASK_COL_0: tl.constexpr, + ORDER_MASK_COL_1: tl.constexpr, + # For 2D atom_indices blocks [1, N], order is a tuple of 2 indices (only used if USE_EXPLICIT_INDICES_*) + ORDER_ATOM_INDICES_ROW_0: tl.constexpr, + ORDER_ATOM_INDICES_ROW_1: tl.constexpr, + ORDER_ATOM_INDICES_COL_0: tl.constexpr, + ORDER_ATOM_INDICES_COL_1: tl.constexpr, + # For 2D cutoff_col blocks [1, N], order is a tuple of 2 indices (only used if USE_CUTOFF_COL) + ORDER_CUTOFF_COL_0: tl.constexpr, + ORDER_CUTOFF_COL_1: tl.constexpr, + # Flags for mask multiplicity (whether masks have B*mul or B batch dimension) + MASK_ROW_HAS_MUL: tl.constexpr, # If True, mask_row has [B*mul, N] shape + MASK_COL_HAS_MUL: tl.constexpr, # If True, mask_col has [B*mul, N] shape + # Explicit multiplicity factor + MULTIPLICITY: tl.constexpr, +): + # Validate coordinate dimension at compile time + tl.static_assert(SIZE_DIM_D == 3, "SIZE_DIM_D must be 3 (3D coordinates)") + + # Next power of 2 for coordinate dimension (required by tl.arange) + # Since SIZE_DIM_D is always 3, BLOCK_D is 4 (next power of 2) + BLOCK_D: tl.constexpr = 4 + + # 1. Grid Identification + pid_batch = tl.program_id(0) # Handles flattened batch * multiplicity + pid_m = tl.program_id(1) # Row block index + pid_n = tl.program_id(2) # Col block index + + # 2. Multiplicity & Broadcasting + # Use explicit multiplicity parameter + # Determine which sample in the original [B] batch this corresponds to + batch_idx = pid_batch // MULTIPLICITY + # Batch indices for masks (depends on whether masks have multiplicity) + batch_idx_mask_row = pid_batch if MASK_ROW_HAS_MUL else batch_idx + batch_idx_mask_col = pid_batch if MASK_COL_HAS_MUL else batch_idx + + # 3. Create block pointers using make_block_ptr (full-dimensional to capture stride-1 axis) + # For coordinate tensors: 3D block [1, BLOCK_M, BLOCK_D] from [B_mul, N_row, 3] + # Including batch dimension ensures order tuple captures all strides including if batch has stride=1 + pred_row_block_ptr = tl.make_block_ptr( + base=pred_coords_row, + shape=(B_mul_pred_coords_row, N_pred_coords_row, SIZE_DIM_D), + strides=(stride_pred_coords_row_b, stride_pred_coords_row_n, stride_pred_coords_row_d), + offsets=(pid_batch, pid_m * BLOCK_M, 0), + block_shape=(1, BLOCK_M, BLOCK_D), + order=(ORDER_PRED_COORDS_ROW_0, ORDER_PRED_COORDS_ROW_1, ORDER_PRED_COORDS_ROW_2), + ) + pred_col_block_ptr = tl.make_block_ptr( + base=pred_coords_col, + shape=(B_mul_pred_coords_col, N_pred_coords_col, SIZE_DIM_D), + strides=(stride_pred_coords_col_b, stride_pred_coords_col_n, stride_pred_coords_col_d), + offsets=(pid_batch, pid_n * BLOCK_N, 0), + block_shape=(1, BLOCK_N, BLOCK_D), + order=(ORDER_PRED_COORDS_COL_0, ORDER_PRED_COORDS_COL_1, ORDER_PRED_COORDS_COL_2), + ) + true_row_block_ptr = tl.make_block_ptr( + base=true_coords_row, + shape=(B_mul_true_coords_row, N_true_coords_row, SIZE_DIM_D), + strides=(stride_true_coords_row_b, stride_true_coords_row_n, stride_true_coords_row_d), + offsets=(pid_batch, pid_m * BLOCK_M, 0), + block_shape=(1, BLOCK_M, BLOCK_D), + order=(ORDER_TRUE_COORDS_ROW_0, ORDER_TRUE_COORDS_ROW_1, ORDER_TRUE_COORDS_ROW_2), + ) + true_col_block_ptr = tl.make_block_ptr( + base=true_coords_col, + shape=(B_mul_true_coords_col, N_true_coords_col, SIZE_DIM_D), + strides=(stride_true_coords_col_b, stride_true_coords_col_n, stride_true_coords_col_d), + offsets=(pid_batch, pid_n * BLOCK_N, 0), + block_shape=(1, BLOCK_N, BLOCK_D), + order=(ORDER_TRUE_COORDS_COL_0, ORDER_TRUE_COORDS_COL_1, ORDER_TRUE_COORDS_COL_2), + ) + + # For mask tensors: 2D block [1, BLOCK_M] from [B or B*mul, N_row] + # Uses batch_idx_mask_row/col which is pid_batch if mask has multiplicity, else batch_idx + mask_row_block_ptr = tl.make_block_ptr( + base=mask_row, + shape=(B_mask_row, N_mask_row), + strides=(stride_mask_row_b, stride_mask_row_n), + offsets=(batch_idx_mask_row, pid_m * BLOCK_M), + block_shape=(1, BLOCK_M), + order=(ORDER_MASK_ROW_0, ORDER_MASK_ROW_1), + ) + mask_col_block_ptr = tl.make_block_ptr( + base=mask_col, + shape=(B_mask_col, N_mask_col), + strides=(stride_mask_col_b, stride_mask_col_n), + offsets=(batch_idx_mask_col, pid_n * BLOCK_N), + block_shape=(1, BLOCK_N), + order=(ORDER_MASK_COL_0, ORDER_MASK_COL_1), + ) + + # For atom_indices tensors: 2D block [1, BLOCK_M/N] from [B, N_row/col] + # Uses batch_idx for broadcasting (same as masks) + # Only create block pointers if explicit indices are used + if USE_EXPLICIT_INDICES_ROW: + atom_indices_row_block_ptr = tl.make_block_ptr( + base=atom_indices_row, + shape=(B_atom_indices_row, N_atom_indices_row), + strides=(stride_atom_indices_row_b, stride_atom_indices_row_n), + offsets=(batch_idx, pid_m * BLOCK_M), + block_shape=(1, BLOCK_M), + order=(ORDER_ATOM_INDICES_ROW_0, ORDER_ATOM_INDICES_ROW_1), + ) + if USE_EXPLICIT_INDICES_COL: + atom_indices_col_block_ptr = tl.make_block_ptr( + base=atom_indices_col, + shape=(B_atom_indices_col, N_atom_indices_col), + strides=(stride_atom_indices_col_b, stride_atom_indices_col_n), + offsets=(batch_idx, pid_n * BLOCK_N), + block_shape=(1, BLOCK_N), + order=(ORDER_ATOM_INDICES_COL_0, ORDER_ATOM_INDICES_COL_1), + ) + + # 5. Load Data using block pointers with boundary_check and padding_option + # boundary_check specifies which dimensions to check for out-of-bounds + # padding_option="zero" fills out-of-bounds values with 0 + # Only check dims 1 and 2 (N and D) for coords; dim 0 (batch) is always in bounds + # Reshape to squeeze the batch dimension of size 1: [1, BLOCK_M, BLOCK_D] -> [BLOCK_M, BLOCK_D] + pred_row = tl.reshape( + tl.load(pred_row_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_M, BLOCK_D), + ) + pred_col = tl.reshape( + tl.load(pred_col_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_N, BLOCK_D), + ) + true_row = tl.reshape( + tl.load(true_row_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_M, BLOCK_D), + ) + true_col = tl.reshape( + tl.load(true_col_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_N, BLOCK_D), + ) + + # Only check dim 1 (N) for masks; dim 0 (batch) is always in bounds + # Reshape to squeeze: [1, BLOCK_M] -> [BLOCK_M] + m_row = tl.reshape( + tl.load(mask_row_block_ptr, boundary_check=(1,), padding_option="zero"), + (BLOCK_M,), + ).to(tl.int1) + m_col = tl.reshape( + tl.load(mask_col_block_ptr, boundary_check=(1,), padding_option="zero"), + (BLOCK_N,), + ).to(tl.int1) + + # 5. Distance Computation (Pairwise) + # dist = sqrt(sum((x[:, None] - y[None, :])^2)) + # Expand dims for broadcasting: [BLOCK_M, 1, 3] - [1, BLOCK_N, 3] + delta_pred = pred_row[:, None, :] - pred_col[None, :, :] + d_pred_sq = tl.sum(delta_pred * delta_pred, axis=2) + d_pred = tl.sqrt(d_pred_sq) + + delta_true = true_row[:, None, :] - true_col[None, :, :] + d_true_sq = tl.sum(delta_true * delta_true, axis=2) + d_true = tl.sqrt(d_true_sq) + + # 6. Validity Masking + # Base validity: both atoms resolved (out-of-bounds already zeroed by masked load) + valid = m_row[:, None] & m_col[None, :] + + # Offsets for Row/Col dimensions (needed for diagonal masking and per_atom accumulation) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + boundary_mask_m = offs_m < N_pred_coords_row + boundary_mask_n = offs_n < N_pred_coords_col + + # 7. Diagonal Masking (conditional on DO_MASK_DIAGONAL) + if DO_MASK_DIAGONAL: + # Get row indices: either from explicit tensor or use implicit arange (offs_m) + # atom_indices use batch_idx (same as masks) for [B, N] broadcasting + if USE_EXPLICIT_INDICES_ROW: + # Load using block pointer with boundary_check and padding_option + # padding_option=-1 to distinguish from valid indices (which are >= 0) + idx_row = tl.reshape( + tl.load(atom_indices_row_block_ptr, boundary_check=(1,), padding_option="zero"), + (BLOCK_M,), + ) + # Set out-of-bounds indices to -1 (won't match any valid column index) + idx_row = tl.where(boundary_mask_m, idx_row, -1) + else: + # Use offs_m directly as indices (equivalent to torch.arange(N_row)) + idx_row = offs_m + + # Get col indices: either from explicit tensor or use implicit arange (offs_n) + if USE_EXPLICIT_INDICES_COL: + # Load using block pointer with boundary_check and padding_option + # padding_option=-2 to distinguish from valid indices and row padding + idx_col = tl.reshape( + tl.load(atom_indices_col_block_ptr, boundary_check=(1,), padding_option="zero"), + (BLOCK_N,), + ) + # Set out-of-bounds indices to -2 (won't match any valid row index) + idx_col = tl.where(boundary_mask_n, idx_col, -2) + else: + # Use offs_n directly as indices (equivalent to torch.arange(N_col)) + idx_col = offs_n + + # Apply diagonal mask (exclude self-pairs where indices match) + is_diagonal = idx_row[:, None] == idx_col[None, :] + valid = valid & (~is_diagonal) + + # 7. Scoring + # Only score pairs where true distance < cutoff + # If USE_CUTOFF_COL, use per-column cutoff values [B, N_col]; otherwise use scalar cutoff + if USE_CUTOFF_COL: + # Load per-column cutoff values for this block using block pointer + # Uses batch_idx for broadcasting (B_mul -> B), same as masks + cutoff_col_block_ptr = tl.make_block_ptr( + base=cutoff_col_ptr, + shape=(B_cutoff_col, N_cutoff_col), + strides=(stride_cutoff_col_b, stride_cutoff_col_n), + offsets=(batch_idx, pid_n * BLOCK_N), + block_shape=(1, BLOCK_N), + order=(ORDER_CUTOFF_COL_0, ORDER_CUTOFF_COL_1), + ) + cutoff_vals = tl.reshape( + tl.load(cutoff_col_block_ptr, boundary_check=(1,), padding_option="zero"), + (BLOCK_N,), + ) + # Broadcast [BLOCK_N] over [BLOCK_M, BLOCK_N]: cutoff_vals[None, :] is [1, BLOCK_N] + within_cutoff = d_true < cutoff_vals[None, :] + else: + within_cutoff = d_true < cutoff + active = valid & within_cutoff + + diff = tl.abs(d_pred - d_true) + score_05 = (diff < 0.5).to(tl.float32) + score_10 = (diff < 1.0).to(tl.float32) + score_20 = (diff < 2.0).to(tl.float32) + score_40 = (diff < 4.0).to(tl.float32) + + score_tile = 0.25 * (score_05 + score_10 + score_20 + score_40) + + # Zero out invalid entries + score_tile = tl.where(active, score_tile, 0.0) + denom_tile = tl.where(active, 1.0, 0.0) + + # 8. Accumulation + if PER_ATOM: + # Sum over columns (BLOCK_N dimension) for per-row scores + num_per_row = tl.sum(score_tile, axis=1) # [BLOCK_M] + denom_per_row = tl.sum(denom_tile, axis=1) # [BLOCK_M] + + # Atomic Add to global buffers [B_mul, N_row] + out_ptr_num = out_num + pid_batch * stride_out_num_b + offs_m * stride_out_num_n + out_ptr_denom = out_denom + pid_batch * stride_out_denom_b + offs_m * stride_out_denom_n + + tl.atomic_add(out_ptr_num, num_per_row, mask=boundary_mask_m, sem="relaxed") + tl.atomic_add(out_ptr_denom, denom_per_row, mask=boundary_mask_m, sem="relaxed") + else: + # Sum over the entire tile + num_sum = tl.sum(score_tile) + denom_sum = tl.sum(denom_tile) + + # Atomic Add to global buffers [B_mul] + tl.atomic_add(out_num + pid_batch * stride_out_num_b, num_sum, sem="relaxed") + tl.atomic_add(out_denom + pid_batch * stride_out_denom_b, denom_sum, sem="relaxed") + + +def cdist_lddt( + pred_coords_row, # [B_mul, N_row, 3] (B_mul = B * multiplicity) + pred_coords_col, # [B_mul, N_col, 3] + true_coords_row, # [B_mul, N_row, 3] + true_coords_col, # [B_mul, N_col, 3] + mask_row, # [B, N_row] or [B_mul, N_row] + mask_col, # [B, N_col] or [B_mul, N_col] + multiplicity, # Required: explicit multiplicity factor (B_mul = B * multiplicity) + atom_indices_row=None, # [B, N_row] (optional, for diagonal masking) + atom_indices_col=None, # [B, N_col] (optional, for diagonal masking) + cutoff=15.0, + cutoff_col=None, # [B, N_col] (optional, per-column cutoff values per batch) + eps=1e-10, + do_mask_diagonal=True, # If True, exclude self-pairs where indices match + return_unnormalized_score=False, # If True, return (out_num, out_denom) before normalization + per_atom=False, # If True, return per-row lDDT scores and mask_no_match + return_denom=False, # If True, also return the denominator (pair counts) +): + """ + Computes lDDT score directly from coordinates without materializing distance matrices. + Handles rectangular inputs and broadcasting of multiplicity. + + Note: Inputs must share a dtype. Computation runs in at least float32 + (via torch.promote_types with float32), and outputs are cast back to the input dtype. + + Parameters + ---------- + pred_coords_row : torch.Tensor + Predicted coordinates for row atoms, shape [B_mul, N_row, 3] + pred_coords_col : torch.Tensor + Predicted coordinates for column atoms, shape [B_mul, N_col, 3] + true_coords_row : torch.Tensor + True coordinates for row atoms, shape [B_mul, N_row, 3] + true_coords_col : torch.Tensor + True coordinates for column atoms, shape [B_mul, N_col, 3] + mask_row : torch.Tensor + Resolved mask for row atoms, shape [B, N_row] or [B*mul, N_row]. + If [B, N_row], broadcasts to B_mul. If [B*mul, N_row], used directly. + mask_col : torch.Tensor + Resolved mask for column atoms, shape [B, N_col] or [B*mul, N_col]. + If [B, N_col], broadcasts to B_mul. If [B*mul, N_col], used directly. + multiplicity : int + Required. Explicit multiplicity factor where B_mul = B * multiplicity. + This determines the base batch size B for validating tensor shapes. + atom_indices_row : torch.Tensor, optional + Explicit indices for row atoms, shape [B, N_row]. If None, uses implicit arange. + The batch dimension B matches mask_row (not B_mul), enabling per-sample indices. + atom_indices_col : torch.Tensor, optional + Explicit indices for column atoms, shape [B, N_col]. If None, uses implicit arange. + The batch dimension B matches mask_col (not B_mul), enabling per-sample indices. + cutoff : float + Distance cutoff for lDDT scoring (default 15.0). Used as fallback if cutoff_col is None. + cutoff_col : torch.Tensor, optional + Per-column distance cutoff values, shape [B, N_col]. If provided, overrides scalar cutoff. + The batch dimension B matches mask_col (not B_mul), enabling per-sample cutoffs. + This enables nucleotide-dependent cutoffs (e.g., 15.0 for protein, 30.0 for nucleotides). + eps : float + Small epsilon for numerical stability (default 1e-10) + do_mask_diagonal : bool + If True (default), exclude self-pairs where atom indices match. + return_unnormalized_score : bool + If True, return raw (out_num, out_denom) before normalization. + This is useful for distributed aggregation where partial sums from different + shards need to be allreduced before computing the final normalized score. + Can be combined with per_atom to control output shape. + If False (default), return normalized lDDT score. + per_atom : bool + If True, return per-row lDDT scores [B_mul, N_row] and mask_no_match [B_mul, N_row]. + If False (default), return total lDDT score [B_mul]. + Can be combined with return_unnormalized_score. + return_denom : bool + If True, also return the denominator (pair counts) used for normalization. + Not valid when return_unnormalized_score=True. + + Returns + ------- + If return_unnormalized_score=True and per_atom=True: + out_num : torch.Tensor + Per-row unnormalized sum of scores, shape [B_mul, N_row]. + out_denom : torch.Tensor + Per-row sum of valid pair counts, shape [B_mul, N_row]. + mask_no_match : torch.Tensor + Mask indicating rows with valid pairs, shape [B_mul, N_row]. + Note: For distributed allreduce on out_num and out_denom, then normalize manually: + norm = 1.0 / (eps + allreduced_denom) + score = norm * (eps + allreduced_num) + # mask_no_match should be allreduced with logical OR across shards + If return_unnormalized_score=True and per_atom=False: + out_num : torch.Tensor + Unnormalized sum of scores, shape [B_mul]. + out_denom : torch.Tensor + Sum of valid pair counts, shape [B_mul]. + If return_unnormalized_score=False and per_atom=False: + score : torch.Tensor + lDDT score per batch element, shape [B_mul] + denom : torch.Tensor, optional + Pair counts per batch element, shape [B_mul] + If return_unnormalized_score=False and per_atom=True: + score : torch.Tensor + Per-row lDDT score, shape [B_mul, N_row] + mask_no_match : torch.Tensor + Boolean mask indicating rows with valid pairs, shape [B_mul, N_row] + denom : torch.Tensor, optional + Pair counts per row, shape [B_mul, N_row] + """ + # Extract reference shapes from primary tensors + B_mul, N_row, dim_d = pred_coords_row.shape + _, N_col, _ = pred_coords_col.shape + B_or_B_mul_mask_row, _ = mask_row.shape + B_or_B_mul_mask_col, _ = mask_col.shape + + # Validate coordinate dimension (must be 3D coordinates) + if dim_d != 3: + raise ValueError(f"Coordinate dimension must be 3, got {dim_d}") + + if return_unnormalized_score and return_denom: + raise ValueError("return_denom is not valid when return_unnormalized_score=True") + + # Compute B from explicit multiplicity + if B_mul % multiplicity != 0: + raise ValueError(f"Coordinate batch dimension ({B_mul}) must be divisible by multiplicity ({multiplicity})") + B = B_mul // multiplicity + + # Check if masks have multiplicity + mask_row_has_mul = B_or_B_mul_mask_row == B_mul + mask_col_has_mul = B_or_B_mul_mask_col == B_mul + + # Validate mask shapes: must be either (B, N) or (B_mul, N) + if B_or_B_mul_mask_row != B and B_or_B_mul_mask_row != B_mul: + raise ValueError(f"mask_row batch dimension ({B_or_B_mul_mask_row}) must be either B ({B}) or B_mul ({B_mul})") + if B_or_B_mul_mask_col != B and B_or_B_mul_mask_col != B_mul: + raise ValueError(f"mask_col batch dimension ({B_or_B_mul_mask_col}) must be either B ({B}) or B_mul ({B_mul})") + + # Note: return_unnormalized_score and per_atom are orthogonal options: + # - per_atom controls output shape: [B_mul, N_row] vs [B_mul] + # - return_unnormalized_score controls whether to return raw (out_num, out_denom) vs normalized score + + # Expected shapes for all tensors + coord_row_shape = (B_mul, N_row, dim_d) + coord_col_shape = (B_mul, N_col, dim_d) + idx_row_shape = (B, N_row) + idx_col_shape = (B, N_col) + + # Validate coordinate tensors + if pred_coords_row.shape != coord_row_shape: + raise ValueError(f"pred_coords_row shape {tuple(pred_coords_row.shape)} must be {coord_row_shape}") + if pred_coords_col.shape != coord_col_shape: + raise ValueError(f"pred_coords_col shape {tuple(pred_coords_col.shape)} must be {coord_col_shape}") + if true_coords_row.shape != coord_row_shape: + raise ValueError(f"true_coords_row shape {tuple(true_coords_row.shape)} must be {coord_row_shape}") + if true_coords_col.shape != coord_col_shape: + raise ValueError(f"true_coords_col shape {tuple(true_coords_col.shape)} must be {coord_col_shape}") + + # Validate mask N dimensions + if mask_row.shape[1] != N_row: + raise ValueError(f"mask_row N dimension ({mask_row.shape[1]}) must be {N_row}") + if mask_col.shape[1] != N_col: + raise ValueError(f"mask_col N dimension ({mask_col.shape[1]}) must be {N_col}") + + # Validate optional index tensors + if atom_indices_row is not None and atom_indices_row.shape != idx_row_shape: + raise ValueError(f"atom_indices_row shape {tuple(atom_indices_row.shape)} must be {idx_row_shape}") + if atom_indices_col is not None and atom_indices_col.shape != idx_col_shape: + raise ValueError(f"atom_indices_col shape {tuple(atom_indices_col.shape)} must be {idx_col_shape}") + + # Validate optional cutoff_col tensor + cutoff_col_shape = (B, N_col) + if cutoff_col is not None and cutoff_col.shape != cutoff_col_shape: + raise ValueError(f"cutoff_col shape {tuple(cutoff_col.shape)} must be {cutoff_col_shape}") + + device = pred_coords_row.device + + if ( + pred_coords_row.dtype != pred_coords_col.dtype + or pred_coords_row.dtype != true_coords_row.dtype + or pred_coords_row.dtype != true_coords_col.dtype + ): + raise ValueError( + "pred/true coords dtypes must match: " + f"row={pred_coords_row.dtype}, col={pred_coords_col.dtype}, " + f"true_row={true_coords_row.dtype}, true_col={true_coords_col.dtype}" + ) + input_dtype = pred_coords_row.dtype + compute_dtype = torch.promote_types(input_dtype, torch.float32) + + if pred_coords_row.dtype != compute_dtype: + pred_coords_row = pred_coords_row.to(compute_dtype) + if pred_coords_col.dtype != compute_dtype: + pred_coords_col = pred_coords_col.to(compute_dtype) + if true_coords_row.dtype != compute_dtype: + true_coords_row = true_coords_row.to(compute_dtype) + if true_coords_col.dtype != compute_dtype: + true_coords_col = true_coords_col.to(compute_dtype) + + # Output buffers + if per_atom: + # Per-row outputs: [B_mul, N_row] + out_num = torch.zeros(B_mul, N_row, device=device, dtype=compute_dtype) + out_denom = torch.zeros(B_mul, N_row, device=device, dtype=compute_dtype) + else: + # Total outputs: [B_mul] + out_num = torch.zeros(B_mul, device=device, dtype=compute_dtype) + out_denom = torch.zeros(B_mul, device=device, dtype=compute_dtype) + + # Block sizes + BLOCK_M = 32 + BLOCK_N = 32 + + grid = (B_mul, triton.cdiv(N_row, BLOCK_M), triton.cdiv(N_col, BLOCK_N)) + + # Determine whether to use explicit indices or implicit arange + use_explicit_indices_row = atom_indices_row is not None + use_explicit_indices_col = atom_indices_col is not None + + # Determine whether to use per-column cutoff + use_cutoff_col = cutoff_col is not None + + # Compute memory layout order for make_block_ptr + # order = argsort of strides (ascending), giving fastest-varying dim first + order_pred_coords_row = tuple(torch.tensor(pred_coords_row.stride()).argsort().tolist()) + order_pred_coords_col = tuple(torch.tensor(pred_coords_col.stride()).argsort().tolist()) + order_true_coords_row = tuple(torch.tensor(true_coords_row.stride()).argsort().tolist()) + order_true_coords_col = tuple(torch.tensor(true_coords_col.stride()).argsort().tolist()) + order_mask_row = tuple(torch.tensor(mask_row.stride()).argsort().tolist()) + order_mask_col = tuple(torch.tensor(mask_col.stride()).argsort().tolist()) + # Compute order for atom_indices if they are provided + order_atom_indices_row = ( + tuple(torch.tensor(atom_indices_row.stride()).argsort().tolist()) + if use_explicit_indices_row + else (0, 1) # Default order, not used when indices are None + ) + order_atom_indices_col = ( + tuple(torch.tensor(atom_indices_col.stride()).argsort().tolist()) + if use_explicit_indices_col + else (0, 1) # Default order, not used when indices are None + ) + # Compute order for cutoff_col if provided + order_cutoff_col = ( + tuple(torch.tensor(cutoff_col.stride()).argsort().tolist()) + if use_cutoff_col + else (0, 1) # Default order, not used when cutoff_col is None + ) + + _cdist_lddt_kernel[grid]( + # Pointers + pred_coords_row, + pred_coords_col, + true_coords_row, + true_coords_col, + mask_row, + mask_col, + atom_indices_row, # Can be None if USE_EXPLICIT_INDICES_ROW=False + atom_indices_col, # Can be None if USE_EXPLICIT_INDICES_COL=False + cutoff_col, # Can be None if USE_CUTOFF_COL=False + out_num, + out_denom, + # Shape and strides for pred_coords_row [B_mul, N_row, 3] + pred_coords_row.shape[0], + pred_coords_row.shape[1], + pred_coords_row.stride(0), + pred_coords_row.stride(1), + pred_coords_row.stride(2), + # Shape and strides for pred_coords_col [B_mul, N_col, 3] + pred_coords_col.shape[0], + pred_coords_col.shape[1], + pred_coords_col.stride(0), + pred_coords_col.stride(1), + pred_coords_col.stride(2), + # Shape and strides for true_coords_row [B_mul, N_row, 3] + true_coords_row.shape[0], + true_coords_row.shape[1], + true_coords_row.stride(0), + true_coords_row.stride(1), + true_coords_row.stride(2), + # Shape and strides for true_coords_col [B_mul, N_col, 3] + true_coords_col.shape[0], + true_coords_col.shape[1], + true_coords_col.stride(0), + true_coords_col.stride(1), + true_coords_col.stride(2), + # Shape and strides for mask_row [B, N_row] + mask_row.shape[0], + mask_row.shape[1], + mask_row.stride(0), + mask_row.stride(1), + # Shape and strides for mask_col [B, N_col] + mask_col.shape[0], + mask_col.shape[1], + mask_col.stride(0), + mask_col.stride(1), + # Shape and strides for atom_indices_row [B, N_row] (only used if USE_EXPLICIT_INDICES_ROW) + atom_indices_row.shape[0] if use_explicit_indices_row else 0, + atom_indices_row.shape[1] if use_explicit_indices_row else 0, + atom_indices_row.stride(0) if use_explicit_indices_row else 0, + atom_indices_row.stride(1) if use_explicit_indices_row else 0, + # Shape and strides for atom_indices_col [B, N_col] (only used if USE_EXPLICIT_INDICES_COL) + atom_indices_col.shape[0] if use_explicit_indices_col else 0, + atom_indices_col.shape[1] if use_explicit_indices_col else 0, + atom_indices_col.stride(0) if use_explicit_indices_col else 0, + atom_indices_col.stride(1) if use_explicit_indices_col else 0, + # Shape and strides for cutoff_col [B, N_col] (only used if USE_CUTOFF_COL is True) + cutoff_col.shape[0] if use_cutoff_col else 0, + cutoff_col.shape[1] if use_cutoff_col else 0, + cutoff_col.stride(0) if use_cutoff_col else 0, + cutoff_col.stride(1) if use_cutoff_col else 0, + # Shape and strides for out_num [B_mul] or [B_mul, N_row] if per_atom + out_num.shape[0], + out_num.stride(0), + out_num.stride(1) if per_atom else 0, + # Shape and strides for out_denom [B_mul] or [B_mul, N_row] if per_atom + out_denom.shape[0], + out_denom.stride(0), + out_denom.stride(1) if per_atom else 0, + # Constants + # When cutoff_col is provided, it overrides the scalar cutoff (pass 0.0 as dummy) + 0.0 if use_cutoff_col else float(cutoff), + # Block sizes + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + # Coordinate dimension size (BLOCK_D computed inside kernel as next power of 2) + SIZE_DIM_D=dim_d, + # Flags for implicit arange + USE_EXPLICIT_INDICES_ROW=use_explicit_indices_row, + USE_EXPLICIT_INDICES_COL=use_explicit_indices_col, + # Flag for per-column cutoff + USE_CUTOFF_COL=use_cutoff_col, + # Flag for diagonal masking + DO_MASK_DIAGONAL=do_mask_diagonal, + # Flag for per-atom output + PER_ATOM=per_atom, + # Memory layout orders for make_block_ptr (3-tuple for coords, 2-tuple for masks) + ORDER_PRED_COORDS_ROW_0=order_pred_coords_row[0], + ORDER_PRED_COORDS_ROW_1=order_pred_coords_row[1], + ORDER_PRED_COORDS_ROW_2=order_pred_coords_row[2], + ORDER_PRED_COORDS_COL_0=order_pred_coords_col[0], + ORDER_PRED_COORDS_COL_1=order_pred_coords_col[1], + ORDER_PRED_COORDS_COL_2=order_pred_coords_col[2], + ORDER_TRUE_COORDS_ROW_0=order_true_coords_row[0], + ORDER_TRUE_COORDS_ROW_1=order_true_coords_row[1], + ORDER_TRUE_COORDS_ROW_2=order_true_coords_row[2], + ORDER_TRUE_COORDS_COL_0=order_true_coords_col[0], + ORDER_TRUE_COORDS_COL_1=order_true_coords_col[1], + ORDER_TRUE_COORDS_COL_2=order_true_coords_col[2], + ORDER_MASK_ROW_0=order_mask_row[0], + ORDER_MASK_ROW_1=order_mask_row[1], + ORDER_MASK_COL_0=order_mask_col[0], + ORDER_MASK_COL_1=order_mask_col[1], + ORDER_ATOM_INDICES_ROW_0=order_atom_indices_row[0], + ORDER_ATOM_INDICES_ROW_1=order_atom_indices_row[1], + ORDER_ATOM_INDICES_COL_0=order_atom_indices_col[0], + ORDER_ATOM_INDICES_COL_1=order_atom_indices_col[1], + ORDER_CUTOFF_COL_0=order_cutoff_col[0], + ORDER_CUTOFF_COL_1=order_cutoff_col[1], + # Flags for mask multiplicity + MASK_ROW_HAS_MUL=mask_row_has_mul, + MASK_COL_HAS_MUL=mask_col_has_mul, + # Explicit multiplicity factor + MULTIPLICITY=multiplicity, + num_warps=4, + num_stages=3, + ) + + # Final reduction / normalization + if per_atom: + # Per-row outputs: out_num, out_denom have shape [B_mul, N_row] + # mask_no_match: True where there are valid pairs for this row + mask_no_match = (out_denom > 0).to(input_dtype) + + if return_unnormalized_score: + # Return raw (out_num, out_denom, mask_no_match) for distributed allreduce + return out_num.to(input_dtype), out_denom.to(input_dtype), mask_no_match + + # Per-row normalization: score = (eps + sum(dists_to_score * score)) / (eps + sum(dists_to_score)) + # This matches lddt_dist's per_atom=True behavior + norm = 1.0 / (eps + out_denom) + score = norm * (eps + out_num) + if return_denom: + return score.to(input_dtype), mask_no_match, out_denom.to(input_dtype) + return score.to(input_dtype), mask_no_match + + # Total outputs: out_num, out_denom have shape [B_mul] + if return_unnormalized_score: + # Return raw (out_num, out_denom) for distributed allreduce + return out_num.to(input_dtype), out_denom.to(input_dtype) + + # Total score normalization + # Avoid division by zero + result = out_num / (out_denom + eps) + + # If denominator is 0, result should be 0 (no valid atoms) + score = torch.where(out_denom > 0, result, torch.zeros_like(result)) + if return_denom: + return score.to(input_dtype), out_denom.to(input_dtype) + return score.to(input_dtype) diff --git a/src/boltz/distributed/model/loss/triton/cdist_pde.py b/src/boltz/distributed/model/loss/triton/cdist_pde.py new file mode 100644 index 000000000..d1c2715ec --- /dev/null +++ b/src/boltz/distributed/model/loss/triton/cdist_pde.py @@ -0,0 +1,1032 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _cdist_pde_fwd_kernel( + # Pointers to inputs (passed as tensors for make_block_ptr) + pred_pde, # [B_mul, N_row, N_col, num_bins] + true_coords_row, # [B_mul, N_row, 3] + true_coords_col, # [B_mul, N_col, 3] + pred_coords_row, # [B_mul, N_row, 3] + pred_coords_col, # [B_mul, N_col, 3] + mask_row, # [B, N_row] or [B_mul, N_row] + mask_col, # [B, N_col] or [B_mul, N_col] + # Output pointers (scalar per batch, use atomic_add) + out_loss_num_ptr, # [B_mul] + out_mask_denom_ptr, # [B_mul] + # Shape info for pred_pde [B_mul, N_row, N_col, num_bins] + B_mul, + N_row, + N_col, + NUM_BINS: tl.constexpr, + stride_pde_b, + stride_pde_i, + stride_pde_j, + stride_pde_k, + # Strides for true_coords_row [B_mul, N_row, 3] + stride_tc_row_b, + stride_tc_row_n, + stride_tc_row_d, + # Strides for true_coords_col [B_mul, N_col, 3] + stride_tc_col_b, + stride_tc_col_n, + stride_tc_col_d, + # Strides for pred_coords_row [B_mul, N_row, 3] + stride_pc_row_b, + stride_pc_row_n, + stride_pc_row_d, + # Strides for pred_coords_col [B_mul, N_col, 3] + stride_pc_col_b, + stride_pc_col_n, + stride_pc_col_d, + # Shape info for mask_row [B_mask_row, N_row] + B_mask_row, + stride_mask_row_b, + stride_mask_row_n, + # Shape info for mask_col [B_mask_col, N_col] + B_mask_col, + stride_mask_col_b, + stride_mask_col_n, + # Constants + max_dist, + # Block sizes + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + # Coordinate dimension + SIZE_DIM_D: tl.constexpr, + # Mask multiplicity flags + MASK_ROW_HAS_MUL: tl.constexpr, + MASK_COL_HAS_MUL: tl.constexpr, + MULTIPLICITY: tl.constexpr, + # Memory layout orders for make_block_ptr (computed from argsort of strides) + ORDER_TC_ROW_0: tl.constexpr, + ORDER_TC_ROW_1: tl.constexpr, + ORDER_TC_ROW_2: tl.constexpr, + ORDER_TC_COL_0: tl.constexpr, + ORDER_TC_COL_1: tl.constexpr, + ORDER_TC_COL_2: tl.constexpr, + ORDER_PC_ROW_0: tl.constexpr, + ORDER_PC_ROW_1: tl.constexpr, + ORDER_PC_ROW_2: tl.constexpr, + ORDER_PC_COL_0: tl.constexpr, + ORDER_PC_COL_1: tl.constexpr, + ORDER_PC_COL_2: tl.constexpr, + ORDER_MASK_ROW_0: tl.constexpr, + ORDER_MASK_ROW_1: tl.constexpr, + ORDER_MASK_COL_0: tl.constexpr, + ORDER_MASK_COL_1: tl.constexpr, + ORDER_PDE_0: tl.constexpr, + ORDER_PDE_1: tl.constexpr, + ORDER_PDE_2: tl.constexpr, + ORDER_PDE_3: tl.constexpr, +): + """ + Forward kernel for computing PDE cross-entropy loss. + + For each (i, j) pair in a tile: + 1. Compute true_d = ||true_coords[i] - true_coords[j]|| + 2. Compute pred_d = ||pred_coords[i] - pred_coords[j]|| + 3. Compute target_pde = |true_d - pred_d| + 4. Compute bin_index = clamp(floor(target_pde * num_bins / max_dist), max=num_bins-1) + 5. Compute log_softmax of pred_pde logits + 6. Compute cross-entropy: ce_loss = -log_softmax[bin_index] + 7. Accumulate: out_loss_num[b] += sum_{i,j}(ce_loss * mask[i,j]) + out_mask_denom[b] += sum_{i,j}(mask[i,j]) + """ + tl.static_assert(SIZE_DIM_D == 3, "SIZE_DIM_D must be 3 (3D coordinates)") + + # Block dimension for coordinates (next power of 2 of 3) + BLOCK_D: tl.constexpr = 4 + + # Grid identification + pid_batch = tl.program_id(0) # batch * multiplicity + pid_m = tl.program_id(1) # row block + pid_n = tl.program_id(2) # col block + + # Multiplicity handling for mask broadcasting + batch_idx = pid_batch // MULTIPLICITY + batch_idx_mask_row = pid_batch if MASK_ROW_HAS_MUL else batch_idx + batch_idx_mask_col = pid_batch if MASK_COL_HAS_MUL else batch_idx + + # ============================================ + # 1. Create block pointers and load coordinates + # ============================================ + # Using make_block_ptr lets the compiler handle index types automatically + # For 3D coordinate tensors: [B_mul, N, 3], load block [1, BLOCK_M/N, BLOCK_D] + tc_row_block_ptr = tl.make_block_ptr( + base=true_coords_row, + shape=(B_mul, N_row, SIZE_DIM_D), + strides=(stride_tc_row_b, stride_tc_row_n, stride_tc_row_d), + offsets=(pid_batch, pid_m * BLOCK_M, 0), + block_shape=(1, BLOCK_M, BLOCK_D), + order=(ORDER_TC_ROW_0, ORDER_TC_ROW_1, ORDER_TC_ROW_2), + ) + tc_col_block_ptr = tl.make_block_ptr( + base=true_coords_col, + shape=(B_mul, N_col, SIZE_DIM_D), + strides=(stride_tc_col_b, stride_tc_col_n, stride_tc_col_d), + offsets=(pid_batch, pid_n * BLOCK_N, 0), + block_shape=(1, BLOCK_N, BLOCK_D), + order=(ORDER_TC_COL_0, ORDER_TC_COL_1, ORDER_TC_COL_2), + ) + pc_row_block_ptr = tl.make_block_ptr( + base=pred_coords_row, + shape=(B_mul, N_row, SIZE_DIM_D), + strides=(stride_pc_row_b, stride_pc_row_n, stride_pc_row_d), + offsets=(pid_batch, pid_m * BLOCK_M, 0), + block_shape=(1, BLOCK_M, BLOCK_D), + order=(ORDER_PC_ROW_0, ORDER_PC_ROW_1, ORDER_PC_ROW_2), + ) + pc_col_block_ptr = tl.make_block_ptr( + base=pred_coords_col, + shape=(B_mul, N_col, SIZE_DIM_D), + strides=(stride_pc_col_b, stride_pc_col_n, stride_pc_col_d), + offsets=(pid_batch, pid_n * BLOCK_N, 0), + block_shape=(1, BLOCK_N, BLOCK_D), + order=(ORDER_PC_COL_0, ORDER_PC_COL_1, ORDER_PC_COL_2), + ) + + # Load coordinates with boundary_check and reshape to squeeze batch dim + # boundary_check=(1, 2) checks N and D dims; batch dim is always in bounds + true_row = tl.reshape( + tl.load(tc_row_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_M, BLOCK_D), + ) + true_col = tl.reshape( + tl.load(tc_col_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_N, BLOCK_D), + ) + pred_row = tl.reshape( + tl.load(pc_row_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_M, BLOCK_D), + ) + pred_col = tl.reshape( + tl.load(pc_col_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_N, BLOCK_D), + ) + + # ============================================ + # 2. Compute pairwise distances [BLOCK_M, BLOCK_N] + # ============================================ + # true_d[i,j] = ||true_row[i] - true_col[j]|| + delta_true = true_row[:, None, :] - true_col[None, :, :] # [BLOCK_M, BLOCK_N, BLOCK_D] + d_true_sq = tl.sum(delta_true * delta_true, axis=2) # [BLOCK_M, BLOCK_N] + d_true = tl.sqrt(d_true_sq) + + # pred_d[i,j] = ||pred_row[i] - pred_col[j]|| + delta_pred = pred_row[:, None, :] - pred_col[None, :, :] # [BLOCK_M, BLOCK_N, BLOCK_D] + d_pred_sq = tl.sum(delta_pred * delta_pred, axis=2) # [BLOCK_M, BLOCK_N] + d_pred = tl.sqrt(d_pred_sq) + + # ============================================ + # 3. Compute target_pde and bin_index [BLOCK_M, BLOCK_N] + # ============================================ + target_pde = tl.abs(d_true - d_pred) + + # bin_index = clamp(floor(target_pde * num_bins / max_dist), max=num_bins-1) + bin_index_float = target_pde * NUM_BINS / max_dist + bin_index = tl.minimum(tl.floor(bin_index_float).to(tl.int32), NUM_BINS - 1) + + # ============================================ + # 4. Load masks using block_ptr + # ============================================ + # For 2D mask tensors: [B or B_mul, N], load block [1, BLOCK_M/N] + mask_row_block_ptr = tl.make_block_ptr( + base=mask_row, + shape=(B_mask_row, N_row), + strides=(stride_mask_row_b, stride_mask_row_n), + offsets=(batch_idx_mask_row, pid_m * BLOCK_M), + block_shape=(1, BLOCK_M), + order=(ORDER_MASK_ROW_0, ORDER_MASK_ROW_1), + ) + mask_col_block_ptr = tl.make_block_ptr( + base=mask_col, + shape=(B_mask_col, N_col), + strides=(stride_mask_col_b, stride_mask_col_n), + offsets=(batch_idx_mask_col, pid_n * BLOCK_N), + block_shape=(1, BLOCK_N), + order=(ORDER_MASK_COL_0, ORDER_MASK_COL_1), + ) + + # Load masks and reshape to squeeze batch dim + m_row = tl.reshape( + tl.load(mask_row_block_ptr, boundary_check=(1,), padding_option="zero"), + (BLOCK_M,), + ) + m_col = tl.reshape( + tl.load(mask_col_block_ptr, boundary_check=(1,), padding_option="zero"), + (BLOCK_N,), + ) + + # Combined pair mask [BLOCK_M, BLOCK_N] + pair_mask = m_row[:, None] * m_col[None, :] + + # ============================================ + # 5. Load pred_pde logits using block_ptr and compute cross-entropy + # ============================================ + # For 4D pred_pde: [B_mul, N_row, N_col, num_bins], load block [1, BLOCK_M, BLOCK_N, NUM_BINS] + pde_block_ptr = tl.make_block_ptr( + base=pred_pde, + shape=(B_mul, N_row, N_col, NUM_BINS), + strides=(stride_pde_b, stride_pde_i, stride_pde_j, stride_pde_k), + offsets=(pid_batch, pid_m * BLOCK_M, pid_n * BLOCK_N, 0), + block_shape=(1, BLOCK_M, BLOCK_N, NUM_BINS), + order=(ORDER_PDE_0, ORDER_PDE_1, ORDER_PDE_2, ORDER_PDE_3), + ) + + # Load and reshape to squeeze batch dim: [1, BLOCK_M, BLOCK_N, NUM_BINS] -> [BLOCK_M, BLOCK_N, NUM_BINS] + logits = tl.reshape( + tl.load(pde_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_M, BLOCK_N, NUM_BINS), + ) + + # Compute log_softmax along the last dimension + # log_softmax(x) = x - log(sum(exp(x))) + max_logits = tl.max(logits, axis=2)[:, :, None] # [BLOCK_M, BLOCK_N, 1] + logits_shifted = logits - max_logits + exp_logits = tl.exp(logits_shifted) + sum_exp = tl.sum(exp_logits, axis=2)[:, :, None] # [BLOCK_M, BLOCK_N, 1] + log_sum_exp = tl.log(sum_exp) + log_probs = logits_shifted - log_sum_exp # [BLOCK_M, BLOCK_N, NUM_BINS] + + # Gather log_probs at bin_index for each (i, j) + # We need log_probs[i, j, bin_index[i, j]] + # Use tl.gather along axis=2 (the bins dimension) + selected_log_prob = tl.gather(log_probs, bin_index[:, :, None], axis=2) # [BLOCK_M, BLOCK_N, 1] + selected_log_prob = tl.reshape(selected_log_prob, (BLOCK_M, BLOCK_N)) # [BLOCK_M, BLOCK_N] + + # Cross-entropy loss: -log_prob + ce_loss = -selected_log_prob # [BLOCK_M, BLOCK_N] + + # Apply pair mask + ce_loss_masked = ce_loss * pair_mask # [BLOCK_M, BLOCK_N] + + # ============================================ + # 6. Accumulate tile sum and atomic add (scalar per batch) + # ============================================ + # Sum over both row and column dimensions (full tile contribution) + loss_tile_sum = tl.sum(ce_loss_masked) # scalar + denom_tile_sum = tl.sum(pair_mask) # scalar + + # Atomic add to per-batch output (scalar) + out_loss_ptr = out_loss_num_ptr + pid_batch + out_denom_ptr = out_mask_denom_ptr + pid_batch + + tl.atomic_add(out_loss_ptr, loss_tile_sum, sem="relaxed") + tl.atomic_add(out_denom_ptr, denom_tile_sum, sem="relaxed") + + +@triton.jit +def _cdist_pde_bwd_kernel( + # Pointers to inputs (passed as tensors for make_block_ptr) + grad_out_loss_num, # [B_mul] - upstream gradient (scalar per batch) + pred_pde, # [B_mul, N_row, N_col, num_bins] + true_coords_row, # [B_mul, N_row, 3] + true_coords_col, # [B_mul, N_col, 3] + pred_coords_row, # [B_mul, N_row, 3] + pred_coords_col, # [B_mul, N_col, 3] + mask_row, # [B, N_row] or [B_mul, N_row] + mask_col, # [B, N_col] or [B_mul, N_col] + # Output tensor (for make_block_ptr) + grad_pred_pde, # [B_mul, N_row, N_col, num_bins] + # Shape info for pred_pde [B_mul, N_row, N_col, num_bins] + B_mul, + N_row, + N_col, + NUM_BINS: tl.constexpr, + stride_pde_b, + stride_pde_i, + stride_pde_j, + stride_pde_k, + # Strides for true_coords_row [B_mul, N_row, 3] + stride_tc_row_b, + stride_tc_row_n, + stride_tc_row_d, + # Strides for true_coords_col [B_mul, N_col, 3] + stride_tc_col_b, + stride_tc_col_n, + stride_tc_col_d, + # Strides for pred_coords_row [B_mul, N_row, 3] + stride_pc_row_b, + stride_pc_row_n, + stride_pc_row_d, + # Strides for pred_coords_col [B_mul, N_col, 3] + stride_pc_col_b, + stride_pc_col_n, + stride_pc_col_d, + # Shape info for mask_row [B_mask_row, N_row] + B_mask_row, + stride_mask_row_b, + stride_mask_row_n, + # Shape info for mask_col [B_mask_col, N_col] + B_mask_col, + stride_mask_col_b, + stride_mask_col_n, + # Stride for grad_out_loss_num [B_mul] + stride_grad_out_b, + # Constants + max_dist, + # Block sizes + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + # Coordinate dimension + SIZE_DIM_D: tl.constexpr, + # Mask multiplicity flags + MASK_ROW_HAS_MUL: tl.constexpr, + MASK_COL_HAS_MUL: tl.constexpr, + MULTIPLICITY: tl.constexpr, + # Memory layout orders for make_block_ptr + ORDER_TC_ROW_0: tl.constexpr, + ORDER_TC_ROW_1: tl.constexpr, + ORDER_TC_ROW_2: tl.constexpr, + ORDER_TC_COL_0: tl.constexpr, + ORDER_TC_COL_1: tl.constexpr, + ORDER_TC_COL_2: tl.constexpr, + ORDER_PC_ROW_0: tl.constexpr, + ORDER_PC_ROW_1: tl.constexpr, + ORDER_PC_ROW_2: tl.constexpr, + ORDER_PC_COL_0: tl.constexpr, + ORDER_PC_COL_1: tl.constexpr, + ORDER_PC_COL_2: tl.constexpr, + ORDER_MASK_ROW_0: tl.constexpr, + ORDER_MASK_ROW_1: tl.constexpr, + ORDER_MASK_COL_0: tl.constexpr, + ORDER_MASK_COL_1: tl.constexpr, + ORDER_PDE_0: tl.constexpr, + ORDER_PDE_1: tl.constexpr, + ORDER_PDE_2: tl.constexpr, + ORDER_PDE_3: tl.constexpr, +): + """ + Backward kernel for PDE cross-entropy loss. + + Computes gradient w.r.t. pred_pde only. + + For cross-entropy loss with log_softmax: + loss = -log_softmax(logits)[target] + d_loss/d_logits = softmax(logits) - one_hot(target) + + With upstream gradient (scalar per batch) and mask: + grad_pred_pde[i,j,:] = (softmax - one_hot(bin_idx)) * mask[i,j] * grad_out[b] + """ + tl.static_assert(SIZE_DIM_D == 3, "SIZE_DIM_D must be 3 (3D coordinates)") + + BLOCK_D: tl.constexpr = 4 + + # Grid identification + pid_batch = tl.program_id(0) + pid_m = tl.program_id(1) + pid_n = tl.program_id(2) + + # Multiplicity handling + batch_idx = pid_batch // MULTIPLICITY + batch_idx_mask_row = pid_batch if MASK_ROW_HAS_MUL else batch_idx + batch_idx_mask_col = pid_batch if MASK_COL_HAS_MUL else batch_idx + + # ============================================ + # 1. Create block pointers and load coordinates + # ============================================ + tc_row_block_ptr = tl.make_block_ptr( + base=true_coords_row, + shape=(B_mul, N_row, SIZE_DIM_D), + strides=(stride_tc_row_b, stride_tc_row_n, stride_tc_row_d), + offsets=(pid_batch, pid_m * BLOCK_M, 0), + block_shape=(1, BLOCK_M, BLOCK_D), + order=(ORDER_TC_ROW_0, ORDER_TC_ROW_1, ORDER_TC_ROW_2), + ) + tc_col_block_ptr = tl.make_block_ptr( + base=true_coords_col, + shape=(B_mul, N_col, SIZE_DIM_D), + strides=(stride_tc_col_b, stride_tc_col_n, stride_tc_col_d), + offsets=(pid_batch, pid_n * BLOCK_N, 0), + block_shape=(1, BLOCK_N, BLOCK_D), + order=(ORDER_TC_COL_0, ORDER_TC_COL_1, ORDER_TC_COL_2), + ) + pc_row_block_ptr = tl.make_block_ptr( + base=pred_coords_row, + shape=(B_mul, N_row, SIZE_DIM_D), + strides=(stride_pc_row_b, stride_pc_row_n, stride_pc_row_d), + offsets=(pid_batch, pid_m * BLOCK_M, 0), + block_shape=(1, BLOCK_M, BLOCK_D), + order=(ORDER_PC_ROW_0, ORDER_PC_ROW_1, ORDER_PC_ROW_2), + ) + pc_col_block_ptr = tl.make_block_ptr( + base=pred_coords_col, + shape=(B_mul, N_col, SIZE_DIM_D), + strides=(stride_pc_col_b, stride_pc_col_n, stride_pc_col_d), + offsets=(pid_batch, pid_n * BLOCK_N, 0), + block_shape=(1, BLOCK_N, BLOCK_D), + order=(ORDER_PC_COL_0, ORDER_PC_COL_1, ORDER_PC_COL_2), + ) + + # Load coordinates with boundary_check and reshape + true_row = tl.reshape( + tl.load(tc_row_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_M, BLOCK_D), + ) + true_col = tl.reshape( + tl.load(tc_col_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_N, BLOCK_D), + ) + pred_row = tl.reshape( + tl.load(pc_row_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_M, BLOCK_D), + ) + pred_col = tl.reshape( + tl.load(pc_col_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_N, BLOCK_D), + ) + + # ============================================ + # 2. Recompute distances and bin_index + # ============================================ + delta_true = true_row[:, None, :] - true_col[None, :, :] + d_true_sq = tl.sum(delta_true * delta_true, axis=2) + d_true = tl.sqrt(d_true_sq) + + delta_pred = pred_row[:, None, :] - pred_col[None, :, :] + d_pred_sq = tl.sum(delta_pred * delta_pred, axis=2) + d_pred = tl.sqrt(d_pred_sq) + + target_pde = tl.abs(d_true - d_pred) + bin_index_float = target_pde * NUM_BINS / max_dist + bin_index = tl.minimum(tl.floor(bin_index_float).to(tl.int32), NUM_BINS - 1) # [BLOCK_M, BLOCK_N] + + # ============================================ + # 3. Load masks using block_ptr + # ============================================ + mask_row_block_ptr = tl.make_block_ptr( + base=mask_row, + shape=(B_mask_row, N_row), + strides=(stride_mask_row_b, stride_mask_row_n), + offsets=(batch_idx_mask_row, pid_m * BLOCK_M), + block_shape=(1, BLOCK_M), + order=(ORDER_MASK_ROW_0, ORDER_MASK_ROW_1), + ) + mask_col_block_ptr = tl.make_block_ptr( + base=mask_col, + shape=(B_mask_col, N_col), + strides=(stride_mask_col_b, stride_mask_col_n), + offsets=(batch_idx_mask_col, pid_n * BLOCK_N), + block_shape=(1, BLOCK_N), + order=(ORDER_MASK_COL_0, ORDER_MASK_COL_1), + ) + + m_row = tl.reshape( + tl.load(mask_row_block_ptr, boundary_check=(1,), padding_option="zero"), + (BLOCK_M,), + ) + m_col = tl.reshape( + tl.load(mask_col_block_ptr, boundary_check=(1,), padding_option="zero"), + (BLOCK_N,), + ) + + pair_mask = m_row[:, None] * m_col[None, :] # [BLOCK_M, BLOCK_N] + + # ============================================ + # 4. Load upstream gradient (scalar per batch) using make_block_ptr + # ============================================ + # grad_out_loss_num is [B_mul], load single scalar for this batch + grad_out_block_ptr = tl.make_block_ptr( + base=grad_out_loss_num, + shape=(B_mul,), + strides=(stride_grad_out_b,), + offsets=(pid_batch,), + block_shape=(1,), + order=(0,), + ) + grad_out = tl.reshape(tl.load(grad_out_block_ptr), ()) # scalar + + # ============================================ + # 5. Load pred_pde logits using block_ptr and compute gradient + # ============================================ + pde_block_ptr = tl.make_block_ptr( + base=pred_pde, + shape=(B_mul, N_row, N_col, NUM_BINS), + strides=(stride_pde_b, stride_pde_i, stride_pde_j, stride_pde_k), + offsets=(pid_batch, pid_m * BLOCK_M, pid_n * BLOCK_N, 0), + block_shape=(1, BLOCK_M, BLOCK_N, NUM_BINS), + order=(ORDER_PDE_0, ORDER_PDE_1, ORDER_PDE_2, ORDER_PDE_3), + ) + logits = tl.reshape( + tl.load(pde_block_ptr, boundary_check=(1, 2), padding_option="zero"), + (BLOCK_M, BLOCK_N, NUM_BINS), + ) + + # Compute softmax + max_logits = tl.max(logits, axis=2)[:, :, None] # [BLOCK_M, BLOCK_N, 1] + logits_shifted = logits - max_logits + exp_logits = tl.exp(logits_shifted) + sum_exp = tl.sum(exp_logits, axis=2)[:, :, None] # [BLOCK_M, BLOCK_N, 1] + softmax_probs = exp_logits / sum_exp # [BLOCK_M, BLOCK_N, NUM_BINS] + + # Create one_hot for targets: [BLOCK_M, BLOCK_N, NUM_BINS] + offs_k = tl.arange(0, NUM_BINS) + one_hot = (offs_k[None, None, :] == bin_index[:, :, None]).to(tl.float32) + + # Gradient of cross-entropy w.r.t. logits: softmax - one_hot + grad_logits = softmax_probs - one_hot # [BLOCK_M, BLOCK_N, NUM_BINS] + + # Apply mask and upstream gradient (scalar per batch) + # grad_pred_pde[i,j,:] = grad_logits * mask[i,j] * grad_out[b] + scale = pair_mask * grad_out # [BLOCK_M, BLOCK_N] (scalar broadcast) + grad_logits_scaled = grad_logits * scale[:, :, None] # [BLOCK_M, BLOCK_N, NUM_BINS] + + # ============================================ + # 6. Store gradient using block_ptr + # ============================================ + # Create block_ptr for gradient output (same layout as pred_pde) + grad_pde_block_ptr = tl.make_block_ptr( + base=grad_pred_pde, + shape=(B_mul, N_row, N_col, NUM_BINS), + strides=(stride_pde_b, stride_pde_i, stride_pde_j, stride_pde_k), + offsets=(pid_batch, pid_m * BLOCK_M, pid_n * BLOCK_N, 0), + block_shape=(1, BLOCK_M, BLOCK_N, NUM_BINS), + order=(ORDER_PDE_0, ORDER_PDE_1, ORDER_PDE_2, ORDER_PDE_3), + ) + # Expand to 4D for store: [BLOCK_M, BLOCK_N, NUM_BINS] -> [1, BLOCK_M, BLOCK_N, NUM_BINS] + grad_logits_4d = tl.reshape(grad_logits_scaled, (1, BLOCK_M, BLOCK_N, NUM_BINS)) + tl.store(grad_pde_block_ptr, grad_logits_4d, boundary_check=(1, 2)) + + +class _CdistPDEImpl(torch.autograd.Function): + """Autograd function wrapping forward and backward Triton kernels.""" + + @staticmethod + def forward( + ctx, + pred_pde, + true_coords_row, + true_coords_col, + pred_coords_row, + pred_coords_col, + mask_row, + mask_col, + multiplicity, + num_bins, + max_dist, + ): + """ + Forward pass: compute PDE cross-entropy fully summed per batch. + + Args: + pred_pde: [B_mul, N_row, N_col, num_bins] - predicted logits + true_coords_row: [B_mul, N_row, 3] - true coordinates for rows + true_coords_col: [B_mul, N_col, 3] - true coordinates for columns + pred_coords_row: [B_mul, N_row, 3] - predicted coordinates for rows + pred_coords_col: [B_mul, N_col, 3] - predicted coordinates for columns + mask_row: [B, N_row] or [B_mul, N_row] - row mask + mask_col: [B, N_col] or [B_mul, N_col] - column mask + multiplicity: int - B_mul = B * multiplicity + num_bins: int - number of bins + max_dist: float - maximum distance for binning + + Returns: + out_loss_num: [B_mul] - sum of CE loss over all (i,j) pairs per batch + out_mask_denom: [B_mul] - sum of mask over all (i,j) pairs per batch + """ + B_mul, N_row, N_col, num_bins_tensor = pred_pde.shape + device = pred_pde.device + + # Validate num_bins + if num_bins_tensor != num_bins: + raise ValueError(f"pred_pde num_bins mismatch: got {num_bins_tensor}, expected {num_bins}") + + # Compute B from multiplicity + if B_mul % multiplicity != 0: + raise ValueError(f"B_mul ({B_mul}) must be divisible by multiplicity ({multiplicity})") + B = B_mul // multiplicity + + # Validate that coordinates and masks don't require gradients + # (gradient flow is broken by .long() in bin_index computation) + if true_coords_row.requires_grad: + raise ValueError( + "true_coords_row should not require gradients " "(gradient flow is broken by bin_index computation)" + ) + if true_coords_col.requires_grad: + raise ValueError( + "true_coords_col should not require gradients " "(gradient flow is broken by bin_index computation)" + ) + if pred_coords_row.requires_grad: + raise ValueError( + "pred_coords_row should not require gradients " "(gradient flow is broken by bin_index computation)" + ) + if pred_coords_col.requires_grad: + raise ValueError( + "pred_coords_col should not require gradients " "(gradient flow is broken by bin_index computation)" + ) + if mask_row.requires_grad: + raise ValueError("mask_row should not require gradients") + if mask_col.requires_grad: + raise ValueError("mask_col should not require gradients") + + # Validate coordinate dimensions + if true_coords_row.shape[-1] != 3: + raise ValueError(f"Coordinate dimension must be 3, got true_coords_row shape {true_coords_row.shape}") + + # Validate coordinate shapes + if true_coords_row.shape != (B_mul, N_row, 3): + raise ValueError( + f"true_coords_row shape mismatch: got {true_coords_row.shape}, " f"expected ({B_mul}, {N_row}, 3)" + ) + if true_coords_col.shape != (B_mul, N_col, 3): + raise ValueError( + f"true_coords_col shape mismatch: got {true_coords_col.shape}, " f"expected ({B_mul}, {N_col}, 3)" + ) + if pred_coords_row.shape != (B_mul, N_row, 3): + raise ValueError( + f"pred_coords_row shape mismatch: got {pred_coords_row.shape}, " f"expected ({B_mul}, {N_row}, 3)" + ) + if pred_coords_col.shape != (B_mul, N_col, 3): + raise ValueError( + f"pred_coords_col shape mismatch: got {pred_coords_col.shape}, " f"expected ({B_mul}, {N_col}, 3)" + ) + + # Check mask dimensions + mask_row_has_mul = mask_row.shape[0] == B_mul + mask_col_has_mul = mask_col.shape[0] == B_mul + + # Validate mask_row shape + if mask_row.shape[0] not in (B, B_mul): + raise ValueError( + f"mask_row batch dimension must be B ({B}) or B_mul ({B_mul}), " f"got {mask_row.shape[0]}" + ) + if mask_row.shape[1] != N_row: + raise ValueError(f"mask_row N dimension mismatch: got {mask_row.shape[1]}, expected {N_row}") + + # Validate mask_col shape + if mask_col.shape[0] not in (B, B_mul): + raise ValueError( + f"mask_col batch dimension must be B ({B}) or B_mul ({B_mul}), " f"got {mask_col.shape[0]}" + ) + if mask_col.shape[1] != N_col: + raise ValueError(f"mask_col N dimension mismatch: got {mask_col.shape[1]}, expected {N_col}") + + # Don't materialize zero gradients for non-differentiable outputs (out_mask_denom) + # This makes grad_out_mask_denom be None in backward() instead of zeros + ctx.set_materialize_grads(False) + + # Output buffers - scalar per batch [B_mul] + # Use float64 if input is float64, otherwise float32 + output_dtype = pred_pde.dtype if pred_pde.dtype == torch.float64 else torch.float32 + out_loss_num = torch.zeros(B_mul, device=device, dtype=output_dtype) + out_mask_denom = torch.zeros(B_mul, device=device, dtype=output_dtype) + + # Block sizes (tuned for N=4096, num_bins=64) + BLOCK_M = 8 + BLOCK_N = 8 + + # Compute memory layout order for make_block_ptr + # order = argsort of strides (ascending), giving fastest-varying dim first + order_tc_row = tuple(torch.tensor(true_coords_row.stride()).argsort().tolist()) + order_tc_col = tuple(torch.tensor(true_coords_col.stride()).argsort().tolist()) + order_pc_row = tuple(torch.tensor(pred_coords_row.stride()).argsort().tolist()) + order_pc_col = tuple(torch.tensor(pred_coords_col.stride()).argsort().tolist()) + order_mask_row = tuple(torch.tensor(mask_row.stride()).argsort().tolist()) + order_mask_col = tuple(torch.tensor(mask_col.stride()).argsort().tolist()) + order_pde = tuple(torch.tensor(pred_pde.stride()).argsort().tolist()) + + # Grid + grid = (B_mul, triton.cdiv(N_row, BLOCK_M), triton.cdiv(N_col, BLOCK_N)) + + _cdist_pde_fwd_kernel[grid]( + pred_pde, + true_coords_row, + true_coords_col, + pred_coords_row, + pred_coords_col, + mask_row, + mask_col, + out_loss_num, + out_mask_denom, + # Shape info + B_mul, + N_row, + N_col, + num_bins, # constexpr + pred_pde.stride(0), + pred_pde.stride(1), + pred_pde.stride(2), + pred_pde.stride(3), + # Coord strides + true_coords_row.stride(0), + true_coords_row.stride(1), + true_coords_row.stride(2), + true_coords_col.stride(0), + true_coords_col.stride(1), + true_coords_col.stride(2), + pred_coords_row.stride(0), + pred_coords_row.stride(1), + pred_coords_row.stride(2), + pred_coords_col.stride(0), + pred_coords_col.stride(1), + pred_coords_col.stride(2), + # Mask info + mask_row.shape[0], + mask_row.stride(0), + mask_row.stride(1), + mask_col.shape[0], + mask_col.stride(0), + mask_col.stride(1), + # Constants + max_dist, + # Block sizes (tuned for N=4096, num_bins=64) + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + SIZE_DIM_D=3, + # Flags + MASK_ROW_HAS_MUL=mask_row_has_mul, + MASK_COL_HAS_MUL=mask_col_has_mul, + MULTIPLICITY=multiplicity, + # Memory layout orders + ORDER_TC_ROW_0=order_tc_row[0], + ORDER_TC_ROW_1=order_tc_row[1], + ORDER_TC_ROW_2=order_tc_row[2], + ORDER_TC_COL_0=order_tc_col[0], + ORDER_TC_COL_1=order_tc_col[1], + ORDER_TC_COL_2=order_tc_col[2], + ORDER_PC_ROW_0=order_pc_row[0], + ORDER_PC_ROW_1=order_pc_row[1], + ORDER_PC_ROW_2=order_pc_row[2], + ORDER_PC_COL_0=order_pc_col[0], + ORDER_PC_COL_1=order_pc_col[1], + ORDER_PC_COL_2=order_pc_col[2], + ORDER_MASK_ROW_0=order_mask_row[0], + ORDER_MASK_ROW_1=order_mask_row[1], + ORDER_MASK_COL_0=order_mask_col[0], + ORDER_MASK_COL_1=order_mask_col[1], + ORDER_PDE_0=order_pde[0], + ORDER_PDE_1=order_pde[1], + ORDER_PDE_2=order_pde[2], + ORDER_PDE_3=order_pde[3], + num_warps=4, + num_stages=3, + ) + + # Save for backward + ctx.save_for_backward( + pred_pde, + true_coords_row, + true_coords_col, + pred_coords_row, + pred_coords_col, + mask_row, + mask_col, + ) + ctx.multiplicity = multiplicity + ctx.num_bins = num_bins + ctx.max_dist = max_dist + ctx.mask_row_has_mul = mask_row_has_mul + ctx.mask_col_has_mul = mask_col_has_mul + + # out_mask_denom only depends on mask (not pred_pde), so no gradient + ctx.mark_non_differentiable(out_mask_denom) + + return out_loss_num, out_mask_denom + + @staticmethod + def backward(ctx, grad_out_loss_num, grad_out_mask_denom): + """ + Backward pass: compute gradient w.r.t. pred_pde. + + grad_out_mask_denom should be None since out_mask_denom is marked + non-differentiable and ctx.set_materialize_grads(False) is set. + """ + # Validate grad_out_mask_denom is None (out_mask_denom is non-differentiable) + if grad_out_mask_denom is not None: + raise ValueError( + "grad_out_mask_denom should be None since out_mask_denom is " + "marked non-differentiable (it only depends on mask, not pred_pde)" + ) + + ( + pred_pde, + true_coords_row, + true_coords_col, + pred_coords_row, + pred_coords_col, + mask_row, + mask_col, + ) = ctx.saved_tensors + + multiplicity = ctx.multiplicity + num_bins = ctx.num_bins + max_dist = ctx.max_dist + mask_row_has_mul = ctx.mask_row_has_mul + mask_col_has_mul = ctx.mask_col_has_mul + + B_mul, N_row, N_col, _ = pred_pde.shape + + # Validate grad_out_loss_num shape (scalar per batch) + if grad_out_loss_num.shape != (B_mul,): + raise ValueError( + f"grad_out_loss_num shape mismatch: got {grad_out_loss_num.shape}, " f"expected ({B_mul},)" + ) + + # Output gradient buffer + grad_pred_pde = torch.zeros_like(pred_pde) + + # Block sizes (tuned for N=4096, num_bins=64) + BLOCK_M = 4 + BLOCK_N = 4 + + # Compute memory layout order for make_block_ptr + order_tc_row = tuple(torch.tensor(true_coords_row.stride()).argsort().tolist()) + order_tc_col = tuple(torch.tensor(true_coords_col.stride()).argsort().tolist()) + order_pc_row = tuple(torch.tensor(pred_coords_row.stride()).argsort().tolist()) + order_pc_col = tuple(torch.tensor(pred_coords_col.stride()).argsort().tolist()) + order_mask_row = tuple(torch.tensor(mask_row.stride()).argsort().tolist()) + order_mask_col = tuple(torch.tensor(mask_col.stride()).argsort().tolist()) + order_pde = tuple(torch.tensor(pred_pde.stride()).argsort().tolist()) + grad_out_contiguous = grad_out_loss_num.contiguous() + + # Grid + grid = (B_mul, triton.cdiv(N_row, BLOCK_M), triton.cdiv(N_col, BLOCK_N)) + + _cdist_pde_bwd_kernel[grid]( + grad_out_contiguous, + pred_pde, + true_coords_row, + true_coords_col, + pred_coords_row, + pred_coords_col, + mask_row, + mask_col, + grad_pred_pde, + # Shape info + B_mul, + N_row, + N_col, + num_bins, + pred_pde.stride(0), + pred_pde.stride(1), + pred_pde.stride(2), + pred_pde.stride(3), + # Coord strides + true_coords_row.stride(0), + true_coords_row.stride(1), + true_coords_row.stride(2), + true_coords_col.stride(0), + true_coords_col.stride(1), + true_coords_col.stride(2), + pred_coords_row.stride(0), + pred_coords_row.stride(1), + pred_coords_row.stride(2), + pred_coords_col.stride(0), + pred_coords_col.stride(1), + pred_coords_col.stride(2), + # Mask info + mask_row.shape[0], + mask_row.stride(0), + mask_row.stride(1), + mask_col.shape[0], + mask_col.stride(0), + mask_col.stride(1), + # Stride for grad_out_loss_num + grad_out_contiguous.stride(0), + # Constants + max_dist, + # Block sizes (tuned for N=4096, num_bins=64) + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + SIZE_DIM_D=3, + # Flags + MASK_ROW_HAS_MUL=mask_row_has_mul, + MASK_COL_HAS_MUL=mask_col_has_mul, + MULTIPLICITY=multiplicity, + # Memory layout orders + ORDER_TC_ROW_0=order_tc_row[0], + ORDER_TC_ROW_1=order_tc_row[1], + ORDER_TC_ROW_2=order_tc_row[2], + ORDER_TC_COL_0=order_tc_col[0], + ORDER_TC_COL_1=order_tc_col[1], + ORDER_TC_COL_2=order_tc_col[2], + ORDER_PC_ROW_0=order_pc_row[0], + ORDER_PC_ROW_1=order_pc_row[1], + ORDER_PC_ROW_2=order_pc_row[2], + ORDER_PC_COL_0=order_pc_col[0], + ORDER_PC_COL_1=order_pc_col[1], + ORDER_PC_COL_2=order_pc_col[2], + ORDER_MASK_ROW_0=order_mask_row[0], + ORDER_MASK_ROW_1=order_mask_row[1], + ORDER_MASK_COL_0=order_mask_col[0], + ORDER_MASK_COL_1=order_mask_col[1], + ORDER_PDE_0=order_pde[0], + ORDER_PDE_1=order_pde[1], + ORDER_PDE_2=order_pde[2], + ORDER_PDE_3=order_pde[3], + num_warps=4, + num_stages=2, + ) + + # Return gradients (None for non-differentiable inputs) + return grad_pred_pde, None, None, None, None, None, None, None, None, None + + +def cdist_pde( + pred_pde, + true_coords_row, + true_coords_col, + pred_coords_row, + pred_coords_col, + mask_row, + mask_col, + multiplicity, + num_bins=64, + max_dist=32.0, +): + """ + Compute PDE cross-entropy loss without materializing O(N_token^2) distance matrices. + + This function computes the cross-entropy portion of the PDE loss directly from + coordinates, fusing distance computation, binning, and cross-entropy into a + single kernel that only uses O(tile_size^2) local memory. + + The computation is equivalent to: + true_d = torch.cdist(true_coords_row, true_coords_col) + pred_d = torch.cdist(pred_coords_row, pred_coords_col) + target_pde = torch.abs(true_d - pred_d) + bin_index = torch.clamp(torch.floor(target_pde * num_bins / max_dist).long(), max=num_bins-1) + one_hot = F.one_hot(bin_index, num_classes=num_bins) + errors = -torch.sum(one_hot * F.log_softmax(pred_pde, dim=-1), dim=-1) + out_loss_num = torch.sum(errors * mask, dim=(-2, -1)) # sum over both dims + out_mask_denom = torch.sum(mask, dim=(-2, -1)) # sum over both dims + + Parameters + ---------- + pred_pde : torch.Tensor + Predicted PDE logits, shape [B_mul, N_row, N_col, num_bins]. + This is the only input that requires gradients. + true_coords_row : torch.Tensor + True coordinates for row tokens, shape [B_mul, N_row, 3]. + true_coords_col : torch.Tensor + True coordinates for column tokens, shape [B_mul, N_col, 3]. + pred_coords_row : torch.Tensor + Predicted coordinates for row tokens, shape [B_mul, N_row, 3]. + pred_coords_col : torch.Tensor + Predicted coordinates for column tokens, shape [B_mul, N_col, 3]. + mask_row : torch.Tensor + Mask for row tokens, shape [B, N_row] or [B_mul, N_row]. + If [B, N_row], broadcasts to B_mul. + mask_col : torch.Tensor + Mask for column tokens, shape [B, N_col] or [B_mul, N_col]. + If [B, N_col], broadcasts to B_mul. + multiplicity : int + Required. Explicit multiplicity factor where B_mul = B * multiplicity. + num_bins : int, optional + Number of distance bins for PDE. Default: 64. + max_dist : float, optional + Maximum distance for binning. Default: 32.0. + + Returns + ------- + out_loss_num : torch.Tensor + Sum of cross-entropy loss over all (i,j) pairs per batch, shape [B_mul]. + out_loss_num[b] = sum_{i,j}(CE_loss[i,j] * mask[i,j]) + out_mask_denom : torch.Tensor + Sum of mask over all (i,j) pairs per batch, shape [B_mul]. + out_mask_denom[b] = sum_{i,j}(mask[i,j]) + + Notes + ----- + To compute the final normalized PDE loss: + loss = out_loss_num / (eps + out_mask_denom) + + For distributed training, allreduce out_loss_num and out_mask_denom separately + before computing the normalized loss. + """ + return _CdistPDEImpl.apply( + pred_pde, + true_coords_row, + true_coords_col, + pred_coords_row, + pred_coords_col, + mask_row, + mask_col, + multiplicity, + num_bins, + max_dist, + ) diff --git a/src/boltz/distributed/model/loss/triton/smooth_lddt_loss.py b/src/boltz/distributed/model/loss/triton/smooth_lddt_loss.py new file mode 100644 index 000000000..302f2546c --- /dev/null +++ b/src/boltz/distributed/model/loss/triton/smooth_lddt_loss.py @@ -0,0 +1,403 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +try: + import triton + import triton.language as tl +except ImportError: + raise ImportError("Triton is not available. Will not import smooth_lddt_loss kernels.") + + +def grid_launch_config(kwargs): + return ( + kwargs["shape_coords_axis_0"], + triton.cdiv(kwargs["shape_coords_axis_1"], kwargs["BLOCK"]), + triton.cdiv(kwargs["shape_coords_axis_1"], kwargs["BLOCK"]), + ) + + +@triton.heuristics( + # register pressure is high with BLOCK >= 64 and will spill at BLOCK = 128 + # due to the pair-wise distance computation + {"BLOCK": lambda args: 32, "num_warps": lambda args: 4}, +) +@triton.jit +def smooth_lddt_loss_fwd_kernel( + pred_coords_ptr, + true_coords_ptr, + pred_coords_t_ptr, + true_coords_t_ptr, + is_nucleotide_ptr, + coords_mask_ptr, + coords_mask_t_ptr, + num_output_ptr, + den_output_ptr, + stride_pred_b, + stride_pred_n, + stride_pred_d, + stride_true_b, + stride_true_n, + stride_true_d, + stride_pred_t_b, + stride_pred_t_n, + stride_pred_t_d, + stride_true_t_b, + stride_true_t_n, + stride_true_t_d, + stride_nuc_b, + stride_nuc_n, + stride_mask_b, + stride_mask_n, + stride_mask_t_b, + stride_mask_t_n, + nucleic_acid_cutoff, + other_cutoff, + is_self_comm: tl.constexpr, + shape_coords_axis_0, + shape_coords_axis_1, + shape_mask_axis_0, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + pid_m = tl.program_id(1) + pid_n = tl.program_id(2) + + # Batch offset + batch_idx = pid + + multiplicity = shape_coords_axis_0 // shape_mask_axis_0 + + # we can reuse the same mask batch without pre-repeat_interleave + # the mask with multiplicity + batch_idx_mask = batch_idx // multiplicity + + # Pointers to current batch + # Compute pointers by adding offsets * strides + # Note: pointers are 64-bit + + pred_coords_cur = pred_coords_ptr + batch_idx * stride_pred_b + true_coords_cur = true_coords_ptr + batch_idx * stride_true_b + pred_coords_t_cur = pred_coords_t_ptr + batch_idx * stride_pred_t_b + true_coords_t_cur = true_coords_t_ptr + batch_idx * stride_true_t_b + + is_nucleotide_cur = is_nucleotide_ptr + batch_idx_mask * stride_nuc_b + coords_mask_cur = coords_mask_ptr + batch_idx_mask * stride_mask_b + coords_mask_t_cur = coords_mask_t_ptr + batch_idx_mask * stride_mask_t_b + + # Offsets for M dimension + offs_m = pid_m * BLOCK + tl.arange(0, BLOCK) + mask_m = offs_m < shape_coords_axis_1 + + # Offsets for shape_coords_axis_1 dimension + offs_n = pid_n * BLOCK + tl.arange(0, BLOCK) + mask_n = offs_n < shape_coords_axis_1 + + # Load data + # is_nucleotide: (BLOCK_M,) + # Use float mask for simpler math + is_nuc = tl.load(is_nucleotide_cur + offs_m * stride_nuc_n, mask=mask_m, other=0.0).to(tl.int1) + + # coords_mask: (BLOCK_M,) + c_mask = tl.load(coords_mask_cur + offs_m * stride_mask_n, mask=mask_m, other=0.0) + + # coords_mask_t: (BLOCK_N,) + c_mask_t = tl.load(coords_mask_t_cur + offs_n * stride_mask_t_n, mask=mask_n, other=0.0) + + # Combined mask (BLOCK_M, BLOCK_N) + # mask = c_mask[:, None] * c_mask_t[None, :] + combined_mask = c_mask[:, None] * c_mask_t[None, :] + + # Handle diagonal + is_diag = (offs_m[:, None] == offs_n[None, :]) & is_self_comm + combined_mask = tl.where(is_diag, 0.0, combined_mask) + + # Distances and Differences + d_idx = tl.arange(0, 4) + mask_d = d_idx < 3 + + # True Coords + tc_ptr = true_coords_cur + offs_m[:, None] * stride_true_n + d_idx[None, :] * stride_true_d + tc_m = tl.load(tc_ptr, mask=mask_m[:, None] & mask_d[None, :], other=0.0) + + tc_t_ptr = true_coords_t_cur + offs_n[:, None] * stride_true_t_n + d_idx[None, :] * stride_true_t_d + tc_t_n = tl.load(tc_t_ptr, mask=mask_n[:, None] & mask_d[None, :], other=0.0) + + diff_true = tc_m[:, None, :] - tc_t_n[None, :, :] + true_dist_sq = tl.sum(diff_true * diff_true, axis=2) + + # Pred Coords + pc_ptr = pred_coords_cur + offs_m[:, None] * stride_pred_n + d_idx[None, :] * stride_pred_d + pc_m = tl.load(pc_ptr, mask=mask_m[:, None] & mask_d[None, :], other=0.0) + + pc_t_ptr = pred_coords_t_cur + offs_n[:, None] * stride_pred_t_n + d_idx[None, :] * stride_pred_t_d + pc_t_n = tl.load(pc_t_ptr, mask=mask_n[:, None] & mask_d[None, :], other=0.0) + + diff_pred = pc_m[:, None, :] - pc_t_n[None, :, :] + pred_dist_sq = tl.sum(diff_pred * diff_pred, axis=2) + + true_dist = tl.sqrt(true_dist_sq) + pred_dist = tl.sqrt(pred_dist_sq) + + # Cutoff mask + # is_nuc is (BLOCK_M,), broadcast to (BLOCK_M, BLOCK_N) + cutoff = tl.where(is_nuc[:, None], nucleic_acid_cutoff, other_cutoff) + dist_mask = true_dist < cutoff + + final_mask = combined_mask * dist_mask + + # Epsilon + dist_diff = tl.abs(true_dist - pred_dist) + + eps = tl.sigmoid(0.5 - dist_diff) + eps += tl.sigmoid(1.0 - dist_diff) + eps += tl.sigmoid(2.0 - dist_diff) + eps += tl.sigmoid(4.0 - dist_diff) + eps *= 0.25 + + # Accumulate + num_val = eps * final_mask + den_val = final_mask + + # Sum within block + # Accumulate in fp32 for precision + num_sum = tl.sum(num_val.to(tl.float32)) + den_sum = tl.sum(den_val.to(tl.float32)) + + # Atomic Add to global + tl.atomic_add(num_output_ptr + batch_idx, num_sum.to(pred_coords_ptr.dtype.element_ty)) + tl.atomic_add(den_output_ptr + batch_idx, den_sum.to(pred_coords_ptr.dtype.element_ty)) + + +@triton.heuristics( + # register pressure is high with BLOCK >= 64 and will spill at BLOCK = 128 + # due to the pair-wise distance computation + {"BLOCK": lambda args: 16, "num_warps": lambda args: 2}, +) +@triton.jit +def smooth_lddt_loss_bwd_kernel( + grad_num_reduced_ptr, + grad_den_reduced_ptr, + pred_coords_ptr, + true_coords_ptr, + pred_coords_t_ptr, + true_coords_t_ptr, + is_nucleotide_ptr, + coords_mask_ptr, + coords_mask_t_ptr, + grad_pred_local_ptr, + grad_pred_t_local_ptr, + stride_pred_b, + stride_pred_n, + stride_pred_d, + stride_true_b, + stride_true_n, + stride_true_d, + stride_pred_t_b, + stride_pred_t_n, + stride_pred_t_d, + stride_true_t_b, + stride_true_t_n, + stride_true_t_d, + stride_nuc_b, + stride_nuc_n, + stride_mask_b, + stride_mask_n, + stride_mask_t_b, + stride_mask_t_n, + stride_grad_pred_b, + stride_grad_pred_n, + stride_grad_pred_d, + stride_grad_pred_t_b, + stride_grad_pred_t_n, + stride_grad_pred_t_d, + nucleic_acid_cutoff, + other_cutoff, + is_self_comm: tl.constexpr, + shape_coords_axis_0, + shape_coords_axis_1, + shape_mask_axis_0, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + pid_m = tl.program_id(1) + pid_n = tl.program_id(2) + + batch_idx = pid + + multiplicity = shape_coords_axis_0 // shape_mask_axis_0 + + # we can reuse the same mask batch without pre-repeat_interleave + # the mask with multiplicity + batch_idx_mask = batch_idx // multiplicity + + # Pointers + pred_coords_cur = pred_coords_ptr + batch_idx * stride_pred_b + true_coords_cur = true_coords_ptr + batch_idx * stride_true_b + pred_coords_t_cur = pred_coords_t_ptr + batch_idx * stride_pred_t_b + true_coords_t_cur = true_coords_t_ptr + batch_idx * stride_true_t_b + + grad_pred_local_cur = grad_pred_local_ptr + batch_idx * stride_grad_pred_b + grad_pred_t_local_cur = grad_pred_t_local_ptr + batch_idx * stride_grad_pred_t_b + + is_nucleotide_cur = is_nucleotide_ptr + batch_idx_mask * stride_nuc_b + coords_mask_cur = coords_mask_ptr + batch_idx_mask * stride_mask_b + coords_mask_t_cur = coords_mask_t_ptr + batch_idx_mask * stride_mask_t_b + + # Gradients are scalars per batch + grad_num = tl.load(grad_num_reduced_ptr + batch_idx) + + # Offsets + offs_m = pid_m * BLOCK + tl.arange(0, BLOCK) + mask_m = offs_m < shape_coords_axis_1 + offs_n = pid_n * BLOCK + tl.arange(0, BLOCK) + mask_n = offs_n < shape_coords_axis_1 + + # --- Recompute Forward Pass Intermediates --- + + # Masks + is_nuc = tl.load(is_nucleotide_cur + offs_m * stride_nuc_n, mask=mask_m, other=0.0).to(tl.int1) + c_mask = tl.load(coords_mask_cur + offs_m * stride_mask_n, mask=mask_m, other=0.0) + c_mask_t = tl.load(coords_mask_t_cur + offs_n * stride_mask_t_n, mask=mask_n, other=0.0) + + combined_mask = c_mask[:, None] * c_mask_t[None, :] + is_diag = (offs_m[:, None] == offs_n[None, :]) & is_self_comm + combined_mask = tl.where(is_diag, 0.0, combined_mask) + + # Distances and Differences + d_idx = tl.arange(0, 4) + mask_d = d_idx < 3 + + # True Coords + tc_ptr = true_coords_cur + offs_m[:, None] * stride_true_n + d_idx[None, :] * stride_true_d + tc_m = tl.load(tc_ptr, mask=mask_m[:, None] & mask_d[None, :], other=0.0) + + tc_t_ptr = true_coords_t_cur + offs_n[:, None] * stride_true_t_n + d_idx[None, :] * stride_true_t_d + tc_t_n = tl.load(tc_t_ptr, mask=mask_n[:, None] & mask_d[None, :], other=0.0) + + diff_true = tc_m[:, None, :] - tc_t_n[None, :, :] + true_dist_sq = tl.sum(diff_true * diff_true, axis=2) + + # Pred Coords + pc_ptr = pred_coords_cur + offs_m[:, None] * stride_pred_n + d_idx[None, :] * stride_pred_d + pc_m = tl.load(pc_ptr, mask=mask_m[:, None] & mask_d[None, :], other=0.0) + + pc_t_ptr = pred_coords_t_cur + offs_n[:, None] * stride_pred_t_n + d_idx[None, :] * stride_pred_t_d + pc_t_n = tl.load(pc_t_ptr, mask=mask_n[:, None] & mask_d[None, :], other=0.0) + + diff_pred = pc_m[:, None, :] - pc_t_n[None, :, :] + pred_dist_sq = tl.sum(diff_pred * diff_pred, axis=2) + + mask_0 = (d_idx == 0)[None, None, :] + mask_1 = (d_idx == 1)[None, None, :] + mask_2 = (d_idx == 2)[None, None, :] + + diff_pred_x = tl.sum(diff_pred * mask_0, axis=2) + diff_pred_y = tl.sum(diff_pred * mask_1, axis=2) + diff_pred_z = tl.sum(diff_pred * mask_2, axis=2) + + true_dist = tl.sqrt(true_dist_sq) + pred_dist = tl.sqrt(pred_dist_sq) + + # Cutoff Mask + cutoff = tl.where(is_nuc[:, None], nucleic_acid_cutoff, other_cutoff) + dist_mask = true_dist < cutoff + final_mask = combined_mask * dist_mask + + # --- Backward Computation --- + + # Compute d_eps_d_diff in fp32 for precision + dist_diff = tl.abs(true_dist - pred_dist).to(tl.float32) + d_eps_d_diff = tl.zeros([BLOCK, BLOCK], dtype=tl.float32) + + # Loop over cutoffs: 0.5, 1.0, 2.0, 4.0 + # Unrolling for simplicity/speed in Triton + # 0.5 + val = 0.5 - dist_diff + sig = tl.sigmoid(val) + d_eps_d_diff -= sig * (1.0 - sig) + # 1.0 + val = 1.0 - dist_diff + sig = tl.sigmoid(val) + d_eps_d_diff -= sig * (1.0 - sig) + # 2.0 + val = 2.0 - dist_diff + sig = tl.sigmoid(val) + d_eps_d_diff -= sig * (1.0 - sig) + # 4.0 + val = 4.0 - dist_diff + sig = tl.sigmoid(val) + d_eps_d_diff -= sig * (1.0 - sig) + + d_eps_d_diff *= 0.25 + + # sign(pred - true) + # if pred > true: 1, else -1. But be careful about 0. + # Using tl.where + # sign_diff = tl.where(pred_dist > true_dist, 1.0, -1.0) + # Actually torch.sign returns 0 for 0. + diff_dist = pred_dist - true_dist + sign_diff = tl.where(diff_dist > 0, 1.0, tl.where(diff_dist < 0, -1.0, 0.0)) + + # d_L_d_pred_dists + # grad_num is scalar broadcasted + d_L_d_pred_dists = grad_num * final_mask * d_eps_d_diff * sign_diff + + # diff_dir = diff_vec / (pred_dist + 1e-8) + # pred_dist_safe + pred_dist_safe = pred_dist + 1e-8 + inv_dist = 1.0 / pred_dist_safe + + # Compute grad_vec (shape_coords_axis_1^2, 3) effectively + # factor = d_L_d_pred_dists * inv_dist + factor = d_L_d_pred_dists * inv_dist + + d_L_d_diff_x = factor * diff_pred_x + d_L_d_diff_y = factor * diff_pred_y + d_L_d_diff_z = factor * diff_pred_z + + # Accumulate gradients locally + # grad_pred_local (M, 3) = sum_over_N(d_L_d_diff) + # grad_pred_t_local (shape_coords_axis_1, 3) = sum_over_M(-d_L_d_diff) + + # Sum over shape_coords_axis_1 (cols) for grad_pred_local + grad_x_m = tl.sum(d_L_d_diff_x, axis=1) + grad_y_m = tl.sum(d_L_d_diff_y, axis=1) + grad_z_m = tl.sum(d_L_d_diff_z, axis=1) + + # Sum over M (rows) for grad_pred_t_local + # Note: grad_pred_t_local is -sum + grad_x_n = tl.sum(d_L_d_diff_x, axis=0) + grad_y_n = tl.sum(d_L_d_diff_y, axis=0) + grad_z_n = tl.sum(d_L_d_diff_z, axis=0) + + # Atomic Add to output buffers + # grad_pred_local: (B, M, 3) + dtype = grad_pred_local_ptr.dtype.element_ty + tl.atomic_add(grad_pred_local_cur + offs_m * stride_grad_pred_n + 0, grad_x_m.to(dtype), mask=mask_m) + tl.atomic_add(grad_pred_local_cur + offs_m * stride_grad_pred_n + 1, grad_y_m.to(dtype), mask=mask_m) + tl.atomic_add(grad_pred_local_cur + offs_m * stride_grad_pred_n + 2, grad_z_m.to(dtype), mask=mask_m) + + # grad_pred_t_local: (B, shape_coords_axis_1, 3) + # Negate + tl.atomic_add(grad_pred_t_local_cur + offs_n * stride_grad_pred_t_n + 0, (-grad_x_n).to(dtype), mask=mask_n) + tl.atomic_add(grad_pred_t_local_cur + offs_n * stride_grad_pred_t_n + 1, (-grad_y_n).to(dtype), mask=mask_n) + tl.atomic_add(grad_pred_t_local_cur + offs_n * stride_grad_pred_t_n + 2, (-grad_z_n).to(dtype), mask=mask_n) diff --git a/src/boltz/distributed/model/loss/validation.py b/src/boltz/distributed/model/loss/validation.py new file mode 100644 index 000000000..fd2e10f72 --- /dev/null +++ b/src/boltz/distributed/model/loss/validation.py @@ -0,0 +1,980 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from collections import defaultdict + +import torch +from torch import Tensor +from torch.distributed.tensor import DTensor + +from boltz.data import const +from boltz.distributed.model.layers.atom_to_token import reconstruct_atom_to_token_global, single_repr_token_to_atom +from boltz.distributed.model.layers.elementwise_op import ( + ElementwiseOp, + elementwise_op, + scalar_tensor_op, +) +from boltz.distributed.model.layers.sharded_op import sharded_sum +from boltz.distributed.model.layers.shardwise_op import shardwise_sum +from boltz.distributed.model.layers.squeeze import shardwise_squeeze +from boltz.distributed.model.loss.diffusion import ( + weighted_rigid_align as dtensor_weighted_rigid_align, +) +from boltz.distributed.model.loss.triton.cdist_lddt import cdist_lddt +from boltz.distributed.model.validation.utils import gather_along_cp +from boltz.model.loss.confidence import ( + lddt_dist, +) + + +def clash_score( + coords_repr: torch.Tensor, + token_pad_mask: torch.Tensor, + multiplicity: int, + clash_cutoff: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute per-sample token clash count and fraction from representative atom coordinates. + + Parameters + ---------- + coords_repr : torch.Tensor + Representative atom coordinates, shape ``[B*mul, N_tokens, 3]``. + token_pad_mask : torch.Tensor + Token padding mask, shape ``[B, N_tokens]``. + multiplicity : int + Diffusion multiplicity. + clash_cutoff : float + Distance cutoff for defining a clash in Angstrom. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + - clash_atoms_count: clashing tokens per sample, shape ``[B*mul]`` + - clash_atoms_fraction: fraction of clashing valid tokens, shape ``[B*mul]`` + """ + _, _, clash_denom = cdist_lddt( + pred_coords_row=coords_repr, + pred_coords_col=coords_repr, + true_coords_row=coords_repr, + true_coords_col=coords_repr, + mask_row=token_pad_mask.float(), + mask_col=token_pad_mask.float(), + multiplicity=multiplicity, + cutoff=clash_cutoff, + do_mask_diagonal=True, + per_atom=True, + return_denom=True, + ) + B = token_pad_mask.shape[0] + clash_denom_reshaped = clash_denom.reshape(B, multiplicity, -1) + token_mask_bc = token_pad_mask[:, None, :] + + clash_atoms_count_2d = ((clash_denom_reshaped > 0) & token_mask_bc).sum(dim=2) + clash_atoms_total = token_pad_mask.sum(dim=1).clamp(min=1)[:, None] + + clash_atoms_count = clash_atoms_count_2d.reshape(-1) + clash_atoms_fraction = (clash_atoms_count_2d / clash_atoms_total).reshape(-1) + return clash_atoms_count, clash_atoms_fraction + + +def factored_lddt_loss( + true_atom_coords, + pred_atom_coords, + feats, + atom_mask, + multiplicity=1, + cardinality_weighted=False, +): + """Compute the lddt factorized into the different modalities. + + Uses triton kernel cdist_lddt to compute the lddt. + + Parameters + ---------- + true_atom_coords : torch.Tensor + Ground truth atom coordinates after symmetry correction, shape [B*mul, N_atoms, 3] + pred_atom_coords : torch.Tensor + Predicted atom coordinates, shape [B*mul, N_atoms, 3] + feats : Dict[str, torch.Tensor] + Input features with token-level tensors at base batch size B + atom_mask : torch.Tensor + Atom mask, shape [B, N_atoms] or [B*mul, N_atoms]. If [B*mul, N_atoms], + the mask is downsampled by taking every multiplicity-th row. + multiplicity : int + Diffusion batch size, by default 1 + cardinality_weighted : bool + If True, use the cardinality weighted loss, defaults to False + + Returns + ------- + Dict[str, torch.Tensor] + The lddt for each modality, each tensor shape [B*mul] + Dict[str, torch.Tensor] + The total number of pairs for each modality, each tensor shape [B*mul] + + """ + # extract necessary features + atom_type = torch.bmm(feats["atom_to_token"].float(), feats["mol_type"].unsqueeze(-1).float()).squeeze(-1).long() + + # Use base masks and rely on cdist_lddt broadcasting across multiplicity. + if atom_mask.shape[0] == atom_type.shape[0]: + atom_mask_base = atom_mask + elif atom_mask.shape[0] == atom_type.shape[0] * multiplicity: + atom_mask_base = atom_mask[::multiplicity] # to match with atom_type shape + else: + raise ValueError( + "atom_mask batch dimension must be B or B*mul " + f"(got {atom_mask.shape[0]} vs B={atom_type.shape[0]}, mul={multiplicity})" + ) + + input_dtype = pred_atom_coords.dtype + compute_dtype = torch.promote_types(input_dtype, torch.float32) + + ligand_mask = (atom_type == const.chain_type_ids["NONPOLYMER"]).to(dtype=compute_dtype) + dna_mask = (atom_type == const.chain_type_ids["DNA"]).to(dtype=compute_dtype) + rna_mask = (atom_type == const.chain_type_ids["RNA"]).to(dtype=compute_dtype) + protein_mask = (atom_type == const.chain_type_ids["PROTEIN"]).to(dtype=compute_dtype) + + atom_mask_base = atom_mask_base.to(dtype=compute_dtype) + pred_atom_coords = pred_atom_coords.to(dtype=compute_dtype) + true_atom_coords = true_atom_coords.to(dtype=compute_dtype) + + def score_and_total(mask_row, mask_col, cutoff, symmetrize=False): + score, total = cdist_lddt( + pred_coords_row=pred_atom_coords, + pred_coords_col=pred_atom_coords, + true_coords_row=true_atom_coords, + true_coords_col=true_atom_coords, + mask_row=mask_row, + mask_col=mask_col, + multiplicity=multiplicity, + cutoff=cutoff, + do_mask_diagonal=True, + per_atom=False, + return_denom=True, + ) + if symmetrize: + total = total * 2 + score = torch.where(total > 0, score, torch.ones_like(score)) + return score, total + + mask_dna = atom_mask_base * dna_mask + mask_rna = atom_mask_base * rna_mask + mask_ligand = atom_mask_base * ligand_mask + mask_protein = atom_mask_base * protein_mask + + dna_protein_lddt, dna_protein_total = score_and_total(mask_dna, mask_protein, cutoff=30.0, symmetrize=True) + rna_protein_lddt, rna_protein_total = score_and_total(mask_rna, mask_protein, cutoff=30.0, symmetrize=True) + ligand_protein_lddt, ligand_protein_total = score_and_total(mask_ligand, mask_protein, cutoff=15.0, symmetrize=True) + dna_ligand_lddt, dna_ligand_total = score_and_total(mask_dna, mask_ligand, cutoff=30.0, symmetrize=True) + rna_ligand_lddt, rna_ligand_total = score_and_total(mask_rna, mask_ligand, cutoff=30.0, symmetrize=True) + + intra_dna_lddt, intra_dna_total = score_and_total(mask_dna, mask_dna, cutoff=30.0) + intra_rna_lddt, intra_rna_total = score_and_total(mask_rna, mask_rna, cutoff=30.0) + + chain_id = feats["asym_id"] + atom_chain_id = torch.bmm(feats["atom_to_token"].float(), chain_id.unsqueeze(-1).float()).squeeze(-1).long() + + chain_ids = torch.unique(atom_chain_id[atom_mask_base.bool()]) + + def accumulate_chain_scores(base_mask, cutoff): + score_sum = torch.zeros(true_atom_coords.shape[0], device=true_atom_coords.device) + total_sum = torch.zeros_like(score_sum) + for chain_value in chain_ids.tolist(): + chain_mask = (atom_chain_id == chain_value).float() + mask_chain = base_mask * chain_mask + if not torch.any(mask_chain): + continue + score, total = score_and_total(mask_chain, mask_chain, cutoff=cutoff) + score_sum = score_sum + score * total + total_sum = total_sum + total + score = torch.where( + total_sum > 0, + score_sum / (total_sum + 1e-10), + torch.ones_like(score_sum), + ) + return score, total_sum + + intra_ligand_lddt, intra_ligand_total = accumulate_chain_scores(mask_ligand, cutoff=15.0) + intra_protein_lddt, intra_protein_total = accumulate_chain_scores(mask_protein, cutoff=15.0) + + protein_score_sum = torch.zeros(true_atom_coords.shape[0], device=true_atom_coords.device) + protein_total_sum = torch.zeros_like(protein_score_sum) + chain_values = chain_ids.tolist() + for i, chain_i in enumerate(chain_values): + mask_i = mask_protein * (atom_chain_id == chain_i).float() + if not torch.any(mask_i): + continue + for chain_j in chain_values[i + 1 :]: + mask_j = mask_protein * (atom_chain_id == chain_j).float() + if not torch.any(mask_j): + continue + score, total = score_and_total(mask_i, mask_j, cutoff=15.0, symmetrize=True) + protein_score_sum = protein_score_sum + score * total + protein_total_sum = protein_total_sum + total + protein_protein_lddt = torch.where( + protein_total_sum > 0, + protein_score_sum / (protein_total_sum + 1e-10), + torch.ones_like(protein_score_sum), + ) + protein_protein_total = protein_total_sum + + lddt_dict = { + "dna_protein": dna_protein_lddt, + "rna_protein": rna_protein_lddt, + "ligand_protein": ligand_protein_lddt, + "dna_ligand": dna_ligand_lddt, + "rna_ligand": rna_ligand_lddt, + "intra_ligand": intra_ligand_lddt, + "intra_dna": intra_dna_lddt, + "intra_rna": intra_rna_lddt, + "intra_protein": intra_protein_lddt, + "protein_protein": protein_protein_lddt, + } + + total_dict = { + "dna_protein": dna_protein_total, + "rna_protein": rna_protein_total, + "ligand_protein": ligand_protein_total, + "dna_ligand": dna_ligand_total, + "rna_ligand": rna_ligand_total, + "intra_ligand": intra_ligand_total, + "intra_dna": intra_dna_total, + "intra_rna": intra_rna_total, + "intra_protein": intra_protein_total, + "protein_protein": protein_protein_total, + } + if not cardinality_weighted: + for key in total_dict: + total_dict[key] = (total_dict[key] > 0.0).to(dtype=input_dtype) + + lddt_dict = {key: value.to(dtype=input_dtype) for key, value in lddt_dict.items()} + total_dict = {key: value.to(dtype=input_dtype) for key, value in total_dict.items()} + + return lddt_dict, total_dict + + +def factored_token_lddt_dist_loss_triton( + pred_token_coords, + true_token_coords, + mol_type, + token_disto_mask, + asym_id, + multiplicity=1, + cardinality_weighted=False, + pred_d=None, + true_d=None, +): + """Compute the distogram lddt factorized into different modalities using cdist_lddt. + + Token-level analogue of factored_lddt_loss. When coordinates are provided for + both sides, uses the cdist_lddt triton kernel to compute pairwise distances + on-the-fly, avoiding O(N^2) materialization. When pre-computed distance + matrices are provided (e.g. from a distogram prediction), uses those directly. + + Parameters + ---------- + pred_token_coords : torch.Tensor + Predicted token representative coordinates, shape [B*mul, N_tokens, 3]. + true_token_coords : torch.Tensor + Ground truth token representative coordinates, shape [B*mul, N_tokens, 3]. + mol_type : torch.Tensor + Molecule type per token, shape [B, N_tokens]. + token_disto_mask : torch.Tensor + Token validity mask for distogram, shape [B, N_tokens]. + asym_id : torch.Tensor + Chain (asymmetric unit) identifier per token, shape [B, N_tokens]. + multiplicity : int + Diffusion multiplicity (B_mul = B * multiplicity), by default 1. + cardinality_weighted : bool + If True, return raw pair counts; if False, binarize totals. Default False. + pred_d : torch.Tensor, optional + Pre-computed predicted distance matrix, shape [B, N_tokens, N_tokens]. + When provided, overrides pred_token_coords for the predicted distances. + true_d : torch.Tensor, optional + Pre-computed true distance matrix, shape [B, N_tokens, N_tokens]. + When provided, overrides true_token_coords for the true distances. + + Returns + ------- + dict[str, torch.Tensor] + LDDT score per modality, each shape [B*mul]. + dict[str, torch.Tensor] + Total (pair count or binary indicator) per modality, each shape [B*mul]. + + """ + + input_dtype = pred_token_coords.dtype + compute_dtype = torch.promote_types(input_dtype, torch.float32) + + use_dists = pred_d is not None or true_d is not None + + token_mask = token_disto_mask.to(dtype=compute_dtype) + true_token_coords = true_token_coords.to(dtype=compute_dtype) + + ligand_mask = (mol_type == const.chain_type_ids["NONPOLYMER"]).to(dtype=compute_dtype) + dna_mask = (mol_type == const.chain_type_ids["DNA"]).to(dtype=compute_dtype) + rna_mask = (mol_type == const.chain_type_ids["RNA"]).to(dtype=compute_dtype) + protein_mask = (mol_type == const.chain_type_ids["PROTEIN"]).to(dtype=compute_dtype) + + nucleotide_mask = dna_mask + rna_mask + + mask_dna = token_mask * dna_mask + mask_rna = token_mask * rna_mask + mask_ligand = token_mask * ligand_mask + mask_protein = token_mask * protein_mask + + if use_dists: + pairwise_mask = token_mask[:, :, None] * token_mask[:, None, :] + pairwise_mask = pairwise_mask * (1 - torch.eye(token_mask.shape[1], device=token_mask.device)[None]).to( + pairwise_mask + ) + cutoff_matrix = 15 + 15 * (1 - (1 - nucleotide_mask[:, :, None]) * (1 - nucleotide_mask[:, None, :])) + eff_pred_d = pred_d if pred_d is not None else torch.cdist(pred_token_coords, pred_token_coords) + eff_true_d = true_d if true_d is not None else torch.cdist(true_token_coords, true_token_coords) + + def score_and_total(mask_row, mask_col, cutoff, symmetrize=False): + mask_2d = pairwise_mask * (mask_row[:, :, None] * mask_col[:, None, :]) + if symmetrize: + mask_2d = mask_2d + pairwise_mask * (mask_col[:, :, None] * mask_row[:, None, :]) + # Keep the same API as the cdist path; this branch intentionally uses + # per-pair cutoffs from cutoff_matrix instead of the scalar cutoff. + score, total = lddt_dist(eff_pred_d, eff_true_d, mask_2d, cutoff_matrix) + score = torch.where(total > 0, score, torch.ones_like(score)) + return score, total + else: + + def score_and_total(mask_row, mask_col, cutoff, symmetrize=False): + score, total = cdist_lddt( + pred_coords_row=pred_token_coords, + pred_coords_col=pred_token_coords, + true_coords_row=true_token_coords, + true_coords_col=true_token_coords, + mask_row=mask_row, + mask_col=mask_col, + multiplicity=multiplicity, + cutoff=cutoff, + do_mask_diagonal=True, + per_atom=False, + return_denom=True, + ) + if symmetrize: + total = total * 2 + score = torch.where(total > 0, score, torch.ones_like(score)) + return score, total + + dna_protein_lddt, dna_protein_total = score_and_total(mask_dna, mask_protein, cutoff=30.0, symmetrize=True) + rna_protein_lddt, rna_protein_total = score_and_total(mask_rna, mask_protein, cutoff=30.0, symmetrize=True) + ligand_protein_lddt, ligand_protein_total = score_and_total(mask_ligand, mask_protein, cutoff=15.0, symmetrize=True) + dna_ligand_lddt, dna_ligand_total = score_and_total(mask_dna, mask_ligand, cutoff=30.0, symmetrize=True) + rna_ligand_lddt, rna_ligand_total = score_and_total(mask_rna, mask_ligand, cutoff=30.0, symmetrize=True) + + intra_dna_lddt, intra_dna_total = score_and_total(mask_dna, mask_dna, cutoff=30.0) + intra_rna_lddt, intra_rna_total = score_and_total(mask_rna, mask_rna, cutoff=30.0) + + chain_ids = torch.unique(asym_id[token_mask.bool()]) + + def accumulate_chain_scores(base_mask, cutoff): + score_sum = torch.zeros(true_token_coords.shape[0], device=true_token_coords.device) + total_sum = torch.zeros_like(score_sum) + for chain_value in chain_ids.tolist(): + chain_mask = (asym_id == chain_value).float() + mask_chain = base_mask * chain_mask + if not torch.any(mask_chain): + continue + score, total = score_and_total(mask_chain, mask_chain, cutoff=cutoff) + score_sum = score_sum + score * total + total_sum = total_sum + total + score = torch.where( + total_sum > 0, + score_sum / (total_sum + 1e-10), + torch.ones_like(score_sum), + ) + return score, total_sum + + intra_ligand_lddt, intra_ligand_total = accumulate_chain_scores(mask_ligand, cutoff=15.0) + intra_protein_lddt, intra_protein_total = accumulate_chain_scores(mask_protein, cutoff=15.0) + + protein_score_sum = torch.zeros(true_token_coords.shape[0], device=true_token_coords.device) + protein_total_sum = torch.zeros_like(protein_score_sum) + chain_values = chain_ids.tolist() + for i, chain_i in enumerate(chain_values): + mask_i = mask_protein * (asym_id == chain_i).float() + if not torch.any(mask_i): + continue + for chain_j in chain_values[i + 1 :]: + mask_j = mask_protein * (asym_id == chain_j).float() + if not torch.any(mask_j): + continue + score, total = score_and_total(mask_i, mask_j, cutoff=15.0, symmetrize=True) + protein_score_sum = protein_score_sum + score * total + protein_total_sum = protein_total_sum + total + protein_protein_lddt = torch.where( + protein_total_sum > 0, + protein_score_sum / (protein_total_sum + 1e-10), + torch.ones_like(protein_score_sum), + ) + protein_protein_total = protein_total_sum + + lddt_dict = { + "dna_protein": dna_protein_lddt, + "rna_protein": rna_protein_lddt, + "ligand_protein": ligand_protein_lddt, + "dna_ligand": dna_ligand_lddt, + "rna_ligand": rna_ligand_lddt, + "intra_ligand": intra_ligand_lddt, + "intra_dna": intra_dna_lddt, + "intra_rna": intra_rna_lddt, + "intra_protein": intra_protein_lddt, + "protein_protein": protein_protein_lddt, + } + + total_dict = { + "dna_protein": dna_protein_total, + "rna_protein": rna_protein_total, + "ligand_protein": ligand_protein_total, + "dna_ligand": dna_ligand_total, + "rna_ligand": rna_ligand_total, + "intra_ligand": intra_ligand_total, + "intra_dna": intra_dna_total, + "intra_rna": intra_rna_total, + "intra_protein": intra_protein_total, + "protein_protein": protein_protein_total, + } + if not cardinality_weighted: + for key in total_dict: + total_dict[key] = (total_dict[key] > 0.0).to(dtype=input_dtype) + + lddt_dict = {key: value.to(dtype=input_dtype) for key, value in lddt_dict.items()} + total_dict = {key: value.to(dtype=input_dtype) for key, value in total_dict.items()} + + return lddt_dict, total_dict + + +def compute_disto_lddt( + model, + batch: dict[str, DTensor], + out: dict[str, DTensor], +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + """Compute distogram LDDT by gathering DTensors and calling the triton kernel. + + Distributed override of the serial Validator.compute_disto_lddt. Gathers the + required DTensors to plain tensors, converts predicted distograms to distance + matrices, then evaluates factored token LDDT for each (distogram, conformer) pair. + + Parameters + ---------- + model + The model, providing min_dist, max_dist, num_bins, num_distograms attributes. + batch : dict[str, DTensor] + Batch features as DTensors. Must contain: + - "disto_coords_ensemble": [K, N_tokens, 3] or [B, K*N_tokens, 3] + - "mol_type": [B, N_tokens] + - "token_disto_mask": [B, N_tokens] + - "asym_id": [B, N_tokens] + out : dict[str, DTensor] + Model outputs as DTensors. Must contain: + - "pdistogram": [B, N, N, D, bins] + + Returns + ------- + dict[str, torch.Tensor] + LDDT score per modality, each shape [1] (min over D, mean over K). + dict[str, torch.Tensor] + Total per modality, each shape [1]. + + """ + + disto_coords_ensemble = gather_along_cp(batch["disto_coords_ensemble"]) + mol_type = gather_along_cp(batch["mol_type"]) + token_disto_mask = gather_along_cp(batch["token_disto_mask"]) + asym_id = gather_along_cp(batch["asym_id"]) + pdistogram = gather_along_cp(out["pdistogram"]) + + boundaries = torch.linspace( + model.min_dist, + model.max_dist, + model.num_bins - 1, + device=pdistogram.device, + dtype=pdistogram.dtype, + ) + lower = torch.tensor([1.0], device=pdistogram.device, dtype=pdistogram.dtype) + upper = torch.tensor([model.max_dist + 5.0], device=pdistogram.device, dtype=pdistogram.dtype) + exp_boundaries = torch.cat((lower, boundaries, upper)) + mid_points = (exp_boundaries[:-1] + exp_boundaries[1:]) / 2 + + if "coords" in batch: + K = gather_along_cp(batch["coords"]).shape[1] + elif hasattr(model, "num_conformers"): + K = model.num_conformers + else: + raise ValueError("Unable to infer conformer count: expected `batch['coords']` or `model.num_conformers`.") + true_center = disto_coords_ensemble.reshape(K, -1, 3) + + D = model.num_distograms + device = pdistogram.device + + disto_lddt_dict = defaultdict(lambda: torch.zeros(K, D, device=device)) + disto_total_dict = defaultdict(lambda: torch.zeros(K, D, device=device)) + + for i in range(D): + preds = pdistogram[:, :, :, i] + pred_dist_i = mid_points[preds.argmax(dim=-1)] + + for k in range(K): + true_center_k = true_center[k].unsqueeze(0) + + lddt_dict_, total_dict_ = factored_token_lddt_dist_loss_triton( + pred_token_coords=true_center_k, + true_token_coords=true_center_k, + mol_type=mol_type, + token_disto_mask=token_disto_mask, + asym_id=asym_id, + pred_d=pred_dist_i, + ) + + for key in lddt_dict_: + disto_lddt_dict[key][k, i] = lddt_dict_[key].item() + disto_total_dict[key][k, i] = total_dict_[key].item() + + for key in disto_lddt_dict: + disto_lddt_dict[key] = disto_lddt_dict[key].min(dim=1).values.mean(dim=0)[None] + disto_total_dict[key] = disto_total_dict[key].min(dim=1).values.mean(dim=0)[None] + + return disto_lddt_dict, disto_total_dict + + +def get_lddt_metrics( + atom_to_token_dtensor: DTensor, + num_conformers: int, + n_samples: int, + true_coords: torch.Tensor, + true_coords_resolved_mask: torch.Tensor, + mol_type: torch.Tensor, + asym_id: torch.Tensor, + sample_atom_coords: torch.Tensor, + expand_to_diffusion_samples: bool, +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + """Compute factored LDDT metrics by gathering DTensors and calling the triton kernel. + + Distributed override of the serial Validator.get_lddt_metrics. Gathers + atom_to_token on-demand (it is large) while reusing pre-gathered mol_type, + asym_id, and sample_atom_coords. + + Parameters + ---------- + atom_to_token_dtensor : DTensor + Sharded atom-to-token mapping DTensor with placements + (Shard(0), Shard(1), Replicate()). + num_conformers : int + Number of conformers (K) in the ensemble. + n_samples : int + Number of diffusion samples (multiplicity). + true_coords : torch.Tensor + Ground truth atom coordinates + Shape [B*mul, K, N_atoms, 3] if ``expand_to_diffusion_samples=True``, + else [K, N_atoms, 3]. + true_coords_resolved_mask : torch.Tensor + Resolved atom mask + Shape [B*mul, N_atoms] when ``expand_to_diffusion_samples=True``, + else [N_atoms]. + mol_type : torch.Tensor + Pre-gathered molecule type per token, shape [B, N_tokens]. + asym_id : torch.Tensor + Pre-gathered chain ID per token, shape [B, N_tokens]. + sample_atom_coords : torch.Tensor + Pre-gathered predicted atom coordinates, shape [B*mul, N_atoms, 3]. + expand_to_diffusion_samples : bool + If True, true coordinates/masks are already expanded to diffusion + samples. If False, this function repeats the non-expanded resolved mask + across ``n_samples`` to match ``sample_atom_coords``. + + Returns + ------- + dict[str, torch.Tensor] + LDDT score per modality, each shape [B*mul, K]. + dict[str, torch.Tensor] + Total per modality, each shape [B*mul, K]. + + """ + + atom_to_token = reconstruct_atom_to_token_global(atom_to_token_dtensor) + + if sample_atom_coords.ndim != 3: + raise ValueError( + f"sample_atom_coords must be rank 3 ([B*mul, N_atoms, 3]) (got ndim={sample_atom_coords.ndim})" + ) + + if expand_to_diffusion_samples: + if true_coords.ndim != 4: + raise ValueError( + "true_coords must be rank 4 ([B*mul, K, N_atoms, 3]) when " + f"expand_to_diffusion_samples=True (got ndim={true_coords.ndim})" + ) + if true_coords_resolved_mask.ndim != 2: + raise ValueError( + "true_coords_resolved_mask must be rank 2 when expand_to_diffusion_samples=True " + f"(got ndim={true_coords_resolved_mask.ndim})" + ) + true_coords_K = true_coords.shape[1] + true_coords_n_atoms = true_coords.shape[2] + true_coords_batch = true_coords.shape[0] + else: + if true_coords.ndim != 3: + raise ValueError( + "true_coords must be rank 3 ([K, N_atoms, 3]) when " + f"expand_to_diffusion_samples=False (got ndim={true_coords.ndim})" + ) + if true_coords_resolved_mask.ndim != 1: + raise ValueError( + "true_coords_resolved_mask must be rank 1 when expand_to_diffusion_samples=False " + f"(got ndim={true_coords_resolved_mask.ndim})" + ) + true_coords_K = true_coords.shape[0] + true_coords_n_atoms = true_coords.shape[1] + true_coords_batch = None + true_coords_resolved_mask = true_coords_resolved_mask.unsqueeze(0).repeat((n_samples, 1)) + + if atom_to_token.shape[0] != 1: + raise ValueError( + "get_lddt_metrics currently expects local batch size 1 after atom_to_token reconstruction " + f"(got atom_to_token.shape[0]={atom_to_token.shape[0]})" + ) + if atom_to_token.shape[0] * n_samples != sample_atom_coords.shape[0]: + raise ValueError( + "sample_atom_coords batch must equal atom_to_token batch * n_samples " + f"(got sample_atom_coords.shape[0]={sample_atom_coords.shape[0]}, " + f"atom_to_token.shape[0]={atom_to_token.shape[0]}, n_samples={n_samples})" + ) + + K = num_conformers + if true_coords_K != K: + raise ValueError(f"true_coords conformer count ({true_coords_K}) != num_conformers ({K})") + if true_coords_batch is not None and true_coords_batch != sample_atom_coords.shape[0]: + raise ValueError( + f"true_coords batch dim ({true_coords_batch}) != " + f"sample_atom_coords batch dim ({sample_atom_coords.shape[0]})" + ) + + N_atoms = atom_to_token.shape[1] + N_tokens = atom_to_token.shape[2] + if mol_type.shape[-1] != N_tokens: + raise ValueError(f"mol_type N_tokens ({mol_type.shape[-1]}) != atom_to_token N_tokens ({N_tokens})") + if asym_id.shape[-1] != N_tokens: + raise ValueError(f"asym_id N_tokens ({asym_id.shape[-1]}) != atom_to_token N_tokens ({N_tokens})") + if sample_atom_coords.shape[1] != N_atoms: + raise ValueError( + f"sample_atom_coords N_atoms ({sample_atom_coords.shape[1]}) != atom_to_token N_atoms ({N_atoms})" + ) + if true_coords_resolved_mask.shape[1] != N_atoms: + raise ValueError( + f"true_coords_resolved_mask N_atoms ({true_coords_resolved_mask.shape[1]}) != " + f"atom_to_token N_atoms ({N_atoms})" + ) + if true_coords_n_atoms != N_atoms: + raise ValueError(f"true_coords N_atoms ({true_coords_n_atoms}) != atom_to_token N_atoms ({N_atoms})") + + feats = { + "atom_to_token": atom_to_token, + "mol_type": mol_type, + "asym_id": asym_id, + } + + all_lddt_dict = defaultdict(list) + all_total_dict = defaultdict(list) + for ensemble_idx in range(K): + if expand_to_diffusion_samples: + true_coords_k = true_coords[:, ensemble_idx] + else: + true_coords_k = true_coords[ensemble_idx].unsqueeze(0).repeat((n_samples, 1, 1)) + + lddt_dict_k, total_dict_k = factored_lddt_loss( + true_atom_coords=true_coords_k, + pred_atom_coords=sample_atom_coords, + feats=feats, + atom_mask=true_coords_resolved_mask, + multiplicity=n_samples, + ) + for key in lddt_dict_k: + all_lddt_dict[key].append(lddt_dict_k[key]) + all_total_dict[key].append(total_dict_k[key]) + + for key in all_lddt_dict: + all_lddt_dict[key] = torch.stack(all_lddt_dict[key], dim=1) + all_total_dict[key] = torch.stack(all_total_dict[key], dim=1) + + return dict(all_lddt_dict), dict(all_total_dict) + + +def weighted_minimum_rmsd_single( + pred_atom_coords: DTensor, + atom_coords: DTensor, + atom_mask: DTensor, + atom_to_token: DTensor, + mol_type: DTensor, + nucleotide_weight: float = 5.0, + ligand_weight: float = 10.0, +) -> tuple[DTensor, DTensor, DTensor]: + """Compute rmsd of the aligned atom coordinates using DTensor operations. + + This is the distributed version that operates on DTensors with placements + (Shard(0), Shard(1), Replicate()) for coords and (Shard(0), Shard(1), Replicate()) + for 2D features. + + Parameters + ---------- + pred_atom_coords : DTensor + Predicted atom coordinates with shape (B, N_atoms, 3). + Placements: (Shard(0), Shard(1), Replicate()) + atom_coords : DTensor + Ground truth atom coordinates with shape (B, N_atoms, 3). + Placements: (Shard(0), Shard(1), Replicate()) + atom_mask : DTensor + Resolved atom mask with shape (B, N_atoms). + Placements: (Shard(0), Shard(1), Replicate()) + atom_to_token : DTensor + Atom to token mapping with shape (B, N_tokens, N_atoms). + Placements: (Shard(0), Shard(1), Replicate()) + mol_type : DTensor + Molecule type per token with shape (B, N_tokens). + Placements: (Shard(0), Shard(1), Replicate()) + nucleotide_weight : float + Weight for nucleotide atoms in RMSD computation. + ligand_weight : float + Weight for ligand atoms in RMSD computation. + + Returns + ------- + tuple[DTensor, DTensor, DTensor] + - rmsd: The RMSD value with shape (B,). Placements: (Shard(0), Replicate(), Replicate()) + - atom_coords_aligned: The aligned coordinates with shape (B, N_atoms, 3). + Placements: (Shard(0), Shard(1), Replicate()) + - align_weights: The alignment weights with shape (B, N_atoms). + Placements: (Shard(0), Shard(1), Replicate()) + + """ + # Validate inputs are DTensors + if not isinstance(pred_atom_coords, DTensor): + raise TypeError(f"pred_atom_coords must be DTensor, got {type(pred_atom_coords)}") + if not isinstance(atom_coords, DTensor): + raise TypeError(f"atom_coords must be DTensor, got {type(atom_coords)}") + if not isinstance(atom_mask, DTensor): + raise TypeError(f"atom_mask must be DTensor, got {type(atom_mask)}") + if not isinstance(atom_to_token, DTensor): + raise TypeError(f"atom_to_token must be DTensor, got {type(atom_to_token)}") + if not isinstance(mol_type, DTensor): + raise TypeError(f"mol_type must be DTensor, got {type(mol_type)}") + + device_mesh = pred_atom_coords.device_mesh + + # Convert dtypes as needed + dtype = pred_atom_coords.dtype + + # Compute atom_type by mapping mol_type (token-level) to atom-level + # atom_type has shape (B, N_atoms) - placements: (Shard(0), Shard(1), Replicate()) + atom_type = single_repr_token_to_atom( + mol_type.to(dtype=dtype).unsqueeze(-1), # (B, N_tokens, 1) + atom_to_token.to(dtype=dtype), # (B, N_tokens, N_atoms) + ) # (B, N_atoms, 1) + atom_type = shardwise_squeeze(atom_type, dim=-1) # (B, N_atoms) + + # Compute nucleotide mask: is_DNA OR is_RNA + is_dna = scalar_tensor_op(float(const.chain_type_ids["DNA"]), atom_type, ElementwiseOp.EQUAL) + is_rna = scalar_tensor_op(float(const.chain_type_ids["RNA"]), atom_type, ElementwiseOp.EQUAL) + is_nucleotide = elementwise_op(is_dna, is_rna, ElementwiseOp.SUM) + + # Compute ligand mask + is_ligand = scalar_tensor_op(float(const.chain_type_ids["NONPOLYMER"]), atom_type, ElementwiseOp.EQUAL) + + # Compute weighted contributions + nucleotide_contribution = scalar_tensor_op(nucleotide_weight, is_nucleotide, ElementwiseOp.PROD) + ligand_contribution = scalar_tensor_op(ligand_weight, is_ligand, ElementwiseOp.PROD) + + # align_weights = 1 + nucleotide_weight * is_nucleotide + ligand_weight * is_ligand + align_weights = scalar_tensor_op( + 1.0, + elementwise_op(nucleotide_contribution, ligand_contribution, ElementwiseOp.SUM), + ElementwiseOp.SUM, + ) + + # Ensure atom_mask has correct placements + atom_mask_float = DTensor.from_local( + atom_mask.to_local().to(dtype=dtype), + device_mesh, + atom_mask.placements, + shape=atom_mask.shape, + stride=atom_mask.stride(), + ) + + # Perform weighted rigid alignment + with torch.no_grad(): + atom_coords_aligned_ground_truth = dtensor_weighted_rigid_align( + atom_coords.to(dtype=dtype), + pred_atom_coords.to(dtype=dtype), + align_weights.to(dtype=dtype), + mask=atom_mask_float, + ) + + # Compute MSE loss: ((pred - aligned_true) ** 2).sum(dim=-1) + diff = elementwise_op(pred_atom_coords, atom_coords_aligned_ground_truth, ElementwiseOp.SUB) + diff_sq = scalar_tensor_op(2.0, diff, ElementwiseOp.POW) + mse_loss = shardwise_sum(diff_sq, dim=-1) # sum over xyz -> (B, N_atoms) + + # Compute weighted MSE: mse_loss * align_weights * atom_mask + weighted_mse = elementwise_op(mse_loss, align_weights, ElementwiseOp.PROD) + weighted_mse = elementwise_op(weighted_mse, atom_mask_float, ElementwiseOp.PROD) + + # Compute denominator: align_weights * atom_mask + denom = elementwise_op(align_weights, atom_mask_float, ElementwiseOp.PROD) + + # Reduce along atom dimension + weighted_mse_sum = sharded_sum(weighted_mse, dim=-1) # (B,) + denom_sum = sharded_sum(denom, dim=-1) # (B,) + + # rmsd = sqrt(weighted_mse_sum / denom_sum) + ratio = elementwise_op(weighted_mse_sum, denom_sum, ElementwiseOp.DIV) + rmsd = scalar_tensor_op(0.5, ratio, ElementwiseOp.POW) # sqrt via x^0.5 + + return rmsd, atom_coords_aligned_ground_truth, align_weights + + +def compute_plddt_mae_triton( + pred_atom_coords: Tensor, + feats: dict[str, Tensor], + true_atom_coords: Tensor, + pred_lddt: Tensor, + true_coords_resolved_mask: Tensor, + multiplicity: int = 1, +) -> tuple[dict[str, Tensor], dict[str, Tensor]]: + """Compute pLDDT MAE using triton ``cdist_lddt`` in rectangular mode. + + Uses ``cdist_lddt`` with rectangular inputs (N_token rows x N_R_set cols) + and factored masks instead of materialising full distance matrices and + pair masks. + + Parameters + ---------- + pred_atom_coords : Tensor + Predicted atom coordinates, shape ``[B*mul, N_atom, 3]``. + feats : dict[str, Tensor] + Feature dict with ``token_to_rep_atom``, ``r_set_to_rep_atom``, + ``atom_to_token``, ``mol_type``. All at base batch B. + true_atom_coords : Tensor + Ground-truth atom coordinates, shape ``[B*mul, N_atom, 3]``. + pred_lddt : Tensor + Predicted per-token lDDT, shape ``[B*mul, N_token]``. + true_coords_resolved_mask : Tensor + Per-atom resolved mask, shape ``[B*mul, N_atom]``. + multiplicity : int + Diffusion sample count (B_mul = B * multiplicity). + + Returns + ------- + mae_plddt_dict : dict[str, Tensor] + Per-modality MAE, each a scalar. + total_dict : dict[str, Tensor] + Per-modality total weight, each a scalar. + """ + token_to_rep_atom = feats["token_to_rep_atom"].float() + r_set_to_rep_atom = feats["r_set_to_rep_atom"].float() + atom_to_token = feats["atom_to_token"].float() + mol_type = feats["mol_type"] + + if multiplicity > 1: + t2r_expanded = token_to_rep_atom.repeat_interleave(multiplicity, 0) + r2r_expanded = r_set_to_rep_atom.repeat_interleave(multiplicity, 0) + else: + t2r_expanded = token_to_rep_atom + r2r_expanded = r_set_to_rep_atom + + pred_token_coords = torch.bmm(t2r_expanded, pred_atom_coords) + true_token_coords = torch.bmm(t2r_expanded, true_atom_coords) + pred_R_coords = torch.bmm(r2r_expanded, pred_atom_coords) + true_R_coords = torch.bmm(r2r_expanded, true_atom_coords) + + # Masks at B*mul level so per-sample mask variation is preserved. + resolved = true_coords_resolved_mask.float() + mask_row = torch.bmm(t2r_expanded, resolved.unsqueeze(-1)).squeeze(-1) + mask_col = torch.bmm(r2r_expanded, resolved.unsqueeze(-1)).squeeze(-1) + + # Per-column cutoff based on nucleotide type (base batch B) + is_nucleotide_token = (mol_type == const.chain_type_ids["DNA"]).float() + ( + mol_type == const.chain_type_ids["RNA"] + ).float() + is_nucleotide_atom = torch.bmm(atom_to_token, is_nucleotide_token.unsqueeze(-1)).squeeze(-1) + is_nucleotide_R = torch.bmm(r_set_to_rep_atom, is_nucleotide_atom.unsqueeze(-1)).squeeze(-1) + cutoff_col = 15.0 + 15.0 * is_nucleotide_R + + atom_indices_row = token_to_rep_atom.argmax(dim=-1) + atom_indices_col = r_set_to_rep_atom.argmax(dim=-1) + + target_lddt, mask_no_match = cdist_lddt( + pred_coords_row=pred_token_coords, + pred_coords_col=pred_R_coords, + true_coords_row=true_token_coords, + true_coords_col=true_R_coords, + mask_row=mask_row, + mask_col=mask_col, + multiplicity=multiplicity, + atom_indices_row=atom_indices_row, + atom_indices_col=atom_indices_col, + cutoff_col=cutoff_col, + do_mask_diagonal=True, + per_atom=True, + ) + + atom_mask = mask_row + if multiplicity > 1: + token_type = mol_type.repeat_interleave(multiplicity, 0) + else: + token_type = mol_type + + protein_mask = (token_type == const.chain_type_ids["PROTEIN"]).float() * atom_mask * mask_no_match + ligand_mask = (token_type == const.chain_type_ids["NONPOLYMER"]).float() * atom_mask * mask_no_match + dna_mask = (token_type == const.chain_type_ids["DNA"]).float() * atom_mask * mask_no_match + rna_mask = (token_type == const.chain_type_ids["RNA"]).float() * atom_mask * mask_no_match + + abs_err = torch.abs(target_lddt - pred_lddt) + + def _mae_and_total(mask): + total = torch.sum(mask) + mae = torch.sum(abs_err * mask) / (total + 1e-5) + return mae, total + + protein_mae, protein_total = _mae_and_total(protein_mask) + ligand_mae, ligand_total = _mae_and_total(ligand_mask) + dna_mae, dna_total = _mae_and_total(dna_mask) + rna_mae, rna_total = _mae_and_total(rna_mask) + + mae_plddt_dict = { + "protein": protein_mae, + "ligand": ligand_mae, + "dna": dna_mae, + "rna": rna_mae, + } + total_dict = { + "protein": protein_total, + "ligand": ligand_total, + "dna": dna_total, + "rna": rna_total, + } + return mae_plddt_dict, total_dict diff --git a/src/boltz/distributed/model/models/__init__.py b/src/boltz/distributed/model/models/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/src/boltz/distributed/model/models/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/src/boltz/distributed/model/models/boltz2.py b/src/boltz/distributed/model/models/boltz2.py new file mode 100644 index 000000000..5b7d922d5 --- /dev/null +++ b/src/boltz/distributed/model/models/boltz2.py @@ -0,0 +1,1394 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Distributed Boltz2 model wrapper for context-parallel training and inference. + +This module wraps a serial :class:`~boltz.model.models.boltz2.Boltz2` model +and replaces its submodules with DTensor-aware distributed counterparts, +enabling context parallelism (CP) across sequence/pair dimensions. + +Architecture overview +--------------------- +The wrapper follows the same pattern as Boltz1Distributed +(``other_versions/boltz-1x-cp/src/boltz/distributed/model/model.py``): + +1. Accept a fully-initialised serial ``Boltz2`` instance. +2. Replace each serial submodule with its distributed counterpart + (e.g. ``LinearParamsReplicated``, distributed ``MSAModule``, etc.). +3. Re-implement ``forward`` to use DTensor operations for outer products, + recycling, and trunk computation. +4. Provide training/validation/predict steps that handle DTensor + ↔ plain-tensor conversions for loss computation and logging. + +Submodule availability +---------------------- +Some serial submodules do **not** yet have distributed implementations. +These are wrapped in ``_PlaceholderModule`` which raises +``NotImplementedError`` when their code path is hit during a forward +pass. The table below summarises the status: + ++----------------------------+----------------------------------------------+ +| Serial submodule | Distributed status | ++============================+==============================================+ +| s_init, z_init_1, z_init_2 | LinearParamsReplicated (ready) | +| s_norm, z_norm | LayerNormParamsReplicated (ready) | +| s_recycle, z_recycle | LinearParamsReplicated (ready) | +| token_bonds | LinearParamsReplicated (ready) | +| msa_module | MSAModule (trunkv2 distributed, ready) | +| pairformer_module | PairformerModule (ready) | +| distogram_module | DistogramModule (trunkv2.py, ready) | +| rel_pos | RelativePositionEncoder (encoders.py, ready) | +| contact_conditioning | ContactConditioning (trunkv2.py, ready) | +| bfactor_module | BFactorModule (trunkv2.py, ready) | +| diffusion_conditioning | DiffusionConditioning (ready) | +| structure_module | AtomDiffusion (ready) | +| input_embedder | InputEmbedder (trunkv2.py, ready) | +| confidence_module | ConfidenceModule (confidencev2.py, ready) | +| template_module | TODO: needs distributed TemplateModule | +| affinity_module(s) | TODO: needs distributed AffinityModule | +| token_bonds_type | EmbeddingParamsReplicated (ready) | ++----------------------------+----------------------------------------------+ + +When a distributed counterpart becomes available, update the corresponding +``_wrap_*`` method and remove the placeholder. +""" + +import gc +import warnings +from copy import deepcopy +from typing import Any, Optional + +import numpy as np +import torch +from pytorch_lightning import LightningModule +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor import zeros as dtensor_zeros +from torchmetrics import MeanMetric + +from boltz.distributed.comm import AttentionPairBiasComm, TransposeComm +from boltz.distributed.data.feature.symmetry import ( + minimum_lddt_symmetry_coords as minimum_lddt_symmetry_coords_dtensor, +) +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.elementwise_op import ElementwiseOp, elementwise_op, scalar_tensor_op +from boltz.distributed.model.layers.embedding import EmbeddingParamsReplicated +from boltz.distributed.model.layers.flatten_and_unflatten import shardwise_flatten_sharded +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.linear import LinearParamsReplicated +from boltz.distributed.model.layers.outer_op import OuterOp, replicate_to_shard_outer_op +from boltz.distributed.model.layers.pairformer import PairformerModule +from boltz.distributed.model.layers.redistribute_transpose import redistribute_transpose +from boltz.distributed.model.layers.repeat_interleave import shardwise_repeat_interleave +from boltz.distributed.model.layers.squeeze import shardwise_squeeze, shardwise_unsqueeze +from boltz.distributed.model.loss.bfactor import bfactor_loss +from boltz.distributed.model.loss.confidencev2 import confidence_loss +from boltz.distributed.model.loss.distogram import distogram_loss +from boltz.distributed.model.modules.confidencev2 import ConfidenceModule +from boltz.distributed.model.modules.diffusion import AtomDiffusion +from boltz.distributed.model.modules.diffusion_conditioning import DiffusionConditioning +from boltz.distributed.model.modules.encoders import RelativePositionEncoder +from boltz.distributed.model.modules.trunkv2 import ( + BFactorModule, + ContactConditioning, + DistogramModule, + InputEmbedder, + MSAModule, +) +from boltz.distributed.model.optim.ema import DistributedEMA +from boltz.distributed.utils import update_exhaustive_strides +from boltz.model.models.boltz2 import Boltz2 as SerialBoltz2 +from boltz.model.optim.scheduler import AlphaFoldLRScheduler + + +def _ensure_numpy_compatible_dtype(t: torch.Tensor) -> torch.Tensor: + """Promote tensor dtype to at least float32 for NumPy compatibility. + + NumPy does not support bfloat16/float16. ``torch.promote_types`` returns + the wider of *t.dtype* and float32, so half-precision becomes float32 while + float32/float64 are preserved unchanged. + """ + return t.to(dtype=torch.promote_types(t.dtype, torch.float32)) + + +def _assert_no_dtensors_in_output(d: dict[str, Any], prefix: str = "") -> None: + """Raise TypeError if any DTensor values remain in the predict output dict. + + Walks the dict recursively so nested structures like ``pair_chains_iptm`` + are also checked. Called at the end of ``predict_step`` to ensure the + writer callback only receives plain tensors. + """ + for key, val in d.items(): + path = f"{prefix}.{key}" if prefix else key + if isinstance(val, DTensor): + raise TypeError( + f"predict_step output['{path}'] is a DTensor " + f"(placements={val.placements}). Convert to a plain tensor " + f"via full_tensor() or to_local() before returning." + ) + if isinstance(val, dict): + _assert_no_dtensors_in_output(val, prefix=path) + + +class _PlaceholderModule(torch.nn.Module): + """Placeholder for a serial module that does not yet have a distributed implementation. + + Stores the serial module's parameters (so they appear in + ``state_dict`` for checkpoint compatibility) but raises + ``NotImplementedError`` on forward. All parameters are frozen + because gradients can never flow through a placeholder. + """ + + def __init__(self, serial_module: torch.nn.Module, name: str) -> None: + super().__init__() + self._serial = serial_module + self._name = name + for p in self._serial.parameters(): + p.requires_grad_(False) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError( + f"Distributed version of '{self._name}' is not yet implemented. " + f"The serial module is stored as a placeholder for parameter " + f"accounting and checkpoint compatibility." + ) + + +class Boltz2(LightningModule): + """Distributed Boltz2 model with context parallelism via DTensor. + + Wraps a fully-initialised serial :class:`~boltz.model.models.boltz2.Boltz2` + and replaces submodules with DTensor-aware distributed counterparts where + available. Submodules without a distributed implementation are wrapped in + :class:`_PlaceholderModule` so their parameters are tracked in the state + dict but ``NotImplementedError`` is raised if the code path is hit. + """ + + def __init__( + self, + model_serial: SerialBoltz2, + dist_manager: DistributedManager, + ) -> None: + """Initialise the distributed Boltz2 model from a serial model. + + Parameters + ---------- + model_serial : SerialBoltz2 + Fully initialised serial Boltz2 instance. + dist_manager : DistributedManager + Distributed manager with device meshes and process groups. + + Raises + ------ + Warning + When a module required for the requested configuration lacks a + distributed implementation. + """ + super().__init__() + + self.dist_manager = dist_manager + self.device_mesh_subgroups = self.dist_manager.device_mesh_subgroups + self.rank_coords = self.device_mesh_subgroups.get_coordinate() + self.is_cp_rank_zero = self.rank_coords[1] == 0 and self.rank_coords[2] == 0 + + # Preserve hyper-parameters for checkpoint portability + self.save_hyperparameters(model_serial.hparams) + + self.has_context_parallelism = True + + # ------------------------------------------------------------------ # + # Transfer scalar config from serial model # + # ------------------------------------------------------------------ # + self.training_args = model_serial.training_args + self.validation_args = model_serial.validation_args + self.diffusion_loss_args = model_serial.diffusion_loss_args + self.predict_args = model_serial.predict_args + self.steering_args = getattr(model_serial, "steering_args", None) + self.validate_structure = model_serial.validate_structure + + self.no_random_recycling_training = model_serial.no_random_recycling_training + self.exclude_ions_from_lddt = model_serial.exclude_ions_from_lddt + self.num_bins = model_serial.num_bins + self.min_dist = model_serial.min_dist + self.max_dist = model_serial.max_dist + self.num_distograms = model_serial.num_distograms + self.aggregate_distogram = model_serial.aggregate_distogram + self.is_pairformer_compiled = False + self.is_msa_compiled = False + self.is_template_compiled = False + self.log_loss_every_steps = model_serial.log_loss_every_steps + + self.bond_type_feature = model_serial.bond_type_feature + self.run_trunk_and_structure = model_serial.run_trunk_and_structure + self.skip_run_structure = model_serial.skip_run_structure + self.predict_bfactor = model_serial.predict_bfactor + self.checkpoint_diffusion_conditioning = model_serial.checkpoint_diffusion_conditioning + + # Confidence / affinity / structure flags + self.confidence_prediction = model_serial.confidence_prediction + self.affinity_prediction = False # TODO: enable when distributed AffinityModule is ready + self.affinity_ensemble = getattr(model_serial, "affinity_ensemble", False) + self.affinity_mw_correction = getattr(model_serial, "affinity_mw_correction", True) + self.token_level_confidence = getattr(model_serial, "token_level_confidence", True) + self.alpha_pae = model_serial.alpha_pae + self.structure_prediction_training = model_serial.structure_prediction_training + + if model_serial.affinity_prediction: + warnings.warn( + "Affinity prediction is not yet implemented for context parallelism in Boltz2. " + "Disabling affinity_prediction in the distributed wrapper." + ) + + # EMA – Boltz2 uses callback-based EMA (boltz.model.optim.ema.EMA), + # and the distributed counterpart is boltz.distributed.model.optim.ema.DistributedEMA. + # The wrapper itself does *not* manage EMA – the callback does. + self.use_ema = model_serial.use_ema + self.ema_decay = model_serial.ema_decay + + # ------------------------------------------------------------------ # + # Communication helpers # + # ------------------------------------------------------------------ # + layout_group_cp = self.dist_manager.layout_subgroups["cp"] + self.transpose_comm = TransposeComm(self.dist_manager.group["cp"], layout_group_cp) + + # Process groups for losses that reduce over unique data dimensions + # only (skipping cp1 which is Replicate for single-repr). + # + # We store the individual dp and cp0 groups rather than a single + # flattened dp×cp0 group. Creating a flattened group via + # ``submesh._flatten().get_group()`` triggers ``new_group()`` calls + # that are NOT coordinated across all ranks when the submesh differs + # per cp1 position, causing NCCL deadlocks on 4+ GPU topologies. + # Two sequential all_reduces (dp then cp0) are mathematically + # equivalent: sum_{dp×cp0}(x) = sum_{cp0}(sum_{dp}(x)). + self.dp_group = self.device_mesh_subgroups.get_group(0) + self.cp0_group = self.device_mesh_subgroups.get_group(1) + self.cp_group = self.dist_manager.group["cp"] + + self._cp1_group = self.device_mesh_subgroups.get_group(2) + + # ------------------------------------------------------------------ # + # Validation infrastructure # + # ------------------------------------------------------------------ # + self.num_val_datasets = model_serial.num_val_datasets + if self.validate_structure: + self.val_group_mapper = {} # maps a dataset index to a validation group name + self.validator_mapper = {} # maps a dataset index to a validator + self.validators = model_serial.validators + + # ------------------------------------------------------------------ # + # Wrap submodules that have distributed implementations # + # ------------------------------------------------------------------ # + self.s_init = LinearParamsReplicated(model_serial.s_init, self.device_mesh_subgroups) + self.z_init_1 = LinearParamsReplicated(model_serial.z_init_1, self.device_mesh_subgroups) + self.z_init_2 = LinearParamsReplicated(model_serial.z_init_2, self.device_mesh_subgroups) + + self.s_norm = LayerNormParamsReplicated(model_serial.s_norm, self.device_mesh_subgroups) + self.z_norm = LayerNormParamsReplicated(model_serial.z_norm, self.device_mesh_subgroups) + + self.s_recycle = LinearParamsReplicated(model_serial.s_recycle, self.device_mesh_subgroups) + self.z_recycle = LinearParamsReplicated(model_serial.z_recycle, self.device_mesh_subgroups) + + self.token_bonds = LinearParamsReplicated(model_serial.token_bonds, self.device_mesh_subgroups) + + # ------------------------------------------------------------------ # + # Trunk submodules with distributed implementations # + # ------------------------------------------------------------------ # + self.msa_module = MSAModule(model_serial.msa_module, self.dist_manager) + + self.pairformer_module = PairformerModule(model_serial.pairformer_module, self.dist_manager) + + self.distogram_module = DistogramModule( + model_serial.distogram_module, + self.dist_manager, + distogram_comm=deepcopy(self.transpose_comm), + ) + + # ------------------------------------------------------------------ # + # Submodules that need distributed implementations (placeholders) # + # ------------------------------------------------------------------ # + # InputEmbedder v2: wraps atom-level encoder, feature embedding, etc. + self.input_embedder = InputEmbedder( + model_serial.input_embedder, + device_mesh=self.device_mesh_subgroups, + ) + + # RelativePositionEncoder v2 + self.rel_pos = RelativePositionEncoder( + model_serial.rel_pos, + device_mesh=self.dist_manager.device_mesh_subgroups, + transpose_comm=deepcopy(self.transpose_comm), + ) + + # ContactConditioning + self.contact_conditioning = ContactConditioning( + model_serial.contact_conditioning, + device_mesh=self.dist_manager.device_mesh_subgroups, + ) + + # bond_type_feature embedding (optional) + if self.bond_type_feature: + self.token_bonds_type = EmbeddingParamsReplicated( + model_serial.token_bonds_type, device_mesh=self.dist_manager.device_mesh_subgroups + ) + + # DiffusionConditioning + self.diffusion_conditioning = DiffusionConditioning( + model_serial.diffusion_conditioning, + device_mesh=self.device_mesh_subgroups, + ) + + # AtomDiffusion v2 (structure module) + ring_comm_diffusion = AttentionPairBiasComm( + self.dist_manager.group["cp"], + self.dist_manager.layout_subgroups["cp"], + self.dist_manager.subgroups["cp"][0], + self.dist_manager.subgroups["cp"][1], + ) + self.structure_module = AtomDiffusion( + model_serial.structure_module, + device_mesh=self.device_mesh_subgroups, + ring_comm=ring_comm_diffusion, + transpose_comm=deepcopy(self.transpose_comm), + ) + + # TemplateModule — preserve weights but skip in forward (not yet distributed) + if model_serial.use_templates: + self.template_module = _PlaceholderModule(model_serial.template_module, "TemplateModule") + # TODO: set self.use_templates = model_serial.use_templates once a real + # distributed TemplateModule replaces the _PlaceholderModule above. + self.use_templates = model_serial.use_templates and not isinstance( + getattr(self, "template_module", None), _PlaceholderModule + ) + + # BFactorModule (optional) + if self.predict_bfactor: + self.bfactor_module = BFactorModule( + model_serial.bfactor_module, + device_mesh=self.dist_manager.device_mesh_subgroups, + ) + + # ConfidenceModule v2 + if model_serial.confidence_prediction: + confidence_transpose_comm = deepcopy(self.transpose_comm) + self.confidence_module = ConfidenceModule( + model_serial.confidence_module, + dist_manager=dist_manager, + transpose_comm=confidence_transpose_comm, + ) + + # AffinityModule(s) (disabled for now) + if model_serial.affinity_prediction: + if model_serial.affinity_ensemble: + self.affinity_module1 = _PlaceholderModule(model_serial.affinity_module1, "AffinityModule1") + self.affinity_module2 = _PlaceholderModule(model_serial.affinity_module2, "AffinityModule2") + else: + self.affinity_module = _PlaceholderModule(model_serial.affinity_module, "AffinityModule") + + # Freeze parameters not involved in structure prediction training + if not self.structure_prediction_training: + for name, param in self.named_parameters(): + if ( + name.split(".")[0] not in ["confidence_module", "affinity_module"] + and "out_token_feat_update" not in name + ): + param.requires_grad = False + + # Validate: every trainable parameter must be a DTensor so that + # on_after_backward can redistribute its gradient. Non-DTensor + # trainable params would accumulate plain-tensor gradients that + # on_after_backward cannot handle. + for name, param in self.named_parameters(): + if param.requires_grad and not isinstance(param, DTensor): + raise ValueError( + f"Trainable parameter '{name}' is a plain Tensor, not a DTensor. " + f"All trainable parameters must be DTensors so that " + f"on_after_backward can redistribute their gradients. " + f"Either wrap the owning module with a DTensor-aware wrapper " + f"or freeze this parameter (requires_grad=False)." + ) + + # ====================================================================== # + # Forward pass # + # ====================================================================== # + + def forward( + self, + feats: dict[str, DTensor], + recycling_steps: int = 0, + num_sampling_steps: Optional[int] = None, + multiplicity_diffusion_train: int = 1, + diffusion_samples: int = 1, + max_parallel_samples: Optional[int] = None, + run_confidence_sequentially: bool = False, + ) -> dict[str, DTensor]: + """Forward pass through the distributed Boltz2 model. + + Performs structure prediction using DTensor context parallelism. + The trunk (input embedding → recycling → MSA → pairformer) is fully + distributed. Diffusion conditioning and structure prediction require + their distributed counterparts to be implemented. + + Parameters + ---------- + feats : dict[str, DTensor] + Input features as DTensors. + recycling_steps : int + Number of recycling iterations. + num_sampling_steps : int, optional + Number of diffusion sampling steps for inference. + multiplicity_diffusion_train : int + Training diffusion multiplicity. + diffusion_samples : int + Number of diffusion samples for inference. + max_parallel_samples : int, optional + Maximum number of parallel diffusion samples. + run_confidence_sequentially : bool + Whether to run the confidence module sequentially. + + Returns + ------- + dict[str, DTensor] + Dictionary containing model outputs. + """ + with torch.set_grad_enabled(self.training and self.structure_prediction_training): + s_inputs = self.input_embedder(feats) + + # Initialise single and pairwise embeddings + s_init = self.s_init(s_inputs) + + z2 = self.z_init_2(s_inputs) + z1 = self.z_init_1(s_inputs) + + # Outer sum: globally equivalent to z1[:, :, None, :] + z2[:, None, :, :] + # Both z1 and z2 have placements (Shard(0), Shard(1), Replicate()). + # replicate_to_shard_outer_op handles the transpose_then_redistribute + # for z2 internally and produces (Shard(0), Shard(1), Shard(2)) output. + # Its backward correctly all-reduces the column/row gradient sums + # across the Replicate axis. + z_init = replicate_to_shard_outer_op( + z1, OuterOp.SUM, axis=1, transpose_comm=self.transpose_comm, input_t=z2 + ) + + relative_position_encoding = self.rel_pos(feats) + z_init = elementwise_op(z_init, relative_position_encoding, ElementwiseOp.SUM) + z_init = elementwise_op( + z_init, + self.token_bonds(feats["token_bonds"].to(dtype=z_init.dtype)), + ElementwiseOp.SUM, + ) + + if self.bond_type_feature: + z_init = elementwise_op( + z_init, + self.token_bonds_type(feats["type_bonds"].long()), + ElementwiseOp.SUM, + ) + + z_init = elementwise_op(z_init, self.contact_conditioning(feats), ElementwiseOp.SUM) + + # Initialise recycling buffers using dtensor_zeros to avoid + # native DTensor dispatch which may trigger implicit communication. + s: DTensor = dtensor_zeros( + s_init.shape, + dtype=s_init.dtype, + device_mesh=self.device_mesh_subgroups, + placements=list(s_init.placements), + ) + z: DTensor = dtensor_zeros( + z_init.shape, + dtype=z_init.dtype, + device_mesh=self.device_mesh_subgroups, + placements=list(z_init.placements), + ) + + mask = feats["token_pad_mask"].to(dtype=s.dtype) + pair_mask = feats["token_pair_pad_mask"].to(dtype=z.dtype) + + # Redistribute s_inputs for MSAModule + # shape: (B, N, D); placements: (S(0), S(1), R) → (S(0), R, S(1)) + s_inputs_redistributed = redistribute_transpose( + s_inputs, + transpose_comm=self.transpose_comm, + output_placements=(Shard(0), Replicate(), Shard(1)), + dim0=None, + dim1=None, + ) + + if self.run_trunk_and_structure: + for i in range(recycling_steps + 1): + with torch.set_grad_enabled( + self.training and self.structure_prediction_training and (i == recycling_steps) + ): + if self.training and (i == recycling_steps) and torch.is_autocast_enabled(): + torch.clear_autocast_cache() + + # Apply recycling + s = elementwise_op(s_init, self.s_recycle(self.s_norm(s)), ElementwiseOp.SUM) + z = elementwise_op(z_init, self.z_recycle(self.z_norm(z)), ElementwiseOp.SUM) + + # Templates (optional) + if self.use_templates: + # TODO: use distributed template_module when available + z = elementwise_op( + z, + self.template_module(z, feats, pair_mask), + ElementwiseOp.SUM, + ) + + # MSA module + z = elementwise_op( + z, + self.msa_module(z, s_inputs_redistributed, feats), + ElementwiseOp.SUM, + ) + + # Pairformer + s, z = self.pairformer_module(s, z, mask=mask, pair_mask=pair_mask) + + pdistogram = self.distogram_module(z) + dict_out: dict[str, DTensor] = { + "pdistogram": pdistogram, + "s": s, + "z": z, + } + + if self.run_trunk_and_structure and (not self.skip_run_structure): + # Diffusion conditioning (distributed returns 5 values; to_keys + # is handled internally by the distributed AtomAttentionEncoder) + if self.checkpoint_diffusion_conditioning and self.training: + q, c, atom_enc_bias, atom_dec_bias, token_trans_bias = torch.utils.checkpoint.checkpoint( + self.diffusion_conditioning, + s, + z, + relative_position_encoding, + feats, + use_reentrant=False, + ) + else: + q, c, atom_enc_bias, atom_dec_bias, token_trans_bias = self.diffusion_conditioning( + s_trunk=s, + z_trunk=z, + relative_position_encoding=relative_position_encoding, + feats=feats, + ) + diffusion_conditioning_dict = { + "q": q, + "c": c, + "atom_enc_bias": atom_enc_bias, + "atom_dec_bias": atom_dec_bias, + "token_trans_bias": token_trans_bias, + } + + # Inference: reverse diffusion sampling + if (not self.training) or self.confidence_prediction: + with torch.autocast("cuda", enabled=False): + compute_dtype = torch.promote_types(s.dtype, torch.float32) + struct_out = self.structure_module.sample( + s_trunk=s.to(compute_dtype), + s_inputs=s_inputs.to(compute_dtype), + feats=feats, + num_sampling_steps=num_sampling_steps, + atom_mask=feats["atom_pad_mask"].to(compute_dtype), + multiplicity=diffusion_samples, + max_parallel_samples=max_parallel_samples, + diffusion_conditioning=diffusion_conditioning_dict, + ) + dict_out.update(struct_out) + + if self.predict_bfactor: + dict_out["pbfactor"] = self.bfactor_module(s) + + # Training: diffusion forward pass + if self.training and self.structure_prediction_training: + atom_coords = feats["coords"] + K = atom_coords.shape[1] + assert K in (multiplicity_diffusion_train, 1) + + # Expand K → multiplicity if needed, then flatten (B, K, L, 3) → (B*K, L, 3). + if K < multiplicity_diffusion_train: + atom_coords = shardwise_repeat_interleave(atom_coords, multiplicity_diffusion_train // K, dim=1) + feats["coords"] = shardwise_flatten_sharded(atom_coords, start_dim=0, end_dim=1) + + with torch.autocast("cuda", enabled=False): + compute_dtype = torch.promote_types(s.dtype, torch.float32) + struct_out = self.structure_module( + s_trunk=s.to(compute_dtype), + s_inputs=s_inputs.to(compute_dtype), + feats=feats, + multiplicity=multiplicity_diffusion_train, + diffusion_conditioning=diffusion_conditioning_dict, # noqa: F821 + ) + dict_out.update(struct_out) + + elif self.training: + # squeeze(1) removes the singleton ensemble dim: + # (B, 1, A, 3) → (B*1, A, 3) = (B, A, 3) + feats["coords"] = shardwise_flatten_sharded(feats["coords"], start_dim=0, end_dim=1) + + if self.confidence_prediction: + if "frames_idx" in feats and feats["frames_idx"].ndim == 4: + feats["frames_idx"] = shardwise_flatten_sharded(feats["frames_idx"], start_dim=0, end_dim=1) + if "frame_resolved_mask" in feats and feats["frame_resolved_mask"].ndim == 3: + feats["frame_resolved_mask"] = shardwise_flatten_sharded( + feats["frame_resolved_mask"], start_dim=0, end_dim=1 + ) + dict_out.update( + self.confidence_module( + s_inputs=s_inputs.detach(), + s=s.detach(), + z=z.detach(), + x_pred=( + dict_out["sample_atom_coords"].detach() + if not self.skip_run_structure + else shardwise_repeat_interleave(feats["coords"], diffusion_samples, dim=0) + ), + feats=feats, + pred_distogram_logits=dict_out["pdistogram"].detach(), + multiplicity=diffusion_samples, + run_sequentially=run_confidence_sequentially, + ) + ) + + # Affinity (TODO: enable when distributed AffinityModule is ready) + + return dict_out + + # ====================================================================== # + # Training step # + # ====================================================================== # + + def training_step(self, batch: dict[str, int | DTensor], batch_idx: int) -> DTensor: + """Training step with distributed loss computation.""" + # Sample recycling steps + if self.no_random_recycling_training: + recycling_steps = self.training_args.recycling_steps + else: + rgn = np.random.default_rng(self.global_step) + recycling_steps = rgn.integers(0, self.training_args.recycling_steps + 1).item() + + # Synchronise recycling steps across CP ranks via the flat CP group + recycling_steps_tensor = torch.tensor(recycling_steps, device=self.device) + cp_group_global_rank_zero = torch.distributed.get_global_rank(self.cp_group, 0) + torch.distributed.broadcast(recycling_steps_tensor, src=cp_group_global_rank_zero, group=self.cp_group) + recycling_steps = recycling_steps_tensor.item() + + if self.training_args.get("sampling_steps_random", None) is not None: + rgn_sampling = np.random.default_rng(self.global_step) + sampling_steps = rgn_sampling.choice(self.training_args.sampling_steps_random) + else: + sampling_steps = self.training_args.sampling_steps + + # Broadcast sampling_steps across CP ranks for consistency + sampling_steps_tensor = torch.tensor(int(sampling_steps), device=self.device) + torch.distributed.broadcast(sampling_steps_tensor, src=cp_group_global_rank_zero, group=self.cp_group) + sampling_steps = sampling_steps_tensor.item() + + out = self( + feats=batch, + recycling_steps=recycling_steps, + num_sampling_steps=sampling_steps, + multiplicity_diffusion_train=self.training_args.diffusion_multiplicity, + diffusion_samples=self.training_args.diffusion_samples, + ) + + # Compute losses + if self.structure_prediction_training: + disto_loss = self._compute_distogram_loss(out, batch) + diffusion_loss_dict = self.structure_module.compute_loss( + batch, + out, + multiplicity=self.training_args.diffusion_multiplicity, + **self.diffusion_loss_args, + ) + bfactor_loss_val = self._compute_bfactor_loss(out, batch) + else: + zeros_dtensor = dtensor_zeros( + (), + requires_grad=False, + device_mesh=self.device_mesh_subgroups, + placements=(Replicate(), Replicate(), Replicate()), + ) + disto_loss = zeros_dtensor + bfactor_loss_val = zeros_dtensor + diffusion_loss_dict = {"loss": zeros_dtensor, "loss_breakdown": {}} + + confidence_loss_dict = self._compute_confidence_loss(out, batch) + + # Aggregate losses + loss = elementwise_op( + elementwise_op( + elementwise_op( + scalar_tensor_op( + self.training_args.diffusion_loss_weight, + diffusion_loss_dict["loss"], + ElementwiseOp.PROD, + ), + scalar_tensor_op( + self.training_args.distogram_loss_weight, + disto_loss, + ElementwiseOp.PROD, + ), + ElementwiseOp.SUM, + ), + scalar_tensor_op( + # Default 0.0 matches serial boltz2.py (added after other weights). + self.training_args.get("bfactor_loss_weight", 0.0), + bfactor_loss_val, + ElementwiseOp.PROD, + ), + ElementwiseOp.SUM, + ), + scalar_tensor_op( + self.training_args.confidence_loss_weight, + confidence_loss_dict["loss"], + ElementwiseOp.PROD, + ), + ElementwiseOp.SUM, + ) + + if not (self.global_step % self.log_loss_every_steps): + self.log("train/distogram_loss", disto_loss.to_local()) + self.log("train/diffusion_loss", diffusion_loss_dict["loss"].to_local()) + for k, v in diffusion_loss_dict["loss_breakdown"].items(): + self.log(f"train/{k}", v.to_local() if isinstance(v, DTensor) else v) + if self.confidence_prediction: + self.log("train/confidence_loss", confidence_loss_dict["loss"].to_local()) + for k, v in confidence_loss_dict["loss_breakdown"].items(): + self.log(f"train/{k}", v.to_local() if isinstance(v, DTensor) else v) + self.log("train/loss", loss.to_local()) + self.training_log() + + return loss + + def _compute_distogram_loss(self, out: dict, batch: dict) -> DTensor: + """Compute distogram loss using the distributed implementation.""" + disto_loss, _ = distogram_loss( + out, batch, comm=self.transpose_comm, aggregate_distogram=self.aggregate_distogram + ) + return disto_loss + + def _compute_confidence_loss(self, out: dict, batch: dict) -> dict: + """Compute confidence loss if confidence_prediction is enabled.""" + if not self.confidence_prediction: + zeros = dtensor_zeros( + (), + requires_grad=False, + device_mesh=self.device_mesh_subgroups, + placements=(Replicate(), Replicate(), Replicate()), + ) + return {"loss": zeros, "loss_breakdown": {}} + + return_dict = self.get_true_coordinates( + batch, + out, + diffusion_samples=self.training_args.diffusion_samples, + symmetry_correction=self.training_args.symmetry_correction, + ) + + true_coords = return_dict["true_coords"] + true_coords_resolved_mask = return_dict["true_coords_resolved_mask"] + + return confidence_loss( + out, + batch, + true_coords, + true_coords_resolved_mask, + comm=self.transpose_comm, + token_level_confidence=self.token_level_confidence, + alpha_pae=self.alpha_pae, + multiplicity=self.training_args.diffusion_samples, + dist_manager=self.dist_manager, + group_layout=self.dist_manager.layout_subgroups["cp"], + ) + + def _compute_bfactor_loss(self, out: dict, batch: dict) -> DTensor: + """Compute bfactor loss if enabled.""" + if self.predict_bfactor: + return bfactor_loss( + out, + batch, + device_mesh=self.device_mesh_subgroups, + dp_group=self.dp_group, + cp0_group=self.cp0_group, + cp1_group=self._cp1_group, + ) + return dtensor_zeros( + (), + requires_grad=False, + device_mesh=self.device_mesh_subgroups, + placements=(Replicate(), Replicate(), Replicate()), + ) + + # ====================================================================== # + # Logging helpers # + # ====================================================================== # + + def training_log(self) -> None: + """Log training metrics.""" + self.log("train/param_norm", self.parameter_norm(self), prog_bar=False) + + lr = self.trainer.optimizers[0].param_groups[0]["lr"] + self.log("lr", lr, prog_bar=False) + + self.log("train/param_norm_msa_module", self.parameter_norm(self.msa_module), prog_bar=False) + self.log("train/param_norm_pairformer_module", self.parameter_norm(self.pairformer_module), prog_bar=False) + self.log("train/param_norm_structure_module", self.parameter_norm(self.structure_module), prog_bar=False) + + if self.confidence_prediction: + self.log( + "train/param_norm_confidence_module", + self.parameter_norm(self.confidence_module), + prog_bar=False, + ) + + def gradient_norm(self, module: torch.nn.Module) -> torch.Tensor: + """Compute L2 norm of gradients for a distributed module. + + Handles both DTensor and plain Tensor parameters (e.g. from + placeholder modules whose serial weights have not yet been + distributed). + """ + norms_sq: list[torch.Tensor] = [] + for p in module.parameters(): + if p.requires_grad and p.grad is not None: + grad = p.grad.to_local() if isinstance(p.grad, DTensor) else p.grad + norms_sq.append(grad.norm(p=2) ** 2) + if len(norms_sq) == 0: + return torch.tensor(0.0, device=self.dist_manager.device.type) + return torch.stack(norms_sq).sum().sqrt().to(device=self.dist_manager.device.type) + + def parameter_norm(self, module: torch.nn.Module) -> torch.Tensor: + """Compute L2 norm of parameters for a distributed module. + + Handles both DTensor and plain Tensor parameters. + """ + norms_sq: list[torch.Tensor] = [] + for p in module.parameters(): + if p.requires_grad: + val = p.to_local() if isinstance(p, DTensor) else p + norms_sq.append(val.norm(p=2) ** 2) + if len(norms_sq) == 0: + return torch.tensor(0.0, device=self.dist_manager.device.type) + return torch.stack(norms_sq).sum().sqrt().to(device=self.dist_manager.device.type) + + # ====================================================================== # + # Gradient redistribution # + # ====================================================================== # + + def on_after_backward(self) -> None: + """Redistribute DTensor gradients to Replicate placement after backward. + + Called after ``loss.backward()`` and before ``optimizer.step()``. + Ensures gradients are properly synchronised across context parallel + ranks via all-reduce. + + The ``__init__`` validates that all trainable parameters are DTensors, + so only DTensor or None gradients should appear here. Plain-tensor + gradients are skipped with a warning as a defensive measure. + + Note: parameter gradients on the Replicate (cp1) axis are already + synchronised per-layer via ``Partial("avg")`` in the linear/layernorm + backward (``avg_over_replicate_param_grad=True``). The redistribute + call below handles any remaining ``Partial`` placements from other + sources. + """ + for name, p in self.named_parameters(): + if p.grad is None: + continue + if isinstance(p.grad, DTensor): + p.grad = p.grad.redistribute(p.grad.device_mesh, [Replicate()] * p.grad.device_mesh.ndim) + else: + warnings.warn( + f"Parameter '{name}' has a plain-tensor gradient (type={type(p.grad)}), " + f"skipping redistribution. This should not happen — all trainable " + f"parameters should be DTensors. Check __init__ validation.", + stacklevel=2, + ) + + if not (self.global_step % self.log_loss_every_steps): + self.log("train/grad_norm", self.gradient_norm(self), prog_bar=False) + self.log("train/grad_norm_msa_module", self.gradient_norm(self.msa_module), prog_bar=False) + self.log("train/grad_norm_pairformer_module", self.gradient_norm(self.pairformer_module), prog_bar=False) + self.log("train/grad_norm_structure_module", self.gradient_norm(self.structure_module), prog_bar=False) + if self.confidence_prediction: + self.log( + "train/grad_norm_confidence_module", + self.gradient_norm(self.confidence_module), + prog_bar=False, + ) + + # ====================================================================== # + # Epoch-level hooks # + # ====================================================================== # + + def on_train_epoch_end(self) -> None: + """Log epoch-level training metrics.""" + + # ====================================================================== # + # Validation # + # ====================================================================== # + + def validation_step(self, batch: dict[str, Any], batch_idx: int) -> None: + """Validation step delegating to the distributed validator.""" + if self.validate_structure: + try: + # idx_dataset is a non-sharded feature; CollateDTensor collates + # TRAINING_METADATA_FEATURES as a Python list of tensors. + + msg = "Only batch=1 is supported for validation" + assert len(batch["idx_dataset"]) == 1, msg + assert batch["idx_dataset"][0].shape[0] == 1, msg + + idx_dataset = batch["idx_dataset"][0].item() + validator = self.validator_mapper[idx_dataset] + + out = validator.run_model(model=self, batch=batch, idx_dataset=idx_dataset) + validator.process( + model=self, + batch=batch, + out=out, + idx_dataset=idx_dataset, + transpose_comm=self.transpose_comm, + ) + except RuntimeError as e: + idx_dataset = batch["idx_dataset"][0].item() + if "out of memory" in str(e): + msg = f"| WARNING: ran out of memory, skipping batch, {idx_dataset}" + print(msg) + torch.cuda.empty_cache() + gc.collect() + return + raise e + else: + try: + out = self( + batch, + recycling_steps=self.validation_args.recycling_steps, + num_sampling_steps=self.validation_args.sampling_steps, + diffusion_samples=self.validation_args.diffusion_samples, + run_confidence_sequentially=self.validation_args.get("run_confidence_sequentially", False), + ) + except RuntimeError as e: + idx_dataset = batch["idx_dataset"][0].item() + if "out of memory" in str(e): + msg = f"| WARNING: ran out of memory, skipping batch, {idx_dataset}" + print(msg) + torch.cuda.empty_cache() + gc.collect() + return + raise e + + def on_validation_epoch_end(self) -> None: + """Aggregate all metrics for each validator.""" + if not self.validate_structure: + return + + if self.trainer.sanity_checking: + for validator in self.validator_mapper.values(): + for m in validator.modules(): + if isinstance(m, MeanMetric): + m.reset() + return + + for validator in self.validator_mapper.values(): + validator.on_epoch_end(model=self) + + def setup(self, stage: str) -> None: + """Set the model for training, validation and inference.""" + + if ( + stage != "predict" + and hasattr(self.trainer, "datamodule") + and self.trainer.datamodule + and self.validate_structure + ): + self.val_group_mapper.update(self.trainer.datamodule.val_group_mapper) + + l1 = len(self.val_group_mapper) + l2 = self.num_val_datasets + msg = ( + f"Number of validation datasets num_val_datasets={l2} " + f"does not match the number of val_group_mapper entries={l1}." + ) + assert l1 == l2, msg + + all_validator_names = [] + for validator in self.validators: + for val_name in validator.val_names: + msg = f"Validator {val_name} duplicated in validators." + assert val_name not in all_validator_names, msg + all_validator_names.append(val_name) + for val_idx, val_group in self.val_group_mapper.items(): + if val_name == val_group["label"]: + self.validator_mapper[val_idx] = validator + + msg = "Mismatch between validator names and val_group_mapper values." + assert set(all_validator_names) == {x["label"] for x in self.val_group_mapper.values()}, msg + + def get_true_coordinates( + self, + batch: dict[str, Any], + out: dict[str, Any], + diffusion_samples: int, + symmetry_correction: bool, + expand_to_diffusion_samples: bool = True, + ) -> dict[str, Any]: + """Compute true coordinates for validation/confidence loss. + + In the distributed case, coordinates are DTensors sharded across the + (DP, CP_0, CP_1) mesh. When ``symmetry_correction`` is True, each + sample is processed by :func:`minimum_lddt_symmetry_coords_dtensor` + which internally gathers coordinates along CP axes, runs the symmetry + search on plain tensors using the ``cdist_lddt`` triton kernel, then + re-shards the results. + + Parameters + ---------- + batch : dict[str, Any] + Input features as DTensors. + out : dict[str, Any] + Model outputs including ``sample_atom_coords`` as a DTensor. + diffusion_samples : int + Number of diffusion samples per batch element. + symmetry_correction : bool + Whether to apply symmetry correction. + expand_to_diffusion_samples : bool + Whether to expand coordinates to match diffusion samples. + + Returns + ------- + dict[str, Any] + Dictionary with ``true_coords``, ``true_coords_resolved_mask``, + ``rmsds``, and ``best_rmsd_recall``. + + """ + if symmetry_correction: + assert expand_to_diffusion_samples, "expand_to_diffusion_samples must be true for symmetry correction." + + return_dict: dict[str, Any] = {} + sample_atom_coords = out["sample_atom_coords"] + + if symmetry_correction: + true_coords_list: list[DTensor] = [] + true_mask_list: list[DTensor] = [] + + local_batch_size = batch["token_index"].to_local().shape[0] + for i_batch_local in range(local_batch_size): + for rep in range(diffusion_samples): + i_local = i_batch_local * diffusion_samples + rep + best_true_coords, best_true_mask = minimum_lddt_symmetry_coords_dtensor( + coords=sample_atom_coords, + feats=batch, + index_batch_local=i_batch_local, + i_batch_multiplicity_local=i_local, + ) + true_coords_list.append(best_true_coords) + true_mask_list.append(best_true_mask) + + assert len(true_coords_list) >= 1, "There should be at least 1 true coords processed" + _coords_local = torch.cat([c.to_local() for c in true_coords_list], dim=0) + _coords_global_shape = list(_coords_local.shape) + for _pi, _pl in enumerate(true_coords_list[0].placements): + if hasattr(_pl, "dim"): + _coords_global_shape[_pl.dim] *= true_coords_list[0].device_mesh.size(_pi) + _coords_global_shape = torch.Size(_coords_global_shape) + true_coords = DTensor.from_local( + _coords_local, + device_mesh=true_coords_list[0].device_mesh, + placements=true_coords_list[0].placements, + shape=_coords_global_shape, + stride=update_exhaustive_strides(_coords_local.shape, _coords_local.stride(), _coords_global_shape), + ) + _mask_local = torch.cat([m.to_local() for m in true_mask_list], dim=0) + _mask_global_shape = list(_mask_local.shape) + for _pi, _pl in enumerate(true_mask_list[0].placements): + if hasattr(_pl, "dim"): + _mask_global_shape[_pl.dim] *= true_mask_list[0].device_mesh.size(_pi) + _mask_global_shape = torch.Size(_mask_global_shape) + true_coords_resolved_mask = DTensor.from_local( + _mask_local, + device_mesh=true_mask_list[0].device_mesh, + placements=true_mask_list[0].placements, + shape=_mask_global_shape, + stride=update_exhaustive_strides(_mask_local.shape, _mask_local.stride(), _mask_global_shape), + ) + + true_coords = shardwise_unsqueeze(true_coords, dim=1) + return_dict["true_coords"] = true_coords + return_dict["true_coords_resolved_mask"] = true_coords_resolved_mask + return_dict["rmsds"] = 0 + return_dict["best_rmsd_recall"] = 0 + else: + true_coords_resolved_mask = batch["atom_resolved_mask"] + true_coords = shardwise_squeeze(batch["coords"], dim=1) + if expand_to_diffusion_samples: + true_coords = shardwise_repeat_interleave(true_coords, diffusion_samples, 0) + true_coords_resolved_mask = shardwise_repeat_interleave(true_coords_resolved_mask, diffusion_samples, 0) + + return_dict["true_coords"] = true_coords + return_dict["true_coords_resolved_mask"] = true_coords_resolved_mask + return_dict["rmsds"] = 0 + return_dict["best_rmsd_recall"] = 0 + return_dict["best_rmsd_precision"] = 0 + + return return_dict + + # ====================================================================== # + # Prediction # + # ====================================================================== # + + def predict_step( + self, batch: dict[str, DTensor], batch_idx: int, dataloader_idx: int = 0 + ) -> dict[str, torch.Tensor]: + """Prediction step with distributed inference. + + Parameters + ---------- + batch : dict[str, DTensor] + Input features as DTensors. + batch_idx : int + Index of the current batch. + dataloader_idx : int + Index of the current dataloader. + + Returns + ------- + dict[str, torch.Tensor] + Prediction results gathered on rank 0 of each CP column. + """ + try: + out = self( + batch, + recycling_steps=self.predict_args["recycling_steps"], + num_sampling_steps=self.predict_args["sampling_steps"], + diffusion_samples=self.predict_args["diffusion_samples"], + max_parallel_samples=self.predict_args.get("max_parallel_samples", None), + run_confidence_sequentially=True, + ) + pred_dict: dict[str, Any] = {"exception": False} + + # Gather coords and masks onto rank 0 of each CP column + tag_group_gather = 0 + ranks_gather = self.dist_manager.subgroups_ranks["cp"][tag_group_gather] + group_gather = self.dist_manager.subgroups["cp"][tag_group_gather] + world_size_gather = len(ranks_gather) + + coords = out["sample_atom_coords"].to_local() + mask = batch["atom_pad_mask"].to_local() + + if self.dist_manager.subgroups_rank["cp"][tag_group_gather] == 0: + gather_list_coords = [torch.empty_like(coords) for _ in range(world_size_gather)] + gather_list_mask = [torch.empty_like(mask) for _ in range(world_size_gather)] + else: + gather_list_coords = None + gather_list_mask = None + + torch.distributed.gather(coords, gather_list_coords, dst=ranks_gather[0], group=group_gather) + torch.distributed.gather(mask, gather_list_mask, dst=ranks_gather[0], group=group_gather) + + if self.dist_manager.subgroups_rank["cp"][tag_group_gather] == 0: + pred_dict["masks"] = torch.concat(gather_list_mask, dim=1) + pred_dict["coords"] = _ensure_numpy_compatible_dtype(torch.concat(gather_list_coords, dim=1)) + else: + pred_dict["masks"] = mask + pred_dict["coords"] = _ensure_numpy_compatible_dtype(coords) + + if self.confidence_prediction: + # Per-token and per-pair confidence outputs are sharded across + # CP ranks. Use full_tensor() to reassemble the global tensor + # for the writer (predict path only, not hot). + for key in ["pde", "plddt"]: + val = out[key] + val = val.full_tensor() if isinstance(val, DTensor) else val + pred_dict[key] = _ensure_numpy_compatible_dtype(val) + + # Scalar metrics — already (Shard(0), Replicate(), Replicate()), + # so to_local() gives the correct value on every rank. + for key in [ + "complex_plddt", + "complex_iplddt", + "complex_pde", + "complex_ipde", + ]: + if key in out: + val = out[key] + pred_dict[key] = val.to_local() if isinstance(val, DTensor) else val + + # Confidence score (matching serial formula) + cplddt = pred_dict["complex_plddt"] + if "iptm" in out: + iptm_val = out["iptm"].to_local() if isinstance(out["iptm"], DTensor) else out["iptm"] + ptm_val = out["ptm"].to_local() if isinstance(out["ptm"], DTensor) else out["ptm"] + use_iptm = not torch.allclose(iptm_val, torch.zeros_like(iptm_val)) + pred_dict["confidence_score"] = (4 * cplddt + (iptm_val if use_iptm else ptm_val)) / 5 + else: + pred_dict["confidence_score"] = cplddt + + if self.alpha_pae > 0 and "pae" in out: + # pae is sharded across CP — needs full_tensor() to reassemble. + val = out["pae"] + val = val.full_tensor() if isinstance(val, DTensor) else val + pred_dict["pae"] = _ensure_numpy_compatible_dtype(val) + + # ptm, iptm, *_iptm are globally reduced scalars with + # placements (Shard(0), Replicate(), Replicate()) — already + # fully reduced across CP. to_local() extracts the value + # with no communication. + for key in ["ptm", "iptm", "ligand_iptm", "protein_iptm"]: + if key in out: + val = out[key] + pred_dict[key] = val.to_local() if isinstance(val, DTensor) else val + + if "pair_chains_iptm" in out: + pci = out["pair_chains_iptm"] + if isinstance(pci, dict): + # Success path: nested dict of DTensors → + # plain tensors. Values are globally reduced + # (Shard(0), Replicate(), Replicate()). + pred_dict["pair_chains_iptm"] = { + k1: {k2: v.to_local() if isinstance(v, DTensor) else v for k2, v in inner.items()} + for k1, inner in pci.items() + } + else: + # Fallback: compute_ptms failed and the + # confidence module returned a zero tensor. + # The writer iterates pair_chains_iptm as a + # nested dict — pass an empty dict so the + # comprehension produces nothing. + pred_dict["pair_chains_iptm"] = {} + + _assert_no_dtensors_in_output(pred_dict) + return pred_dict + + except RuntimeError as e: + if "out of memory" in str(e): + print("| WARNING: ran out of memory, skipping batch") + torch.cuda.empty_cache() + gc.collect() + return {"exception": True} + raise + + # ====================================================================== # + # Optimiser # + # ====================================================================== # + + def configure_optimizers(self) -> torch.optim.Optimizer: + """Configure the optimizer following the serial Boltz2 pattern.""" + param_dict = dict(self.named_parameters()) + + if self.structure_prediction_training: + all_parameter_names = [pn for pn, p in self.named_parameters() if p.requires_grad] + else: + all_parameter_names = [ + pn + for pn, p in self.named_parameters() + if p.requires_grad and ("out_token_feat_update" in pn or "confidence_module" in pn) + ] + + weight_decay = self.training_args.get("weight_decay", 0.0) + if weight_decay > 0 and self.training_args.get("weight_decay_exclude", False): + nodecay_params_names = [ + pn + for pn in all_parameter_names + if ( + "norm" in pn + or "rel_pos" in pn + or ".s_init" in pn + or ".z_init_" in pn + or "token_bonds" in pn + or "embed_atom_features" in pn + or "dist_bin_pairwise_embed" in pn + ) + ] + nodecay_params = [param_dict[pn] for pn in nodecay_params_names] + decay_params = [param_dict[pn] for pn in all_parameter_names if pn not in nodecay_params_names] + optim_groups = [ + {"params": decay_params, "weight_decay": weight_decay}, + {"params": nodecay_params, "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW( + optim_groups, + betas=(self.training_args.adam_beta_1, self.training_args.adam_beta_2), + eps=self.training_args.adam_eps, + lr=self.training_args.base_lr, + ) + else: + optimizer = torch.optim.AdamW( + [param_dict[pn] for pn in all_parameter_names], + betas=(self.training_args.adam_beta_1, self.training_args.adam_beta_2), + eps=self.training_args.adam_eps, + lr=self.training_args.base_lr, + weight_decay=weight_decay, + ) + + if self.training_args.lr_scheduler == "af3": + scheduler = AlphaFoldLRScheduler( + optimizer, + base_lr=self.training_args.base_lr, + max_lr=self.training_args.max_lr, + warmup_no_steps=self.training_args.lr_warmup_no_steps, + start_decay_after_n_steps=self.training_args.lr_start_decay_after_n_steps, + decay_every_n_steps=self.training_args.lr_decay_every_n_steps, + decay_factor=self.training_args.lr_decay_factor, + ) + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] + + return optimizer + + # ====================================================================== # + # EMA – Boltz2 uses callback-based EMA # + # ====================================================================== # + + def configure_callbacks(self) -> list: + """Configure model callbacks. + + When EMA is enabled, returns a :class:`DistributedEMA` callback which + handles DTensor ↔ plain-tensor conversions automatically. + """ + if self.use_ema: + return [DistributedEMA(self.ema_decay)] + return [] + + # ====================================================================== # + # Checkpoint hooks # + # ====================================================================== # + + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Adjust checkpoint hyperparameters on load (matching serial Boltz2).""" + lr = self.training_args.max_lr + weight_decay = self.training_args.weight_decay + if "optimizer_states" in checkpoint: + for state in checkpoint["optimizer_states"]: + for group in state["param_groups"]: + group["lr"] = lr + group["weight_decay"] = weight_decay + if "lr_schedulers" in checkpoint: + for scheduler in checkpoint["lr_schedulers"]: + scheduler["max_lr"] = lr + scheduler["base_lrs"] = [lr] * len(scheduler["base_lrs"]) + scheduler["_last_lr"] = [lr] * len(scheduler["_last_lr"]) + if "hyper_parameters" in checkpoint: + checkpoint["hyper_parameters"]["training_args"]["max_lr"] = lr + checkpoint["hyper_parameters"]["training_args"]["diffusion_multiplicity"] = ( + self.training_args.diffusion_multiplicity + ) + checkpoint["hyper_parameters"]["training_args"]["recycling_steps"] = self.training_args.recycling_steps + checkpoint["hyper_parameters"]["training_args"]["weight_decay"] = self.training_args.weight_decay diff --git a/src/boltz/distributed/model/modules/__init__.py b/src/boltz/distributed/model/modules/__init__.py new file mode 100644 index 000000000..01aa3c84b --- /dev/null +++ b/src/boltz/distributed/model/modules/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Distributed model modules with context parallelism support.""" diff --git a/src/boltz/distributed/model/modules/confidence_utils.py b/src/boltz/distributed/model/modules/confidence_utils.py new file mode 100644 index 000000000..11da59f43 --- /dev/null +++ b/src/boltz/distributed/model/modules/confidence_utils.py @@ -0,0 +1,942 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +"""DTensor-compatible confidence utility functions. + +This module provides DTensor implementations of confidence metric computations +used in the Boltz confidence module. Some helpers are shardwise, while others +redistribute inputs to replicated placements for global reductions. +""" + +import torch +import torch.nn.functional as F +from torch.autograd.function import FunctionCtx +from torch.distributed.tensor import DTensor, Partial, Replicate, Shard + +from boltz.data import const +from boltz.distributed.comm import TransposeComm +from boltz.distributed.model.layers.outer_op import OuterOp, distributed_outer_op +from boltz.distributed.model.loss.confidencev2 import compute_frame_pred +from boltz.distributed.utils import LayoutRightMap, update_exhaustive_strides +from boltz.model.modules.confidence_utils import ( + tm_function as serial_tm_function, +) + +# Sentinel value for chain_pair_iptm entries where the chain pair does not exist +# on this rank's batch. Valid iPTM values are in [0, 1], so -1.0 is unambiguous. +CHAIN_IPTM_SENTINEL = -1.0 + +# Small constant added to denominators to avoid division by zero in TM/iPTM-style metrics. +_EPS = 1e-5 + + +class _ComputeAggregatedMetricImpl(torch.autograd.Function): + """Autograd function for computing aggregated metric from logits. + + This implements the forward and backward passes for converting binned logits + to expected metric values via softmax and weighted sum. The computation is + shardwise (no communication required) since it operates along the replicated + bins dimension. + + The metric computation follows: + probs = softmax(logits, dim=-1) + metric = sum(probs * bounds, dim=-1) + + Where bounds are bin centers computed as: + bounds[i] = (i + 0.5) * bin_width, for i in [0, num_bins) + bin_width = end / num_bins + + The backward pass uses PyTorch autograd on the local computation graph, + avoiding the need for manual gradient derivation. + + See Also + -------- + compute_aggregated_metric : The public API function that calls this. + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx: FunctionCtx, + logits: DTensor, + end: float, + ) -> DTensor: + """Forward pass for computing aggregated metric. + + Parameters + ---------- + ctx : FunctionCtx + The autograd context object for saving tensors for backward. + logits : DTensor + Input logits tensor with bins as the last dimension. + Typical shapes: + - pLDDT: (B*mult, N_token, 50) with placements (Shard(0), Shard(1),Replicate()) + - pDE/pAE: (B*mult, N_token, N_token, 64) with placements (Shard(0), Shard(1), Shard(2)) + end : float + Maximum value of the metric range. Default 1.0 for pLDDT, 32.0 for pAE. + + Returns + ------- + DTensor + Output metric tensor with the last dimension reduced. + Shape is logits.shape[:-1]. + + Raises + ------ + TypeError + If logits is not a DTensor. + ValueError + If Partial placements are present or the last dimension is sharded. + """ + # Type checking + if not isinstance(logits, DTensor): + raise TypeError(f"Expected DTensor for logits, got {type(logits)}") + + device_mesh = logits.device_mesh + placements = logits.placements + + # Validate placements + last_dim = len(logits.shape) - 1 + for i_dim_device_mesh, placement in enumerate(placements): + if isinstance(placement, Partial): + raise ValueError("Partial placements are not supported") + elif isinstance(placement, Shard): + # Check that the last dimension (bins) is not sharded + if placement.dim == last_dim: + raise ValueError( + f"The bins dimension (dim={last_dim}) must not be sharded for compute_aggregated_metric" + ) + # Check that sharded dimensions are evenly divided + if logits.shape[placement.dim] % device_mesh.shape[i_dim_device_mesh] != 0: + raise ValueError( + f"Uneven sharding of tensor dimension {placement.dim} of size {logits.shape[placement.dim]} " + f"along device mesh dimension {i_dim_device_mesh} of size {device_mesh.shape[i_dim_device_mesh]} is not supported" + ) + + # Detach and set requires_grad to build a local computation graph + logits_local_orig = logits.to_local().detach().requires_grad_(logits.requires_grad) + + with torch.enable_grad(): + # Promote to at least float32 to match serial compute_aggregated_metric + # which uses default-dtype (float32) torch.arange for bounds and + # sum(probs * bounds) — element-wise ops that stay float32 under autocast. + compute_dtype = torch.promote_types(logits_local_orig.dtype, torch.float32) + num_bins = logits_local_orig.shape[-1] + bin_width = end / num_bins + bounds = (torch.arange(num_bins, device=logits_local_orig.device, dtype=compute_dtype) + 0.5) * bin_width + + probs = F.softmax(logits_local_orig.to(compute_dtype), dim=-1) + + # Use sum(probs * bounds) instead of matmul(probs, bounds) to match + # serial code. Under autocast, matmul is on the "lower precision" + # list and would downcast float32 probs to BF16, while element-wise + # multiply and sum are not affected by autocast. + metric_local = torch.sum( + probs * bounds.view(*((1,) * (probs.ndim - 1)), num_bins), + dim=-1, + ) + + # Compute output shape and stride (remove last dimension) + output_shape = tuple(logits.shape[:-1]) + output_stride = LayoutRightMap(output_shape).strides + + # Save tensors for backward pass + ctx.save_for_backward(logits_local_orig, metric_local) + ctx.device_mesh = device_mesh + ctx.placements = placements + ctx.logits_shape = logits.shape + ctx.logits_stride = logits.stride() + ctx.output_shape = output_shape + ctx.output_stride = output_stride + + # Create output DTensor + result = DTensor.from_local( + metric_local.detach(), + device_mesh=device_mesh, + placements=placements, + shape=output_shape, + stride=output_stride, + ) + + return result + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward( + ctx: FunctionCtx, + grad_output: DTensor, + ) -> tuple[DTensor | None, None]: + """Backward pass for computing aggregated metric. + + Computes gradients by backpropagating through the local computation graph + that was built during the forward pass. This leverages PyTorch's autograd + rather than manual gradient computation. + + Parameters + ---------- + ctx : FunctionCtx + The autograd context containing saved tensors from forward: + - logits_local_orig: Input logits tensor (local) + - metric_local: Output metric tensor that holds the computation graph + - device_mesh, placements: DTensor metadata + - logits_shape, logits_stride: Shape/stride info for DTensor reconstruction + grad_output : DTensor + Gradient of loss with respect to output metric. + + Returns + ------- + tuple[DTensor | None, None] + Gradients for each forward input in order: + - d_logits: DTensor or None, gradient for logits + - None: end parameter (non-differentiable) + + Notes + ----- + The gradient computation follows the chain rule: + + Forward: + probs = softmax(logits, dim=-1) + metric = sum(probs * bounds, dim=-1) + + Backward (using autograd): + d_logits = d_metric @ d_metric/d_probs @ d_probs/d_logits + = d_metric @ bounds @ softmax_jacobian + + Where softmax_jacobian for dim=-1 is block-diagonal per position, + making this a purely local operation. + """ + logits_local, metric_local = ctx.saved_tensors + + if not logits_local.requires_grad: + return None, None + + grad_output_local = grad_output.to_local() + + # Backprop via the local graph + (d_logits_local,) = torch.autograd.grad( + outputs=[metric_local], + inputs=[logits_local], + grad_outputs=[grad_output_local], + retain_graph=False, # Frees the local graph immediately + ) + + # Wrap gradient in DTensor + d_logits = DTensor.from_local( + d_logits_local, + device_mesh=ctx.device_mesh, + placements=ctx.placements, + shape=ctx.logits_shape, + stride=ctx.logits_stride, + ) + + return d_logits, None + + +def compute_aggregated_metric(logits: DTensor, end: float = 1.0) -> DTensor: + """Compute the metric from logits via softmax and weighted sum. + + This is the DTensor-compatible version of the serial compute_aggregated_metric + function. It converts binned logits to expected metric values by computing + the softmax-weighted average of bin centers. + + All operations are shardwise (no inter-rank communication required) since + they operate along the replicated bins dimension (last dimension). + + Parameters + ---------- + logits : DTensor + The input logits tensor with bins as the last dimension. + Typical shapes and placements: + - pLDDT: (B*mult, N_token, 50) with placements (Shard(0), Replicate()) + - pDE/pAE: (B*mult, N_token, N_token, 64) with placements (Shard(0), Shard(1), Replicate()) + end : float, optional + Maximum value of the metric range, by default 1.0. + Use 1.0 for pLDDT, 32.0 for pAE. + + Returns + ------- + DTensor + The computed metric tensor with shape logits.shape[:-1]. + Placements are preserved from input. + + Examples + -------- + >>> # pLDDT computation + >>> plddt_logits = ... # DTensor with shape (B*mult, N_token, 50) + >>> plddt = compute_aggregated_metric(plddt_logits, end=1.0) + >>> # plddt has shape (B*mult, N_token) + + >>> # pAE computation + >>> pae_logits = ... # DTensor with shape (B*mult, N_token, N_token, 64) + >>> pae = compute_aggregated_metric(pae_logits, end=32.0) + >>> # pae has shape (B*mult, N_token, N_token) + + See Also + -------- + boltz.model.modules.confidence_utils.compute_aggregated_metric : Serial version. + """ + return _ComputeAggregatedMetricImpl.apply(logits, end) + + +class _LocalShardedSum(torch.autograd.Function): + """Sum over a sharded dimension with all-reduce within the placement group. + + Forward: local sum over reduced_dim, then all_reduce(SUM) over the process + group that shards that dimension. Backward: gradient is replicated along + reduced_dim (evenly to all shards). + + Parameters + ---------- + x_local : torch.Tensor + Local shard of the DTensor (i.e. ``dtensor.to_local()``). Must have + the same shape and layout as the local piece implied by the global + DTensor shape and input_placements on this rank. + reduced_dim : int + Dimension to reduce (0-based). Must be one of the dimensions that has + ``Shard(d)`` in input_placements; the all-reduce runs over the process + group for that placement. + input_placements : tuple[object, ...] + Placements of the original DTensor (e.g. ``(Shard(0), Shard(1), Replicate())``). + Used to select which mesh dimension to all-reduce and to iterate + ``device_mesh.get_all_groups()``. + device_mesh : DeviceMesh + The DTensor device mesh. Must match the mesh used to shard the tensor + that produced x_local. + + Returns + ------- + torch.Tensor + Local shard of the sum result. The dimension reduced_dim is removed; + on ranks that shard that dimension, the local chunk is the same (replicated). + """ + + @staticmethod + def forward( # type: ignore[override] + ctx, + x_local: torch.Tensor, + reduced_dim: int, + input_placements: tuple[object, ...], + device_mesh, + ) -> torch.Tensor: + """Sum over reduced_dim and all-reduce across ranks that shard it.""" + output_local = torch.sum(x_local, dim=reduced_dim, keepdim=False) + for placement, placement_group in zip(input_placements, device_mesh.get_all_groups()): + if isinstance(placement, Shard) and placement.dim == reduced_dim: + torch.distributed.all_reduce( + output_local, + op=torch.distributed.ReduceOp.SUM, + group=placement_group, + ) + ctx.input_local_shape = x_local.shape + ctx.reduced_dim = reduced_dim + return output_local + + @staticmethod + def backward(ctx, grad_output_local: torch.Tensor) -> tuple[torch.Tensor | None, None, None, None]: + """Replicate grad_output along reduced_dim for upstream.""" + dx_local = grad_output_local.unsqueeze(ctx.reduced_dim) + dx_local = dx_local.expand(ctx.input_local_shape).clone(memory_format=torch.contiguous_format) + return dx_local, None, None, None + + +class _LocalShardedMax(torch.autograd.Function): + """Max over a sharded dimension with all-reduce within the placement group. + + Forward: local max over reduced_dim, then all_reduce(MAX) over the process + group that shards that dimension. Backward: gradient flows only to the + local argmax elements. + + Parameters + ---------- + x_local : torch.Tensor + Local shard of the DTensor. Same semantics as _LocalShardedSum: must + be the local piece implied by the global shape and input_placements. + reduced_dim : int + Dimension to reduce. Must correspond to a Shard in input_placements. + input_placements : tuple[object, ...] + Placements of the original DTensor; used to find the group for all_reduce(MAX). + device_mesh : DeviceMesh + Device mesh for the DTensor. + + Returns + ------- + torch.Tensor + Local shard of the max result (reduced_dim removed). After all-reduce, + all ranks in the group that shard reduced_dim hold the same values. + """ + + @staticmethod + def forward( # type: ignore[override] + ctx, + x_local: torch.Tensor, + reduced_dim: int, + input_placements: tuple[object, ...], + device_mesh, + ) -> torch.Tensor: + """Max over reduced_dim and all-reduce across ranks that shard it.""" + output_local_keepdim = torch.amax(x_local, dim=reduced_dim, keepdim=True) + for placement, placement_group in zip(input_placements, device_mesh.get_all_groups()): + if isinstance(placement, Shard) and placement.dim == reduced_dim: + torch.distributed.all_reduce( + output_local_keepdim, + op=torch.distributed.ReduceOp.MAX, + group=placement_group, + ) + ctx.reduced_dim = reduced_dim + ctx.save_for_backward(x_local, output_local_keepdim) + return output_local_keepdim.squeeze(reduced_dim) + + @staticmethod + def backward(ctx, grad_output_local: torch.Tensor) -> tuple[torch.Tensor | None, None, None, None]: + """Route gradient only to elements that matched the max.""" + x_local, output_local_keepdim = ctx.saved_tensors + grad_output_local = grad_output_local.unsqueeze(ctx.reduced_dim) + mask = x_local == output_local_keepdim + dx_local = grad_output_local * mask + dx_local = dx_local.expand(x_local.shape).clone(memory_format=torch.contiguous_format) + return dx_local, None, None, None + + +def _reduced_placements( + input_placements: tuple[object, ...], input_shape: torch.Size, reduced_dim: int +) -> tuple[object, ...]: + """Placements for a tensor after reducing (removing) one dimension. + + Used to build the output DTensor layout when a reduction (e.g. sum or max) + over reduced_dim is performed: the reduced dimension is squeezed out, so + placements must be updated accordingly. + + Parameters + ---------- + input_placements : tuple[object, ...] + Placements of the tensor before reduction (e.g. (Shard(0), Shard(1), Replicate())). + input_shape : torch.Size + Shape of the tensor before reduction. Used to get ndim and to map + placement dim indices after removing reduced_dim. + reduced_dim : int + The dimension that was reduced (0-based). This dimension is removed + in the output. + + Returns + ------- + tuple[object, ...] + Placements for the reduced tensor. The placement that was Shard(reduced_dim) + becomes Replicate(); any Shard(d) with d > reduced_dim becomes + Shard(d - 1); Replicate() is unchanged. + """ + ndim = len(input_shape) + shift = torch.zeros(ndim, dtype=torch.int64) + shift[reduced_dim] = 1 + map_dims = (torch.arange(ndim, dtype=torch.int64) - shift.cumsum(0)).tolist() + output_placements = [] + for placement in input_placements: + if isinstance(placement, Shard) and placement.dim == reduced_dim: + output_placements.append(Replicate()) + elif isinstance(placement, Shard): + output_placements.append(Shard(map_dims[placement.dim])) + elif isinstance(placement, Replicate): + output_placements.append(placement) + return tuple(output_placements) + + +def _reduced_shape_stride( + input_shape: torch.Size, input_stride: tuple[int, ...], reduced_dim: int +) -> tuple[torch.Size, tuple[int, ...]]: + """Shape and stride for a tensor after removing (squeezing) one dimension. + + Used together with _reduced_placements to construct the correct shape and + stride for DTensor.from_local() when building a reduced output (e.g. after + sum/max over reduced_dim). + + Parameters + ---------- + input_shape : torch.Size + Shape before reduction. + input_stride : tuple[int, ...] + Stride of the tensor before reduction (must match input_shape layout). + reduced_dim : int + Dimension that was reduced and is to be removed (0-based). + + Returns + ------- + tuple[torch.Size, tuple[int, ...]] + output_shape: input_shape with reduced_dim removed (length ndim - 1). + output_stride: strides with the reduced_dim entry removed, consistent + with the new shape. + """ + shape_output = list(input_shape) + shape_output[reduced_dim] = 1 + shape_output = tuple(shape_output) + strides_output = update_exhaustive_strides(input_shape, input_stride, shape_output) + shape_output = tuple(dim for i, dim in enumerate(shape_output) if i != reduced_dim) + strides_output = tuple(stride for i, stride in enumerate(strides_output) if i != reduced_dim) + return torch.Size(shape_output), strides_output + + +class _ComputePtmsImpl(torch.autograd.Function): + """Fused pTM/ipTM computation from PAE logits and token/chain masks. + + Computes pTM, ipTM, ligand ipTM, protein ipTM, and per-chain-pair ipTM + using distributed outer ops (TransposeComm) and local sharded sum/max + (_LocalShardedSum, _LocalShardedMax). Aggregation is performed only within + each CP group so that results match serial semantics when the serial + reference is run on the same DP chunk. + + See Also + -------- + compute_ptms : Public API that builds masks and calls this. + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward( # type: ignore[override] + ctx: FunctionCtx, + mask_collinear_pred: DTensor, + token_pad_mask: DTensor, + asym_id_base: DTensor, + mol_type: DTensor, + logits: DTensor, + multiplicity: int, + transpose_comm: TransposeComm, + ) -> tuple[DTensor, DTensor, DTensor, DTensor, dict[int, dict[int, DTensor]]]: + """Compute pTM, ipTM, ligand/protein ipTM, and chain_pair_iptm from local shards. + + Parameters + ---------- + mask_collinear_pred : DTensor + Mask of shape (B, mult, N_token) or (B*mult, N_token), True where + the token is valid for PTM (non-collinear frame). Typically + (Shard(0), Shard(1), Replicate()). Must be on the same device_mesh + as other inputs. + token_pad_mask : DTensor + Token padding mask, shape (B, N_token). True = valid token. + Placements (Shard(0), Shard(1), Replicate()). Repeated internally + by multiplicity for the computation. + asym_id_base : DTensor + Chain ID per token, shape (B, N_token). Used to build inter-chain + pair masks for ipTM and chain_pair_iptm. Same placements as + token_pad_mask. Unique values are gathered within cp_axis_0 to form + chain_pair_iptm keys. + mol_type : DTensor + Token molecule type (e.g. PROTEIN, NONPOLYMER), shape (B, N_token). + Used for ligand_iptm and protein_iptm masks. Same mesh and + placement convention as token_pad_mask. + logits : DTensor + PAE logits, shape (B*mult, N_token, N_token, num_bins). Placements + (Shard(0), Shard(1), Shard(2)). Bins dimension must be replicated. + multiplicity : int + Number of diffusion samples per batch element. Batch dimension + of logits is B * multiplicity. + transpose_comm : TransposeComm + Communication helper for the CP subgroup; used by distributed_outer_op + to build full pair masks from sharded rows/columns. + + Returns + ------- + tuple[DTensor, DTensor, DTensor, DTensor, dict[int, dict[int, DTensor]]] + ptm, iptm, ligand_iptm, protein_iptm (each shape (B,) DTensors), and + chain_pair_iptm mapping (idx1, idx2) -> DTensor of shape (B,). + Dict keys are the union of chain IDs across all ranks (world-level), + so all DP ranks have identical key sets. Entries where a chain pair + does not exist on this DP rank's batch are filled with sentinel + value -1.0. + + Requirements + ------------ + - All DTensor inputs must use the same device_mesh. + - token_pad_mask, asym_id_base, mol_type must have the same placements + (typically (Shard(0), Shard(1), Replicate())). + - logits placements must have the last dimension (bins) replicated. + + Raises + ------ + TypeError + If any of mask_collinear_pred, token_pad_mask, asym_id_base, + mol_type, or logits is not a DTensor. + ValueError + If any input's device_mesh differs from token_pad_mask.device_mesh, + or if transpose_comm is invalid for the current mesh. + """ + if not isinstance(mask_collinear_pred, DTensor): + raise TypeError(f"Expected DTensor for mask_collinear_pred, got {type(mask_collinear_pred)}") + if not isinstance(token_pad_mask, DTensor): + raise TypeError(f"Expected DTensor for token_pad_mask, got {type(token_pad_mask)}") + if not isinstance(asym_id_base, DTensor): + raise TypeError(f"Expected DTensor for asym_id, got {type(asym_id_base)}") + if not isinstance(mol_type, DTensor): + raise TypeError(f"Expected DTensor for mol_type, got {type(mol_type)}") + if not isinstance(logits, DTensor): + raise TypeError(f"Expected DTensor for logits, got {type(logits)}") + + device_mesh = token_pad_mask.device_mesh + if mask_collinear_pred.device_mesh != device_mesh: + raise ValueError( + "mask_collinear_pred must be on the same device mesh as token_pad_mask, " + f"got {mask_collinear_pred.device_mesh} and {device_mesh}" + ) + if asym_id_base.device_mesh != device_mesh: + raise ValueError( + f"asym_id must be on the same device mesh as token_pad_mask, got {asym_id_base.device_mesh} and {device_mesh}" + ) + if mol_type.device_mesh != device_mesh: + raise ValueError( + f"mol_type must be on the same device mesh as token_pad_mask, got {mol_type.device_mesh} and {device_mesh}" + ) + if logits.device_mesh != device_mesh: + raise ValueError( + f"logits must be on the same device mesh as token_pad_mask, got {logits.device_mesh} and {device_mesh}" + ) + + group_replicate = device_mesh.get_group("cp_axis_1") + + maski_local = mask_collinear_pred.to_local().bool() + maski_local = maski_local.reshape(-1, maski_local.shape[-1]) + mask_pad_local = token_pad_mask.to_local().bool().repeat_interleave(multiplicity, dim=0) + asym_id_local = asym_id_base.to_local().repeat_interleave(multiplicity, dim=0) + + pair_mask_row_local = mask_pad_local & maski_local + pair_mask_ptm_local = distributed_outer_op( + pair_mask_row_local, + op=OuterOp.BITAND, + axis=1, + input_t=mask_pad_local, + transpose_comm=transpose_comm, + group_replicate=group_replicate, + ) + pair_mask_iptm_equal_local = distributed_outer_op( + asym_id_local, + op=OuterOp.EQUAL, + axis=1, + transpose_comm=transpose_comm, + group_replicate=group_replicate, + ) + pair_mask_iptm_local = pair_mask_ptm_local & (~pair_mask_iptm_equal_local) + + token_type_local = mol_type.to_local().repeat_interleave(multiplicity, dim=0) + is_ligand_token_local = token_type_local == const.chain_type_ids["NONPOLYMER"] + is_protein_token_local = token_type_local == const.chain_type_ids["PROTEIN"] + ligand_iptm_mask_row_local = distributed_outer_op( + is_ligand_token_local.bool(), + op=OuterOp.BITAND, + axis=1, + transpose_comm=transpose_comm, + group_replicate=group_replicate, + input_t=is_protein_token_local, + ) + ligand_iptm_mask_col_local = distributed_outer_op( + is_protein_token_local.bool(), + op=OuterOp.BITAND, + axis=1, + transpose_comm=transpose_comm, + group_replicate=group_replicate, + input_t=is_ligand_token_local, + ) + ligand_iptm_mask_local = (ligand_iptm_mask_row_local | ligand_iptm_mask_col_local) & pair_mask_iptm_local + + protein_iptm_mask_local = distributed_outer_op( + is_protein_token_local.bool(), + op=OuterOp.BITAND, + axis=1, + transpose_comm=transpose_comm, + group_replicate=group_replicate, + ) + protein_iptm_mask_local = protein_iptm_mask_local & pair_mask_iptm_local + + reduced_dim = token_pad_mask.ndim - 1 + n_res_local = _LocalShardedSum.apply( + mask_pad_local, + reduced_dim, + token_pad_mask.placements, + device_mesh, + ).unsqueeze(reduced_dim) + logits_local = logits.to_local().detach() + n_res_local = n_res_local.detach() + + num_bins = logits_local.shape[-1] + bin_width = 32.0 / num_bins + # Use at least float32 for bin centers to match serial code's default-dtype torch.arange + compute_dtype = torch.promote_types(logits_local.dtype, torch.float32) + pae_value = (torch.arange(num_bins, device=logits_local.device, dtype=compute_dtype) + 0.5) * bin_width + pae_value = pae_value.unsqueeze(0) + tm_value = serial_tm_function(pae_value, n_res_local).unsqueeze(1).unsqueeze(2) + probs = F.softmax(logits_local.to(compute_dtype), dim=-1) + tm_expected_value_local = torch.sum(probs * tm_value, dim=-1) + + reduced_dim = tm_expected_value_local.ndim - 1 + ptm_shape1, ptm_stride1 = _reduced_shape_stride( + logits.shape[:-1], LayoutRightMap(tuple(logits.shape[:-1])).strides, reduced_dim + ) + ptm_placements1 = _reduced_placements(logits.placements, logits.shape[:-1], reduced_dim) + ptm_shape2, ptm_stride2 = _reduced_shape_stride(ptm_shape1, ptm_stride1, 1) + output_placements = _reduced_placements(ptm_placements1, ptm_shape1, 1) + mask_placements = (Shard(0), Shard(1), Shard(2)) + + ptm_mask_local = pair_mask_ptm_local.bool() + ptm_numerator_local = _LocalShardedSum.apply( + tm_expected_value_local.masked_fill(~ptm_mask_local, 0), + reduced_dim, + logits.placements, + device_mesh, + ) + ptm_denominator_local = _LocalShardedSum.apply( + ptm_mask_local, + reduced_dim, + mask_placements, + device_mesh, + ) + ptm_local = ptm_numerator_local / (ptm_denominator_local.to(tm_expected_value_local.dtype) + _EPS) + ptm_local = _LocalShardedMax.apply(ptm_local, 1, ptm_placements1, device_mesh) + + iptm_mask_local = pair_mask_iptm_local.bool() + iptm_numerator_local = _LocalShardedSum.apply( + tm_expected_value_local.masked_fill(~iptm_mask_local, 0), + reduced_dim, + logits.placements, + device_mesh, + ) + iptm_denominator_local = _LocalShardedSum.apply( + iptm_mask_local, + reduced_dim, + mask_placements, + device_mesh, + ) + iptm_local = iptm_numerator_local / (iptm_denominator_local.to(tm_expected_value_local.dtype) + _EPS) + iptm_local = _LocalShardedMax.apply(iptm_local, 1, ptm_placements1, device_mesh) + + ligand_mask_local = ligand_iptm_mask_local.bool() + ligand_num_local = _LocalShardedSum.apply( + tm_expected_value_local.masked_fill(~ligand_mask_local, 0), + reduced_dim, + logits.placements, + device_mesh, + ) + ligand_den_local = _LocalShardedSum.apply( + ligand_mask_local, + reduced_dim, + mask_placements, + device_mesh, + ) + ligand_local = ligand_num_local / (ligand_den_local.to(tm_expected_value_local.dtype) + _EPS) + ligand_local = _LocalShardedMax.apply(ligand_local, 1, ptm_placements1, device_mesh) + + protein_mask_local = protein_iptm_mask_local.bool() + protein_num_local = _LocalShardedSum.apply( + tm_expected_value_local.masked_fill(~protein_mask_local, 0), + reduced_dim, + logits.placements, + device_mesh, + ) + protein_den_local = _LocalShardedSum.apply( + protein_mask_local, + reduced_dim, + mask_placements, + device_mesh, + ) + protein_local = protein_num_local / (protein_den_local.to(tm_expected_value_local.dtype) + _EPS) + protein_local = _LocalShardedMax.apply(protein_local, 1, ptm_placements1, device_mesh) + + chain_pair_iptm: dict[int, dict[int, DTensor]] = {} + local_asym_ids = set(torch.unique(asym_id_local).tolist()) + cp_axis_0_group = device_mesh.get_group("cp_axis_0") + cp_obj_list = [None] * torch.distributed.get_world_size(group=cp_axis_0_group) + torch.distributed.all_gather_object(cp_obj_list, local_asym_ids, group=cp_axis_0_group) + cp_asym_ids = set().union(*cp_obj_list) + + dp_group = device_mesh.get_group("dp") + dp_obj_list = [None] * torch.distributed.get_world_size(group=dp_group) + torch.distributed.all_gather_object(dp_obj_list, cp_asym_ids, group=dp_group) + world_asym_ids_list = sorted(set().union(*dp_obj_list)) + for idx1 in world_asym_ids_list: + chain_iptm: dict[int, DTensor] = {} + for idx2 in world_asym_ids_list: + if idx1 not in cp_asym_ids or idx2 not in cp_asym_ids: + iptm_chain_local = torch.full( + (mask_pad_local.size(0),), + CHAIN_IPTM_SENTINEL, + device=mask_pad_local.device, + dtype=tm_expected_value_local.dtype, + ) + else: + mask_pair_chain_row = maski_local & (asym_id_local == idx2) & mask_pad_local + mask_pair_chain_col = (asym_id_local == idx1) & mask_pad_local + mask_pair_chain_local = distributed_outer_op( + mask_pair_chain_row, + op=OuterOp.BITAND, + axis=1, + transpose_comm=transpose_comm, + group_replicate=group_replicate, + input_t=mask_pair_chain_col, + ) + mask_pair_chain_local = mask_pair_chain_local.bool() + numerator_local = _LocalShardedSum.apply( + tm_expected_value_local.masked_fill(~mask_pair_chain_local, 0), + reduced_dim, + logits.placements, + device_mesh, + ) + denominator_local = _LocalShardedSum.apply( + mask_pair_chain_local, + reduced_dim, + mask_placements, + device_mesh, + ) + iptm_chain_local = numerator_local / (denominator_local.to(tm_expected_value_local.dtype) + _EPS) + iptm_chain_local = _LocalShardedMax.apply(iptm_chain_local, 1, ptm_placements1, device_mesh) + + chain_iptm[idx2] = DTensor.from_local( + iptm_chain_local, + device_mesh=device_mesh, + placements=output_placements, + shape=ptm_shape2, + stride=ptm_stride2, + ) + chain_pair_iptm[idx1] = chain_iptm + + ptm = DTensor.from_local( + ptm_local, + device_mesh=device_mesh, + placements=output_placements, + shape=ptm_shape2, + stride=ptm_stride2, + ) + iptm = DTensor.from_local( + iptm_local, + device_mesh=device_mesh, + placements=output_placements, + shape=ptm_shape2, + stride=ptm_stride2, + ) + ligand_iptm = DTensor.from_local( + ligand_local, + device_mesh=device_mesh, + placements=output_placements, + shape=ptm_shape2, + stride=ptm_stride2, + ) + protein_iptm = DTensor.from_local( + protein_local, + device_mesh=device_mesh, + placements=output_placements, + shape=ptm_shape2, + stride=ptm_stride2, + ) + + return ptm, iptm, ligand_iptm, protein_iptm, chain_pair_iptm + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward( # type: ignore[override] + ctx: FunctionCtx, + grad_ptm: DTensor, + grad_iptm: DTensor, + grad_ligand_iptm: DTensor, + grad_protein_iptm: DTensor, + grad_chain_pair: dict[int, dict[int, DTensor]], + ) -> tuple[None, None, None, None, None, None, None]: + """Backward is not supported; returns None for all inputs.""" + return None, None, None, None, None, None, None + + +def compute_ptms( + logits: DTensor, + x_preds: DTensor, + feats: dict[str, DTensor], + multiplicity: int, + transpose_comm: TransposeComm, +) -> tuple[DTensor, DTensor, DTensor, DTensor, dict[int, dict[int, DTensor]]]: + """Compute pTM and ipTM scores for DTensor inputs. + + This redistributes PAE logits and token-level features to replicated placements + to compute global reductions (max/sum) across tokens and chains. + + Args: + logits: DTensor of shape (batch * multiplicity, num_tokens, num_tokens, num_bins) + with placements (Shard(0), Shard(1), Shard(2)) containing PAE prediction logits. + x_preds: DTensor of shape (batch * multiplicity, num_atoms, 3) with placements + (Shard(0), Shard(1), Replicate()) containing predicted atom coordinates. + feats: DTensor feature dict with required keys (frames_idx, asym_id, atom_to_token, + atom_pad_mask, atom_resolved_mask, mol_type, token_pad_mask). + multiplicity: Number of copies per sample in the batch dimension. + transpose_comm: TransposeComm object for distributed outer operations. + + Returns: + Tuple containing: + - ptm: DTensor of shape (batch,) with confidence scores for predicted templates + - iptm: DTensor of shape (batch,) with interface confidence scores + - ligand_iptm: DTensor of shape (batch,) with ligand-protein interface scores + - protein_iptm: DTensor of shape (batch,) with protein-protein interface scores + - chain_pair_iptm: Dict mapping chain pairs to their interface confidence + DTensors. Keys are the world-level union of chain IDs (homogeneous across + all DP ranks). Entries where a chain pair does not exist on this DP rank's + batch are filled with sentinel value CHAIN_IPTM_SENTINEL (-1.0). + """ + feats_keys = { + "frames_idx", + "asym_id", + "atom_to_token", + "atom_pad_mask", + "atom_resolved_mask", + "mol_type", + "token_pad_mask", + } + if any(k not in feats for k in feats_keys): + raise ValueError(f"feats must contain the following keys: {feats_keys}, got {feats.keys()}") + if not isinstance(logits, DTensor): + raise TypeError(f"Expected DTensor for logits, got {type(logits)}") + if not isinstance(x_preds, DTensor): + raise TypeError(f"Expected DTensor for x_preds, got {type(x_preds)}") + if feats["frames_idx"].ndim == 4: + raise ValueError( + f"frames_idx has unsqueezed ensemble dim (ndim=4, shape={feats['frames_idx'].shape}). " + "Only E=1 is supported; squeeze the ensemble dim before calling compute_ptms." + ) + + device_mesh = logits.device_mesh + if x_preds.device_mesh != device_mesh: + raise ValueError( + f"x_preds must be on the same device mesh as logits, got {x_preds.device_mesh} and {device_mesh}" + ) + for key in feats_keys: + if feats[key].device_mesh != device_mesh: + raise ValueError( + f"feats[{key}] must be on the same device mesh as logits, got {feats[key].device_mesh} and {device_mesh}" + ) + + _, mask_collinear_pred = compute_frame_pred( + x_preds, + feats["frames_idx"], + feats, + multiplicity, + inference=True, + ) + + ptm, iptm, ligand_iptm, protein_iptm, chain_pair_iptm = _ComputePtmsImpl.apply( + mask_collinear_pred, + feats["token_pad_mask"], + feats["asym_id"], + feats["mol_type"], + logits, + multiplicity, + transpose_comm, + ) + + return ptm, iptm, ligand_iptm, protein_iptm, chain_pair_iptm diff --git a/src/boltz/distributed/model/modules/confidencev2.py b/src/boltz/distributed/model/modules/confidencev2.py new file mode 100644 index 000000000..0a85c6277 --- /dev/null +++ b/src/boltz/distributed/model/modules/confidencev2.py @@ -0,0 +1,886 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +"""DTensor-based v2 confidence module and heads. + +This module provides the distributed implementation of the Boltz-2 ConfidenceModule +and ConfidenceHeads, including support for separate intra/inter-chain heads, +token-level pLDDT confidence, and updated iPLDDT weighting. + +Only ``token_level_confidence=True`` is supported. The atom-level path +(``token_level_confidence=False``) raises ``NotImplementedError``. + +Placement conventions (3-D device mesh: [dp, cp_axis_0, cp_axis_1]): + s: (Shard(0), Shard(1), Replicate()) — single representation + z: (Shard(0), Shard(1), Shard(2)) — pair representation + d: (Shard(0), Shard(1), Shard(2)) — distance matrix + x_pred:(Shard(0), Shard(1), Replicate()) — predicted coords + scalar metrics: (Shard(0), Replicate(), Replicate()) +""" + +import warnings +from copy import deepcopy + +import torch +from torch import nn +from torch.autograd.function import FunctionCtx +from torch.distributed.tensor import DTensor, Partial, Replicate, Shard + +from boltz.data import const +from boltz.distributed.comm import TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.atom_to_token import single_repr_rep_atom_to_token +from boltz.distributed.model.layers.elementwise_op import ElementwiseOp, elementwise_op +from boltz.distributed.model.layers.embedding import EmbeddingParamsReplicated +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.linear import LinearParamsReplicated +from boltz.distributed.model.layers.outer_op import OuterOp, replicate_to_shard_outer_op +from boltz.distributed.model.layers.pairformer import PairformerModule +from boltz.distributed.model.layers.redistribute_transpose import redistribute_transpose +from boltz.distributed.model.layers.repeat_interleave import shardwise_repeat_interleave +from boltz.distributed.model.layers.shardwise_op import shardwise_distogram +from boltz.distributed.model.modules.confidence_utils import ( + compute_aggregated_metric, + compute_ptms, +) +from boltz.distributed.model.modules.encoders import RelativePositionEncoder +from boltz.distributed.model.modules.trunkv2 import ContactConditioning +from boltz.distributed.utils import update_exhaustive_strides +from boltz.model.modules.confidencev2 import ( + IPLDDT_INTERFACE_WEIGHT, + IPLDDT_LIGAND_WEIGHT, + IPLDDT_NON_INTERFACE_WEIGHT, +) +from boltz.model.modules.confidencev2 import ConfidenceHeads as SerialConfidenceHeadsV2 +from boltz.model.modules.confidencev2 import ConfidenceModule as SerialConfidenceModuleV2 + + +class _ShardwiseWhere(torch.autograd.Function): + """Select between two 4-D pair DTensors using a 3-D boolean condition, shardwise. + + Computes ``torch.where(cond[..., None], a, b)`` on local shards. + Gradients flow to *a* where ``cond`` is True and to *b* where False. + + Communication budget: 0 collectives — purely shardwise. + + Parameters (forward) + ---------- + cond_local : Tensor + Plain (non-DTensor) bool tensor of shape ``(B_local*mult, N_row, N_col)``. + a, b : DTensor + Shape ``(B*mult, N, N, D)`` with identical placements (typically + ``(Shard(0), Shard(1), Shard(2))``). + """ + + @staticmethod + def forward(ctx: FunctionCtx, cond_local: torch.Tensor, a: DTensor, b: DTensor) -> DTensor: + if not isinstance(a, DTensor) or not isinstance(b, DTensor): + raise TypeError(f"Expected DTensors for a and b, got {type(a)} and {type(b)}") + if a.device_mesh != b.device_mesh or a.placements != b.placements: + raise ValueError("a and b must share the same device_mesh and placements") + for p in a.placements: + if isinstance(p, Partial): + raise ValueError("Partial placements are not supported") + expected_cond_shape = a.to_local().shape[:3] + if cond_local.shape != expected_cond_shape: + raise ValueError( + f"cond_local shape {tuple(cond_local.shape)} does not match " + f"local a shape prefix {tuple(expected_cond_shape)}" + ) + + a_local = a.to_local() + b_local = b.to_local() + cond_expanded = cond_local.unsqueeze(-1) + result_local = torch.where(cond_expanded, a_local, b_local) + + ctx.save_for_backward(cond_local) + ctx._a_requires_grad = a.requires_grad + ctx._b_requires_grad = b.requires_grad + ctx._device_mesh = a.device_mesh + ctx._placements = a.placements + ctx._shape = a.shape + ctx._stride = a.stride() + + return DTensor.from_local( + result_local, + device_mesh=a.device_mesh, + placements=a.placements, + shape=a.shape, + stride=a.stride(), + ) + + @staticmethod + def backward(ctx: FunctionCtx, grad_output: DTensor): + (cond_local,) = ctx.saved_tensors + go_local = grad_output.to_local() + cond_expanded = cond_local.unsqueeze(-1) + zero = torch.zeros_like(go_local) + + d_a = ( + DTensor.from_local( + torch.where(cond_expanded, go_local, zero), + device_mesh=ctx._device_mesh, + placements=ctx._placements, + shape=ctx._shape, + stride=ctx._stride, + ) + if ctx._a_requires_grad + else None + ) + d_b = ( + DTensor.from_local( + torch.where(cond_expanded, zero, go_local), + device_mesh=ctx._device_mesh, + placements=ctx._placements, + shape=ctx._shape, + stride=ctx._stride, + ) + if ctx._b_requires_grad + else None + ) + return None, d_a, d_b + + +class ConfidenceHeads(nn.Module): + """DTensor-based v2 confidence heads. + + Wraps the serial ``ConfidenceHeadsV2`` layer, distributing parameters with + ``LinearParamsReplicated`` and adding sharded metric computation. + + Compared to the v1 distributed ``ConfidenceHeads``: + * PAE is always computed (no ``compute_pae`` flag). + * Optional ``use_separate_heads`` splits PAE/PDE into intra/inter-chain projections. + * iPLDDT weights updated to ``ligand=20, interface=10, non_interface=1``. + * PTM/iPTM always computed (with try/except fallback). + + Only ``token_level_confidence=True`` is supported. Constructing with + ``token_level_confidence=False`` raises ``NotImplementedError``. + """ + + def __init__( + self, + layer: SerialConfidenceHeadsV2, + device_mesh: torch.distributed.device_mesh.DeviceMesh, + transpose_comm: TransposeComm, + ): + super().__init__() + + # token_level_confidence = True is the default setting in the public checkpoint + if not layer.token_level_confidence: + raise NotImplementedError( + "ConfidenceHeads distributed v2 only supports token_level_confidence=True. " + "The atom-level confidence path is not implemented for DTensor." + ) + + self.device_mesh = device_mesh + self.transpose_comm = transpose_comm + self.token_level_confidence = layer.token_level_confidence + self.use_separate_heads = layer.use_separate_heads + + # --- PAE / PDE heads --- + if self.use_separate_heads: + self.to_pae_intra_logits = LinearParamsReplicated(layer.to_pae_intra_logits, device_mesh) + self.to_pae_inter_logits = LinearParamsReplicated(layer.to_pae_inter_logits, device_mesh) + + self.to_pde_intra_logits = LinearParamsReplicated(layer.to_pde_intra_logits, device_mesh) + self.to_pde_inter_logits = LinearParamsReplicated(layer.to_pde_inter_logits, device_mesh) + else: + self.to_pae_logits = LinearParamsReplicated(layer.to_pae_logits, device_mesh) + self.to_pde_logits = LinearParamsReplicated(layer.to_pde_logits, device_mesh) + + # --- pLDDT / resolved heads --- + self.to_plddt_logits = LinearParamsReplicated(layer.to_plddt_logits, device_mesh) + self.to_resolved_logits = LinearParamsReplicated(layer.to_resolved_logits, device_mesh) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward( + self, + s: DTensor, + z: DTensor, + x_pred: DTensor, + d: DTensor, + feats: dict, + pred_distogram_logits: DTensor, + multiplicity: int = 1, + ) -> dict[str, DTensor]: + """Compute confidence logits and aggregated metrics. + + Parameters + ---------- + s : DTensor + Single representation. Shape ``(B*mult, N, D_s)``, + placements ``(Shard(0), Shard(1), Replicate())``. + z : DTensor + Pair representation. Shape ``(B*mult, N, N, D_z)``, + placements ``(Shard(0), Shard(1), Shard(2))``. + x_pred : DTensor + Predicted atom coordinates. Shape ``(B*mult, N_atoms, 3)``, + placements ``(Shard(0), Shard(1), Replicate())``. + d : DTensor + Token-level distance matrix. Shape ``(B*mult, N, N)``, + placements ``(Shard(0), Shard(1), Shard(2))``. + feats : dict[str, DTensor] + Feature dictionary. Required keys: ``token_pad_mask``, ``asym_id``, + ``mol_type``. + pred_distogram_logits : DTensor + Predicted distogram logits. Shape ``(B, N, N, 64)``, + placements ``(Shard(0), Shard(1), Shard(2))``. + multiplicity : int + Number of diffusion samples per input. + + Returns + ------- + dict[str, DTensor] + """ + self._validate_inputs(s, z, x_pred, d, feats, pred_distogram_logits) + + plddt_logits = self.to_plddt_logits(s) + resolved_logits = self.to_resolved_logits(s) + + # Build same_chain mask once; reused by _ShardwiseWhere (when use_separate_heads) + # and the no-grad iPLDDT metrics block (always). + with torch.no_grad(): + same_chain_base = replicate_to_shard_outer_op( + feats["asym_id"], OuterOp.EQUAL, axis=1, transpose_comm=self.transpose_comm + ).to_local() # (B_local, N_row, N_col) + + if self.use_separate_heads: + # M = same_chain mask, A = intra, B = inter → torch.where(M, A, B) + same_chain = same_chain_base.repeat_interleave(multiplicity, dim=0) if multiplicity > 1 else same_chain_base + + pae_logits = _ShardwiseWhere.apply(same_chain, self.to_pae_intra_logits(z), self.to_pae_inter_logits(z)) + + # proj(z + z^T) = proj(z) + proj(z^T) + z_pde_intra = self.to_pde_intra_logits(z) + z_pde_intra_T = redistribute_transpose( + z_pde_intra, self.transpose_comm, (Shard(0), Shard(1), Shard(2)), 1, 2 + ) + pde_intra = elementwise_op(z_pde_intra, z_pde_intra_T, ElementwiseOp.SUM) + + z_pde_inter = self.to_pde_inter_logits(z) + z_pde_inter_T = redistribute_transpose( + z_pde_inter, self.transpose_comm, (Shard(0), Shard(1), Shard(2)), 1, 2 + ) + pde_inter = elementwise_op(z_pde_inter, z_pde_inter_T, ElementwiseOp.SUM) + + pde_logits = _ShardwiseWhere.apply(same_chain, pde_intra, pde_inter) + else: + # Original path from boltz1 + pae_logits = self.to_pae_logits(z) + + z_proj = self.to_pde_logits(z) + z_proj_T = redistribute_transpose(z_proj, self.transpose_comm, (Shard(0), Shard(1), Shard(2)), 1, 2) + pde_logits = elementwise_op(z_proj, z_proj_T, ElementwiseOp.SUM) + + out_dict: dict[str, DTensor] = { + "plddt_logits": plddt_logits, + "pde_logits": pde_logits, + "resolved_logits": resolved_logits, + "pae_logits": pae_logits, + } + + # ================================================================== + # No-grad aggregated metrics (inference / logging only) + # ================================================================== + with torch.no_grad(): + token_pad_mask = feats["token_pad_mask"] + mask_local = token_pad_mask.to_local() # (B_local, N_local) + B_local = mask_local.shape[0] + N_local = mask_local.shape[1] + + # ---- pLDDT ---- + plddt = compute_aggregated_metric(plddt_logits) + plddt_local = plddt.to_local() # (B_local*mult, N_local) + plddt_reshaped = plddt_local.reshape(B_local, multiplicity, N_local) + + masked_plddt = plddt_reshaped * mask_local.unsqueeze(1) + num_local = masked_plddt.sum(dim=-1) # (B_local, mult) + den_local = mask_local.sum(dim=-1, keepdim=True) # (B_local, 1) + + group_cp0 = self.device_mesh.get_group("cp_axis_0") + torch.distributed.all_reduce(num_local, op=torch.distributed.ReduceOp.SUM, group=group_cp0) + torch.distributed.all_reduce(den_local, op=torch.distributed.ReduceOp.SUM, group=group_cp0) + + complex_plddt_local = (num_local / den_local).reshape(B_local * multiplicity) + complex_plddt = DTensor.from_local( + complex_plddt_local, + device_mesh=self.device_mesh, + placements=(Shard(0), Replicate(), Replicate()), + shape=(plddt.shape[0],), + stride=(1,), + ) + + # ---- iPLDDT (v2 weights: ligand=20, interface=10, non_interface=1) ---- + ligand_weight = IPLDDT_LIGAND_WEIGHT + interface_weight = IPLDDT_INTERFACE_WEIGHT + non_interface_weight = IPLDDT_NON_INTERFACE_WEIGHT + + mol_type_local = feats["mol_type"].to_local() # (B_local, N_local) + is_ligand_local = (mol_type_local == const.chain_type_ids["NONPOLYMER"]).float() + + d_local = d.to_local() # (B_local*mult, N_row, N_col) + is_contact_local = (d_local < 8).float() + + is_diff_chain_local = (~same_chain_base).float() + + # NOTE: because we use a square grid for now, N_row == N_col + N_row = d_local.shape[1] + N_col = d_local.shape[2] + is_contact_4d = is_contact_local.reshape(B_local, multiplicity, N_row, N_col) + is_diff_chain_4d = is_diff_chain_local.unsqueeze(1) + non_ligand_4d = (1 - is_ligand_local).unsqueeze(1).unsqueeze(-1) + + interface_product = is_contact_4d * is_diff_chain_4d * non_ligand_4d + token_interface_mask_local = interface_product.max(dim=-1).values # (B_local, mult, N_row) + + group_cp1 = self.device_mesh.get_group("cp_axis_1") + torch.distributed.all_reduce( + token_interface_mask_local, + op=torch.distributed.ReduceOp.MAX, + group=group_cp1, + ) + + is_ligand_3d = is_ligand_local.unsqueeze(1) # (B_local, 1, N_local) + token_non_interface_mask = (1 - token_interface_mask_local) * (1 - is_ligand_3d) + iplddt_weight_local = ( + is_ligand_3d * ligand_weight + + token_interface_mask_local * interface_weight + + token_non_interface_mask * non_interface_weight + ) # (B_local, mult, N_local) + + masked_iplddt_w = mask_local.unsqueeze(1) * iplddt_weight_local + num_iplddt = (plddt_reshaped * masked_iplddt_w).sum(dim=-1) + den_iplddt = masked_iplddt_w.sum(dim=-1) + + torch.distributed.all_reduce(num_iplddt, op=torch.distributed.ReduceOp.SUM, group=group_cp0) + torch.distributed.all_reduce(den_iplddt, op=torch.distributed.ReduceOp.SUM, group=group_cp0) + + complex_iplddt_local = (num_iplddt / den_iplddt).reshape(B_local * multiplicity) + complex_iplddt = DTensor.from_local( + complex_iplddt_local, + device_mesh=self.device_mesh, + placements=(Shard(0), Replicate(), Replicate()), + shape=(plddt.shape[0],), + stride=(1,), + ) + + # ---- PDE / iPDE ---- + pde = compute_aggregated_metric(pde_logits, end=32) + + pred_disto_local = pred_distogram_logits.to_local() + if pred_disto_local.ndim == 5: + if pred_disto_local.shape[-2] != 1: + raise ValueError( + f"ConfidenceHeads expects num_distograms=1, " f"got shape {pred_disto_local.shape}" + ) + pred_disto_local = pred_disto_local.squeeze(-2) + pred_disto_prob = torch.softmax(pred_disto_local, dim=-1) + contacts_mask = torch.zeros((1, 1, 1, 64), dtype=pred_disto_prob.dtype, device=pred_disto_prob.device) + contacts_mask[:, :, :, :20] = 1.0 + prob_contact_local = (pred_disto_prob.unsqueeze(1) * contacts_mask).sum(-1) # (B_local, 1, N_row, N_col) + + pde_local = pde.to_local().reshape(B_local, multiplicity, N_local, N_local) + + row_mask = mask_local # (B_local, N_local) + col_mask = redistribute_transpose( + token_pad_mask, + self.transpose_comm, + (Shard(0), Replicate(), Shard(1)), + dim0=None, + dim1=None, + ).to_local() + + prob_contact_local = prob_contact_local * row_mask[:, None, :, None] * col_mask[:, None, None, :] + + mesh_coord = self.device_mesh.get_coordinate() + if mesh_coord[1] == mesh_coord[2]: + diag_idx = torch.arange(0, N_local, device=prob_contact_local.device) + prob_contact_local[:, :, diag_idx, diag_idx] = 0 + + num_pde = (pde_local * prob_contact_local).sum(dim=(2, 3)) + den_pde = prob_contact_local.sum(dim=(2, 3)) + + torch.distributed.all_reduce(num_pde, op=torch.distributed.ReduceOp.SUM, group=self.transpose_comm.group) + torch.distributed.all_reduce(den_pde, op=torch.distributed.ReduceOp.SUM, group=self.transpose_comm.group) + + complex_pde_local = (num_pde / den_pde).reshape(B_local * multiplicity) + complex_pde = DTensor.from_local( + complex_pde_local, + device_mesh=self.device_mesh, + placements=(Shard(0), Replicate(), Replicate()), + shape=(pde.shape[0],), + stride=(1,), + ) + + # iPDE + token_intf_pair = prob_contact_local * is_diff_chain_local.unsqueeze(1) + num_ipde = (pde_local * token_intf_pair).sum(dim=(2, 3)) + den_ipde = token_intf_pair.sum(dim=(2, 3)) + + torch.distributed.all_reduce(num_ipde, op=torch.distributed.ReduceOp.SUM, group=self.transpose_comm.group) + torch.distributed.all_reduce(den_ipde, op=torch.distributed.ReduceOp.SUM, group=self.transpose_comm.group) + + complex_ipde_local = (num_ipde / (den_ipde + 1e-5)).reshape(B_local * multiplicity) + complex_ipde = DTensor.from_local( + complex_ipde_local, + device_mesh=self.device_mesh, + placements=(Shard(0), Replicate(), Replicate()), + shape=(pde.shape[0],), + stride=(1,), + ) + + # ---- PAE ---- + pae = compute_aggregated_metric(pae_logits, end=32) + + out_dict["plddt"] = plddt + out_dict["pde"] = pde + out_dict["pae"] = pae + out_dict["complex_plddt"] = complex_plddt + out_dict["complex_iplddt"] = complex_iplddt + out_dict["complex_pde"] = complex_pde + out_dict["complex_ipde"] = complex_ipde + + # --- PTM / iPTM --- + # No try-except here: the serial v2 code has a broad `except Exception` + # fallback that silently replaces PTM scores with zeros. We intentionally + # omit it in the distributed path so that any error surfaces immediately. + # If the serial path's fallback ever triggers, an equivalence test will + # catch the mismatch (serial zeros vs distributed crash). + ptm, iptm, ligand_iptm, protein_iptm, pair_chains_iptm = compute_ptms( + pae_logits, + x_pred, + feats, + multiplicity, + self.transpose_comm, + ) + out_dict["ptm"] = ptm + out_dict["iptm"] = iptm + out_dict["ligand_iptm"] = ligand_iptm + out_dict["protein_iptm"] = protein_iptm + out_dict["pair_chains_iptm"] = pair_chains_iptm + + return out_dict + + # ------------------------------------------------------------------ + # Input validation + # ------------------------------------------------------------------ + + def _validate_inputs( + self, + s: DTensor, + z: DTensor, + x_pred: DTensor, + d: DTensor, + feats: dict, + pred_distogram_logits: DTensor, + ) -> None: + for name, tensor in [("s", s), ("z", z), ("x_pred", x_pred), ("d", d)]: + if not isinstance(tensor, DTensor): + raise TypeError(f"Expected DTensor for {name}, got {type(tensor)}") + + expected = { + "s": (Shard(0), Shard(1), Replicate()), + "z": (Shard(0), Shard(1), Shard(2)), + "x_pred": (Shard(0), Shard(1), Replicate()), + "d": (Shard(0), Shard(1), Shard(2)), + } + for name, tensor in [("s", s), ("z", z), ("x_pred", x_pred), ("d", d)]: + if tensor.placements != expected[name]: + raise ValueError(f"Expected {name} placements {expected[name]}, got {tensor.placements}") + + if pred_distogram_logits.placements != (Shard(0), Shard(1), Shard(2)): + raise ValueError( + f"Expected pred_distogram_logits placements (Shard(0), Shard(1), Shard(2)), " + f"got {pred_distogram_logits.placements}" + ) + + for key in ("token_pad_mask", "asym_id", "mol_type"): + feat = feats[key] + if not isinstance(feat, DTensor): + raise TypeError(f"Expected DTensor for feats['{key}'], got {type(feat)}") + if feat.placements != (Shard(0), Shard(1), Replicate()): + raise ValueError( + f"Expected feats['{key}'] placements (Shard(0), Shard(1), Replicate()), " f"got {feat.placements}" + ) + + N_global = feats["token_pad_mask"].shape[1] + if s.shape[0] != z.shape[0]: + raise ValueError(f"Batch dim mismatch: s.shape[0]={s.shape[0]} vs z.shape[0]={z.shape[0]}") + if s.shape[1] != N_global: + raise ValueError(f"Token dim mismatch: s.shape[1]={s.shape[1]} vs N_global={N_global}") + if z.shape[1] != N_global or z.shape[2] != N_global: + raise ValueError( + f"Pair dims must equal N_global={N_global}, got z.shape[1]={z.shape[1]}, z.shape[2]={z.shape[2]}" + ) + + +class ConfidenceModule(nn.Module): + """Distributed ConfidenceModule v2 (Algorithm 31). + + Wraps the serial :class:`~boltz.model.modules.confidencev2.ConfidenceModule`, + distributing submodule parameters with DTensor-compatible layers and using + sharded operations for pair computations. + + The forward pass: + + 1. Normalize ``s_inputs``, ``s``, ``z`` + 2. Optional ``s`` / ``z`` conditioning (``add_s_input_to_s``, ``add_z_input_to_z``) + 3. ``repeat_interleave`` s for multiplicity + 4. Outer-sum ``s_to_z`` pair update + 5. Optional: ``add_s_to_z_prod`` + 6. Distogram chain: representative-atom projection → pairwise cdist → + binning → embedding + 7. ``repeat_interleave`` z for multiplicity + 8. Pairformer stack + 9. :class:`ConfidenceHeads` for logit projections and aggregated metrics + + Only ``token_level_confidence=True`` is supported. + + Communication budget (forward only): + + - Norms / linears: no collectives (params are Replicate) + - ``replicate_to_shard_outer_op``: 1 all-to-all per call + - ``single_repr_rep_atom_to_token``: shardwise (0 collectives) + - ``replicate_to_shard_outer_op(CDIST)``: 1 all-to-all + - ``shardwise_distogram``: 0 collectives + - ``PairformerModule``: O(depth) collectives (ring attention + triangle) + - ``ConfidenceHeads``: O(1) all-reduces for aggregated metrics + + Parameters + ---------- + module : SerialConfidenceModuleV2 + Initialised serial module whose weights are wrapped / transferred. + dist_manager : DistributedManager + Distributed manager with the 3-D device mesh (dp, cp_axis_0, cp_axis_1). + transpose_comm : TransposeComm + Base transpose-communication handle. Deep copies are created + internally for submodules that store their own handle. + """ + + def __init__( + self, + module: SerialConfidenceModuleV2, + dist_manager: DistributedManager, + transpose_comm: TransposeComm, + ) -> None: + super().__init__() + + if not module.token_level_confidence: + raise NotImplementedError( + "ConfidenceModule distributed v2 only supports token_level_confidence=True. " + "The atom-level confidence path is not implemented for DTensor." + ) + + self.device_mesh = dist_manager.device_mesh_subgroups + self.transpose_comm = transpose_comm + + self.no_update_s = module.no_update_s + self.add_s_to_z_prod = module.add_s_to_z_prod + self.add_s_input_to_s = module.add_s_input_to_s + self.add_z_input_to_z = module.add_z_input_to_z + self.return_latent_feats = module.return_latent_feats + + # ---- Buffer (plain tensor, not DTensor) ---- + self.register_buffer("boundaries", module.boundaries) + + # ---- LayerNorms ---- + self.s_inputs_norm = LayerNormParamsReplicated(module.s_inputs_norm, self.device_mesh) + if not self.no_update_s: + self.s_norm = LayerNormParamsReplicated(module.s_norm, self.device_mesh) + self.z_norm = LayerNormParamsReplicated(module.z_norm, self.device_mesh) + + # ---- s → z projections ---- + self.s_to_z = LinearParamsReplicated(module.s_to_z, self.device_mesh) + self.s_to_z_transpose = LinearParamsReplicated(module.s_to_z_transpose, self.device_mesh) + + if self.add_s_to_z_prod: + self.s_to_z_prod_in1 = LinearParamsReplicated(module.s_to_z_prod_in1, self.device_mesh) + self.s_to_z_prod_in2 = LinearParamsReplicated(module.s_to_z_prod_in2, self.device_mesh) + self.s_to_z_prod_out = LinearParamsReplicated(module.s_to_z_prod_out, self.device_mesh) + + # ---- Optional s_input → s ---- + if self.add_s_input_to_s: + self.s_input_to_s = LinearParamsReplicated(module.s_input_to_s, self.device_mesh) + + # ---- Optional z-input conditioning (rel_pos, bonds, contacts) ---- + if self.add_z_input_to_z: + self.rel_pos = RelativePositionEncoder( + module.rel_pos, + device_mesh=self.device_mesh, + transpose_comm=deepcopy(transpose_comm), + ) + self.token_bonds = LinearParamsReplicated(module.token_bonds, self.device_mesh) + self.bond_type_feature = getattr(module, "bond_type_feature", False) + if self.bond_type_feature: + self.token_bonds_type = EmbeddingParamsReplicated(module.token_bonds_type, self.device_mesh) + self.contact_conditioning = ContactConditioning(module.contact_conditioning, device_mesh=self.device_mesh) + + # ---- Distogram embedding ---- + self.dist_bin_pairwise_embed = EmbeddingParamsReplicated(module.dist_bin_pairwise_embed, self.device_mesh) + + # ---- Pairformer ---- + self.pairformer_stack = PairformerModule(module.pairformer_stack, dist_manager) + + # ---- Confidence heads ---- + self.confidence_heads = ConfidenceHeads( + module.confidence_heads, + self.device_mesh, + deepcopy(transpose_comm), + ) + + def forward( + self, + s_inputs: DTensor, + s: DTensor, + z: DTensor, + x_pred: DTensor, + feats: dict, + pred_distogram_logits: DTensor, + multiplicity: int = 1, + run_sequentially: bool = False, + ) -> dict[str, DTensor]: + """Forward pass through the distributed confidence module. + + Parameters + ---------- + s_inputs : DTensor + Input single representation, shape ``(B, N, D_s)``, + placements ``(Shard(0), Shard(1), Replicate())``. + s : DTensor + Trunk single representation (detached), same shape/placements. + z : DTensor + Trunk pair representation (detached), shape ``(B, N, N, D_z)``, + placements ``(Shard(0), Shard(1), Shard(2))``. + x_pred : DTensor + Predicted atom coordinates, shape ``(B*mult, N_atoms, 3)``, + placements ``(Shard(0), Shard(1), Replicate())``. + feats : dict[str, DTensor] + Feature dictionary. + pred_distogram_logits : DTensor + Predicted distogram logits, shape ``(B, N, N, K, bins)`` or ``(B, N, N, bins)``. + multiplicity : int + Number of diffusion samples. + run_sequentially : bool + If True and multiplicity > 1, run each multiplicity sample through + the confidence module one at a time to reduce peak memory usage. + + Returns + ------- + dict[str, DTensor] + Confidence outputs including logits and aggregated metrics. + """ + if run_sequentially and multiplicity > 1: + return self._forward_sequentially(s_inputs, s, z, x_pred, feats, pred_distogram_logits, multiplicity) + + # ---- 1. Normalize inputs ---- + s_inputs = self.s_inputs_norm(s_inputs) + if not self.no_update_s: + s = self.s_norm(s) + + # ---- 2. Optional s_input addition to s ---- + if self.add_s_input_to_s: + s = elementwise_op(s, self.s_input_to_s(s_inputs), ElementwiseOp.SUM) + + # ---- 3. Normalize z ---- + z = self.z_norm(z) + + # ---- 4. Optional z-input conditioning ---- + if self.add_z_input_to_z: + z = elementwise_op(z, self.rel_pos(feats), ElementwiseOp.SUM) + safe_dtype = z.dtype if z.dtype.is_floating_point else torch.float32 + z = elementwise_op( + z, + self.token_bonds(feats["token_bonds"].to(dtype=safe_dtype)), + ElementwiseOp.SUM, + ) + if self.bond_type_feature: + z = elementwise_op( + z, + self.token_bonds_type(feats["type_bonds"].long()), + ElementwiseOp.SUM, + ) + z = elementwise_op(z, self.contact_conditioning(feats), ElementwiseOp.SUM) + + # ---- 5. Repeat s for multiplicity ---- + s = shardwise_repeat_interleave(s, multiplicity, dim=0) + + # ---- 6. Outer-sum s → z ---- + # Serial: z += s_to_z(s_inputs)[:, :, None, :] + s_to_z_T(s_inputs)[:, None, :, :] + s_to_z_pair = replicate_to_shard_outer_op( + self.s_to_z(s_inputs), + OuterOp.SUM, + axis=1, + transpose_comm=self.transpose_comm, + input_t=self.s_to_z_transpose(s_inputs), + ) + z = elementwise_op(z, s_to_z_pair, ElementwiseOp.SUM) + + # ---- 7. Optional outer-product s → z ---- + if self.add_s_to_z_prod: + z_prod = replicate_to_shard_outer_op( + self.s_to_z_prod_in1(s_inputs), + OuterOp.PROD, + axis=1, + transpose_comm=self.transpose_comm, + input_t=self.s_to_z_prod_in2(s_inputs), + ) + z = elementwise_op(z, self.s_to_z_prod_out(z_prod), ElementwiseOp.SUM) + + # ---- 8. Distogram: x_pred → representative-atom token repr → cdist → bin → embed ---- + token_to_rep_atom = feats["token_to_rep_atom"] + x_pred_repr = single_repr_rep_atom_to_token(x_pred, token_to_rep_atom) + + d = replicate_to_shard_outer_op(x_pred_repr, OuterOp.CDIST, axis=1, transpose_comm=self.transpose_comm) + distogram = shardwise_distogram(d, self.boundaries) + distogram = self.dist_bin_pairwise_embed(distogram) + + # ---- 9. Repeat z for multiplicity and add distogram ---- + z = shardwise_repeat_interleave(z, multiplicity, dim=0) + z = elementwise_op(z, distogram, ElementwiseOp.SUM) + + # ---- 10. Masks for pairformer ---- + mask = shardwise_repeat_interleave(feats["token_pad_mask"], multiplicity, dim=0) + pair_mask = shardwise_repeat_interleave(feats["token_pair_pad_mask"], multiplicity, dim=0) + mask = mask.to(dtype=s.dtype) + pair_mask = pair_mask.to(dtype=z.dtype) + + # ---- 11. Pairformer ---- + s, z = self.pairformer_stack(s, z, mask=mask, pair_mask=pair_mask) + + # ---- 12. Output dict ---- + out_dict: dict[str, DTensor] = {} + if self.return_latent_feats: + out_dict["s_conf"] = s + out_dict["z_conf"] = z + + # ---- 13. Confidence heads ---- + out_dict.update( + self.confidence_heads( + s=s, + z=z, + x_pred=x_pred, + d=d, + feats=feats, + pred_distogram_logits=pred_distogram_logits, + multiplicity=multiplicity, + ) + ) + return out_dict + + def _forward_sequentially( + self, + s_inputs: DTensor, + s: DTensor, + z: DTensor, + x_pred: DTensor, + feats: dict, + pred_distogram_logits: DTensor, + multiplicity: int, + ) -> dict[str, DTensor]: + """Run the confidence module one multiplicity sample at a time. + + This trades throughput for peak memory: instead of processing all + ``multiplicity`` samples in a single forward pass, each sample is + processed independently with ``multiplicity=1`` and the results + are re-assembled at the end. + """ + x_pred_local = x_pred.to_local() + assert ( + x_pred_local.shape[0] % multiplicity == 0 + ), f"x_pred.shape[0] must be divisible by multiplicity, got {x_pred.shape[0]} and multiplicity {multiplicity}" + B_local = x_pred_local.shape[0] // multiplicity + B_global = x_pred.shape[0] // multiplicity + + if B_local > 1: + warnings.warn( + "B_local > 1 could cause deadlocking issues with pair_chains_iptm " + "when chain counts are different on different dp groups" + ) + + x_pred_single_shape = torch.Size([B_global, *x_pred.shape[1:]]) + x_pred_unflat = x_pred_local.unflatten(0, (B_local, multiplicity)) + x_pred_single_stride = update_exhaustive_strides(x_pred.shape, x_pred.stride(), x_pred_single_shape) + + out_dicts: list[dict] = [] + for mult_idx in range(multiplicity): + x_pred_sample = DTensor.from_local( + x_pred_unflat[:, mult_idx : mult_idx + 1].flatten(0, 1), + device_mesh=x_pred.device_mesh, + placements=x_pred.placements, + shape=x_pred_single_shape, + stride=x_pred_single_stride, + ) + out_dicts.append( + self.forward( + s_inputs, + s, + z, + x_pred_sample, + feats, + pred_distogram_logits, + multiplicity=1, + run_sequentially=False, + ) + ) + + out_dict: dict[str, DTensor] = {} + B_global_mult = x_pred.shape[0] + for key in out_dicts[0]: + if key != "pair_chains_iptm": + ref = out_dicts[0][key] + stacked = torch.stack([o[key].to_local() for o in out_dicts], dim=1) + stacked_flattened = stacked.flatten(0, 1) + out_shape = torch.Size([B_global_mult, *ref.shape[1:]]) + out_dict[key] = DTensor.from_local( + stacked_flattened, + device_mesh=ref.device_mesh, + placements=ref.placements, + shape=out_shape, + stride=update_exhaustive_strides(ref.shape, ref.stride(), out_shape), + ) + else: + pair_chains_iptm: dict = {} + for idx1 in out_dicts[0][key]: + chain_iptm: dict = {} + for idx2 in out_dicts[0][key][idx1]: + ref = out_dicts[0][key][idx1][idx2] + stacked = torch.stack([o[key][idx1][idx2].to_local() for o in out_dicts], dim=1) + stacked_flattened = stacked.flatten(0, 1) + ref_shape = torch.Size([B_global_mult, *ref.shape[1:]]) + chain_iptm[idx2] = DTensor.from_local( + stacked_flattened, + device_mesh=ref.device_mesh, + placements=ref.placements, + shape=ref_shape, + stride=update_exhaustive_strides(ref.shape, ref.stride(), ref_shape), + ) + pair_chains_iptm[idx1] = chain_iptm + out_dict[key] = pair_chains_iptm + + return out_dict diff --git a/src/boltz/distributed/model/modules/diffusion.py b/src/boltz/distributed/model/modules/diffusion.py new file mode 100644 index 000000000..cd6989c21 --- /dev/null +++ b/src/boltz/distributed/model/modules/diffusion.py @@ -0,0 +1,1314 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""DTensor-compatible DiffusionModule and AtomDiffusion for Context Parallelism. + +Supports both Boltz-1 and Boltz-2 serial DiffusionModule, controlled by the +``internalized_conditioning`` flag (auto-detected from the serial layer type): + +- **Internalized** (Boltz-1): module owns ``pairwise_conditioner``, + ``AtomAttentionEncoder`` computes q/c/p internally. Forward takes raw + ``z_trunk`` + ``relative_position_encoding``. +- **Externalized** (Boltz-2): conditioning is pre-computed by a separate + ``DiffusionConditioning`` module. Forward receives a + ``diffusion_conditioning`` dict. + +The token-level DiffusionTransformer uses ring attention (all-to-all), while the +atom-level attention (inside AtomAttentionEncoder/Decoder) uses window-batched +attention in both modes. +""" + +import warnings +from copy import deepcopy +from math import exp, sqrt + +import torch +from torch import nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard, full, zeros + +from boltz.data import const +from boltz.distributed.comm import AttentionPairBiasComm, TransposeComm +from boltz.distributed.data.feature.featurizer import pack_atom_features +from boltz.distributed.model.layers.atom_to_token import single_repr_token_to_atom +from boltz.distributed.model.layers.elementwise_op import ( + ElementwiseOp, + elementwise_op, + scalar_tensor_op, + single_tensor_op, +) +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.linear import LinearParamsReplicated +from boltz.distributed.model.layers.repeat_interleave import shardwise_repeat_interleave +from boltz.distributed.model.layers.replicate_op import ReplicateOp, replicate_op +from boltz.distributed.model.layers.sharded_op import sharded_sum +from boltz.distributed.model.layers.shardwise_op import shardwise_sum +from boltz.distributed.model.layers.squeeze import shardwise_unsqueeze +from boltz.distributed.model.layers.utils import distributed_pack_and_pad, distributed_unpad_and_unpack +from boltz.distributed.model.loss.diffusion import ( + smooth_lddt_loss, + smooth_lddt_loss_triton, + weighted_rigid_align, +) +from boltz.distributed.model.modules.encoders import ( + AtomAttentionDecoder, + AtomAttentionEncoder, + PairwiseConditioning, + SingleConditioning, +) +from boltz.distributed.model.modules.transformers import DiffusionTransformer +from boltz.distributed.model.modules.utils import center_random_augmentation +from boltz.distributed.utils import LayoutRightMap, create_distributed_randn +from boltz.model.modules.diffusion import AtomDiffusion as SerialAtomDiffusionV1 +from boltz.model.modules.diffusion import DiffusionModule as SerialDiffusionModuleV1 +from boltz.model.modules.diffusionv2 import AtomDiffusion as SerialAtomDiffusionV2 +from boltz.model.modules.diffusionv2 import DiffusionModule as SerialDiffusionModuleV2 +from boltz.model.modules.utils import default + + +class DiffusionModule(nn.Module): + """DTensor DiffusionModule for Context Parallelism. + + Supports both Boltz-1 and Boltz-2 serial DiffusionModule via the + ``internalized_conditioning`` flag (auto-detected from the serial layer): + + - **Internalized** (Boltz-1): owns ``pairwise_conditioner``, encoder + computes q/c/p internally. Forward takes ``z_trunk`` and + ``relative_position_encoding``. + - **Externalized** (Boltz-2): receives pre-computed conditioning from + ``DiffusionConditioning``. Forward takes ``diffusion_conditioning`` dict. + + In both modes, atom features (``feats``) and ``r_noisy`` are expected in + **unpacked** layout (with intersperse padding from the CP DTensor data + loader). The module calls ``pack_atom_features`` and + ``distributed_pack_and_pad`` internally to form a self-contained pack/unpack + closure. The ``diffusion_conditioning`` dict (V2 only) arrives in packed + layout as an inter-layer output from ``DiffusionConditioning`` — this + coupling is managed between the two modules and is unaffected by data + loading pipeline changes. + + The token-level DiffusionTransformer uses ring attention and the atom-level + attention uses window-batched attention. + """ + + def __init__( + self, + layer: SerialDiffusionModuleV1 | SerialDiffusionModuleV2, + device_mesh: DeviceMesh, + ring_comm: AttentionPairBiasComm | None = None, + ): + """Initialize the DTensor DiffusionModule. + + Parameters + ---------- + layer : SerialDiffusionModuleV1 | SerialDiffusionModuleV2 + The serial DiffusionModule. + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + ring_comm : AttentionPairBiasComm or None, optional + Ring communication object for the token-level DiffusionTransformer. + + """ + super().__init__() + if not isinstance(layer, (SerialDiffusionModuleV1, SerialDiffusionModuleV2)): + raise TypeError(f"Expected SerialDiffusionModuleV1 or SerialDiffusionModuleV2, got {type(layer)}") + + # Internalized: module owns pairwise_conditioner, encoder computes q/c/p internally. + # Externalized: receives pre-computed conditioning from DiffusionConditioning. + self.internalized_conditioning = isinstance(layer, SerialDiffusionModuleV1) + + if isinstance(layer, SerialDiffusionModuleV2): + warnings.warn( + "CPU offloading-based activation checkpointing by default is " + "not used for Boltz-2 so we do not use it in DTensor DiffusionModule.", + UserWarning, + stacklevel=2, + ) + elif isinstance(layer, SerialDiffusionModuleV1): + warnings.warn( + "CPU offloading-based activation checkpointing can't be passed to DTensor DiffusionModule via " + "the input v1 serial layer. We will implement a custom flag in the future to enable it.", + UserWarning, + stacklevel=2, + ) + + # Sanity: serial layer must have pairwise_conditioner iff internalized + if self.internalized_conditioning and not hasattr(layer, "pairwise_conditioner"): + raise ValueError("internalized_conditioning=True but serial layer has no pairwise_conditioner") + if not self.internalized_conditioning and hasattr(layer, "pairwise_conditioner"): + raise ValueError("internalized_conditioning=False but serial layer has pairwise_conditioner") + + self.device_mesh = device_mesh + self.sigma_data = layer.sigma_data + self.atoms_per_window_queries = layer.atoms_per_window_queries + self.atoms_per_window_keys = layer.atoms_per_window_keys + self.activation_checkpointing = getattr(layer, "activation_checkpointing", False) + + # Common sub-modules (identical attribute names in V1 and V2 serial) + self.single_conditioner = SingleConditioning(layer.single_conditioner, device_mesh) + self.atom_attention_encoder = AtomAttentionEncoder(layer.atom_attention_encoder, device_mesh) + self.s_to_a_linear = nn.Sequential( + LayerNormParamsReplicated(layer.s_to_a_linear[0], device_mesh), + LinearParamsReplicated(layer.s_to_a_linear[1], device_mesh), + ) + self.token_transformer = DiffusionTransformer(layer.token_transformer, device_mesh, ring_comm=ring_comm) + self.a_norm = LayerNormParamsReplicated(layer.a_norm, device_mesh) + self.atom_attention_decoder = AtomAttentionDecoder(layer.atom_attention_decoder, device_mesh) + + # Internalized-only: owns pairwise_conditioner + if self.internalized_conditioning: + self.pairwise_conditioner = PairwiseConditioning(layer.pairwise_conditioner, device_mesh) + + def forward( + self, + s_inputs: DTensor, + s_trunk: DTensor, + r_noisy: DTensor, + times: DTensor, + feats: dict[str, DTensor], + # Externalized conditioning (when internalized_conditioning=False) + diffusion_conditioning: dict[str, DTensor] | None = None, + # Internalized conditioning inputs (when internalized_conditioning=True) + z_trunk: DTensor | None = None, + relative_position_encoding: DTensor | None = None, + # Common + multiplicity: int = 1, + model_cache: dict[str, dict[str, DTensor]] | None = None, + ) -> DTensor | dict[str, DTensor]: + """Forward pass of the DTensor DiffusionModule. + + Parameters + ---------- + s_inputs : DTensor + Input single representation, shape (B, N, token_s). + Placements: (Shard(0), Shard(1), Replicate()). + s_trunk : DTensor + Trunk single representation, shape (B, N, token_s). + Placements: (Shard(0), Shard(1), Replicate()). + r_noisy : DTensor + Noisy atom coordinates, shape (B*M, N_atoms, 3). + Placements: (Shard(0), Shard(1), Replicate()). + times : DTensor + Time embeddings, shape (B*M,). + Placements: (Shard(0), Replicate(), Replicate()). + feats : dict[str, DTensor] + Unpacked atom features (with intersperse padding from the CP DTensor + data loader). Must include ``token_pad_mask`` (token-level). + ``pack_atom_features`` is called internally for both V1 and V2. + diffusion_conditioning : dict[str, DTensor] or None + Externalized conditioning (required when internalized_conditioning=False). + Produced by ``DiffusionConditioning`` in packed layout (inter-layer + output, not from data loader): + - "q": (B, N_atoms_packed, atom_s), placements (S(0), S(1), R) + - "c": (B, N_atoms_packed, atom_s), placements (S(0), S(1), R) + - "atom_enc_bias": (B, K, W, H, total_enc_heads), placements (S(0), S(1), R) + - "atom_dec_bias": (B, K, W, H, total_dec_heads), placements (S(0), S(1), R) + - "token_trans_bias": (B, N, N, total_trans_heads), placements (S(0), S(1), S(2)) + z_trunk : DTensor or None + Trunk pair representation (required when internalized_conditioning=True). + Shape (B, N, N, token_z), placements (S(0), S(1), S(2)). + relative_position_encoding : DTensor or None + Relative position encoding (required when internalized_conditioning=True). + Shape (B, N, N, token_z), placements (S(0), S(1), S(2)). + multiplicity : int + Number of diffusion samples per batch element. + model_cache : dict or None, optional + Model cache for inference optimization (internalized path only). + + Returns + ------- + DTensor or dict[str, DTensor] + Internalized: ``{"r_update": DTensor, "token_a": DTensor}`` + Externalized: ``r_update`` DTensor directly. + + """ + # ------------------------------------------------------------------ + # Input sanity checks + # ------------------------------------------------------------------ + # Conditioning mode consistency + if self.internalized_conditioning: + if diffusion_conditioning is not None: + raise ValueError( + "internalized_conditioning: diffusion_conditioning must be None " + "(conditioning is computed internally from z_trunk + relative_position_encoding)" + ) + if z_trunk is None or relative_position_encoding is None: + raise ValueError("internalized_conditioning: z_trunk and relative_position_encoding are required") + else: + if diffusion_conditioning is None: + raise ValueError("externalized_conditioning: diffusion_conditioning dict is required") + if z_trunk is not None or relative_position_encoding is not None: + raise ValueError( + "externalized_conditioning: z_trunk and relative_position_encoding must be None " + "(conditioning is pre-computed in diffusion_conditioning dict)" + ) + + # Placement checks + expected_single = (Shard(0), Shard(1), Replicate()) + expected_pair = (Shard(0), Shard(1), Shard(2)) + expected_times = (Shard(0), Replicate(), Replicate()) + if s_inputs.placements != expected_single: + raise ValueError(f"s_inputs has incorrect placements: {s_inputs.placements} != {expected_single}") + if s_trunk.placements != expected_single: + raise ValueError(f"s_trunk has incorrect placements: {s_trunk.placements} != {expected_single}") + if r_noisy.placements != expected_single: + raise ValueError(f"r_noisy has incorrect placements: {r_noisy.placements} != {expected_single}") + if times.placements != expected_times: + raise ValueError(f"times has incorrect placements: {times.placements} != {expected_times}") + if self.internalized_conditioning: + if z_trunk.placements != expected_pair: + raise ValueError(f"z_trunk has incorrect placements: {z_trunk.placements} != {expected_pair}") + if relative_position_encoding.placements != expected_pair: + raise ValueError( + f"relative_position_encoding has incorrect placements:" + f" {relative_position_encoding.placements} != {expected_pair}" + ) + + # Shape checks: s_inputs/s_trunk batch should NOT include multiplicity + if s_inputs.shape[0] != feats["token_pad_mask"].shape[0]: + raise ValueError( + f"s_inputs batch {s_inputs.shape[0]} != feats['token_pad_mask'] batch" + f" {feats['token_pad_mask'].shape[0]} (s_inputs should not include multiplicity)" + ) + if r_noisy.shape[0] != feats["atom_pad_mask"].shape[0] * multiplicity: + raise ValueError( + f"r_noisy batch {r_noisy.shape[0]} != atom_pad_mask batch" + f" {feats['atom_pad_mask'].shape[0]} * multiplicity {multiplicity}" + ) + + # ------------------------------------------------------------------ + # 1. Single conditioning (identical in both paths) + # ------------------------------------------------------------------ + s_trunk_mult = shardwise_repeat_interleave(s_trunk, multiplicity, 0) + s_inputs_mult = shardwise_repeat_interleave(s_inputs, multiplicity, 0) + if self.activation_checkpointing and not self.internalized_conditioning and self.training: + s, normed_fourier = torch.utils.checkpoint.checkpoint( + self.single_conditioner, times, s_trunk_mult, s_inputs_mult, use_reentrant=False + ) + else: + s, normed_fourier = self.single_conditioner(times, s_trunk_mult, s_inputs_mult) + + # Promote to at least float32 for numerical stability, preserving higher precision + compute_dtype = torch.promote_types(r_noisy.dtype, torch.float32) + + # ------------------------------------------------------------------ + # Pack atom features and r_noisy (shared by V1 and V2) + # ------------------------------------------------------------------ + # Atom features (feats) are expected in unpacked layout (with intersperse padding + # from the CP DTensor data loader). Each module calls pack_atom_features internally + # to form a self-contained pack/unpack closure, so that no external caller needs to + # pre-pack features. This ensures all modules accept atom features directly as + # produced by the data loader, and future refactoring of the data loading pipeline + # will not require changes to these modules. + # + # The pack/unpack lifecycle for window batching: + # a) upstream input r_noisy and atom feats have interspersed atom padding + # due to CP data sharding requirements + # b) pack_atom_features and distributed_pack_and_pad convert these inputs + # to packed format for the window batching of AtomAttentionEncoder and + # AtomAttentionDecoder. The q, c and p returned from AtomAttentionEncoder + # are also packed. + # c) AtomAttentionDecoder takes packed q, c and p and packed atom_pad_mask + # and atom_to_token_ids_global and outputs r_update_packed. + # d) distributed_unpad_and_unpack reverts r_update_packed to r_update + # e) all packed features are discarded after the forward pass. + # TODO: pack_atom_features and distributed_pack_and_pad should be moved + # to the data featurizing layer to save the extra compute and space + # due to interspersed padding. + W = self.atoms_per_window_queries + _keys_atom_features_packed = { + "atom_pad_mask", + "ref_pos", + "ref_space_uid", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "atom_to_token", + } + feats_packed = pack_atom_features(feats, _keys_atom_features_packed, W) + + atom_mask_mul = shardwise_repeat_interleave(feats["atom_pad_mask"].bool(), multiplicity, 0) + atom_mask_mul_expanded = shardwise_unsqueeze(atom_mask_mul, dim=-1) + r_noisy_packed, atom_mask_r_noisy_packed = distributed_pack_and_pad(r_noisy, atom_mask_mul_expanded, W, axis=1) + + if self.internalized_conditioning: + # ---- Internalized path (Boltz-1) ---- + # Compute pairwise conditioning z (skipped if cached) + if model_cache is None or len(model_cache) == 0: + z = self.pairwise_conditioner(z_trunk, relative_position_encoding) + else: + z = None + + # Atom attention encoder: computes q/c/p internally from s_trunk + z + a, q_skip, c_skip, p_skip = self.atom_attention_encoder( + feats=feats_packed, + s_trunk=s_trunk, + z=z, + r=r_noisy_packed, + multiplicity=multiplicity, + model_cache=model_cache, + ) + + # Token processing (token_pad_mask is token-level, not atom-level — use raw feats) + a = elementwise_op(a, self.s_to_a_linear(s), ElementwiseOp.SUM) + mask = feats["token_pad_mask"] + a = self.token_transformer( + a, mask=mask.to(a.dtype), s=s, z=z, multiplicity=multiplicity, model_cache=model_cache + ) + a = self.a_norm(a) + + # Atom attention decoder with internally-computed p_skip + r_update = self.atom_attention_decoder( + a=a, + q=q_skip, + c=c_skip, + p=p_skip, + feats=feats_packed, + multiplicity=multiplicity, + model_cache=model_cache, + ) + + # Unpack r_update + r_update = distributed_unpad_and_unpack( + r_update, atom_mask_r_noisy_packed, atom_mask_mul_expanded, axis=1, keep_input_padding=False + ) + return {"r_update": r_update, "token_a": a.detach()} + + else: + # ---- Externalized path (Boltz-2) ---- + # The diffusion_conditioning dict (q, c, atom_enc_bias, atom_dec_bias, + # token_trans_bias) is produced by DiffusionConditioning and arrives in + # packed layout. Unlike atom features from the data loader, these are + # inter-layer outputs whose format is managed between the producing and + # consuming modules. This coupling is intentional: changes to the data + # loading pipeline do not affect the conditioning interface between + # DiffusionConditioning and DiffusionModule. + a, q_skip, c_skip, p_skip = self.atom_attention_encoder( + feats=feats_packed, + q=diffusion_conditioning["q"].to(compute_dtype), + c=diffusion_conditioning["c"].to(compute_dtype), + atom_enc_bias=diffusion_conditioning["atom_enc_bias"].to(compute_dtype), + r=r_noisy_packed, + multiplicity=multiplicity, + ) + + # Token processing with pre-computed token_trans_bias + a = elementwise_op(a, self.s_to_a_linear(s), ElementwiseOp.SUM) + mask = feats["token_pad_mask"] + a = self.token_transformer( + a, + mask=mask.to(compute_dtype), + s=s, + z=diffusion_conditioning["token_trans_bias"].to(compute_dtype), + multiplicity=multiplicity, + ) + a = self.a_norm(a) + + # Atom attention decoder with pre-computed atom_dec_bias + r_update = self.atom_attention_decoder( + a=a, + q=q_skip, + c=c_skip, + p=diffusion_conditioning["atom_dec_bias"].to(compute_dtype), + feats=feats_packed, + multiplicity=multiplicity, + ) + + # Unpack r_update + r_update = distributed_unpad_and_unpack( + r_update, atom_mask_r_noisy_packed, atom_mask_mul_expanded, axis=1, keep_input_padding=False + ) + return r_update + + +class AtomDiffusion(nn.Module): + """DTensor AtomDiffusion for Context Parallelism. + + Wraps DiffusionModule with diffusion scheduling (noise preconditioning, + training forward, and sampling). Scalar diffusion math (c_skip, c_out, c_in, + c_noise, loss_weight) is identical between V1 and V2. + + Supports both V1 (internalized conditioning) and V2 (externalized conditioning) + via the ``internalized_conditioning`` flag inherited from the wrapped DiffusionModule. + + Atom features (``feats``) and atom-level tensors (``r_noisy``, ``coords``, + ``noise``) are expected in **unpacked** layout (with intersperse padding from + the CP DTensor data loader). The wrapped ``DiffusionModule`` calls + ``pack_atom_features`` internally — no external packing is needed. + The ``diffusion_conditioning`` dict (V2 only) arrives in packed layout as an + inter-layer output from ``DiffusionConditioning``; this format is managed + between the producing and consuming modules and is unaffected by data loading + pipeline changes. + """ + + def __init__( + self, + layer: SerialAtomDiffusionV1 | SerialAtomDiffusionV2, + device_mesh: DeviceMesh, + ring_comm: AttentionPairBiasComm | None = None, + transpose_comm: TransposeComm | None = None, + ): + """Initialize the DTensor AtomDiffusion. + + Parameters + ---------- + layer : SerialAtomDiffusionV1 | SerialAtomDiffusionV2 + The serial AtomDiffusion module. + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + ring_comm : AttentionPairBiasComm or None, optional + Ring communication for the token-level DiffusionTransformer. + transpose_comm : TransposeComm or None, optional + Transpose communication for smooth LDDT loss. Required when + add_smooth_lddt_loss is True in compute_loss. + + """ + super().__init__() + if not isinstance(layer, (SerialAtomDiffusionV1, SerialAtomDiffusionV2)): + raise TypeError(f"Expected SerialAtomDiffusionV1 or SerialAtomDiffusionV2, got {type(layer)}") + + self.device_mesh = device_mesh + self.transpose_comm = transpose_comm + + # Copy diffusion parameters (identical in V1 and V2) + self.sigma_min = layer.sigma_min + self.sigma_max = layer.sigma_max + self.sigma_data = layer.sigma_data + self.rho = layer.rho + self.P_mean = layer.P_mean + self.P_std = layer.P_std + self.num_sampling_steps = layer.num_sampling_steps + self.gamma_0 = layer.gamma_0 + self.gamma_min = layer.gamma_min + self.noise_scale = layer.noise_scale + self.step_scale = layer.step_scale + self.coordinate_augmentation = layer.coordinate_augmentation + self.alignment_reverse_diff = layer.alignment_reverse_diff + self.synchronize_sigmas = layer.synchronize_sigmas + self.token_s = layer.token_s + + # Convert the score model to DTensor version + self.score_model = DiffusionModule( + layer.score_model, + device_mesh, + ring_comm=ring_comm, + ) + + # Derive conditioning mode from the wrapped DiffusionModule + self.internalized_conditioning = self.score_model.internalized_conditioning + + # V1-only attributes + self.use_inference_model_cache = getattr(layer, "use_inference_model_cache", False) + self.accumulate_token_repr = getattr(layer, "accumulate_token_repr", False) + if self.accumulate_token_repr: + # v2 doesn't have accumulate_token_repr + if isinstance(layer, SerialAtomDiffusionV2): + raise ValueError("accumulate_token_repr should not exist in AtomDiffusionV2") + # TODO: wrap out_token_feat_update with DTensor layers if accumulate_token_repr is needed + warnings.warn("OutTokenFeatUpdate is not implemented in DTensor mode. Skipping.") + + self.register_buffer("zero", torch.tensor(0.0), persistent=False) + self.transpose_comm = deepcopy(transpose_comm) # for self.compute_loss + self.v2 = isinstance(layer, SerialAtomDiffusionV2) + + @property + def device(self): + """Get the device type of the model.""" + return self.device_mesh.device_type + + # ------------------------------------------------------------------ + # Diffusion preconditioning (DTensor scalar ops) + # ------------------------------------------------------------------ + + def _check_sigma_placement(self, sigma: DTensor) -> None: + """Validate that sigma has the expected placements (Shard(0), Replicate(), Replicate()). + + Sigma is a 1-D noise-level tensor sharded across the DP axis and + replicated across CP axes. All preconditioning helpers (c_skip, c_out, + c_in, c_noise, loss_weight) call this before operating on sigma. + """ + expected = (Shard(0), Replicate(), Replicate()) + if sigma.placements != expected: + raise ValueError(f"Sigma tensor has incorrect placements: {sigma.placements} != {expected}") + + def c_skip(self, sigma: DTensor) -> DTensor: + """Skip-connection scaling: sigma_data^2 / (sigma^2 + sigma_data^2). + + Weights the direct pass-through of noised coordinates in the + preconditioning formula: denoised = c_skip * noised + c_out * net_out. + + Parameters + ---------- + sigma : DTensor + Noise levels, shape (B*M,). Placements: (Shard(0), Replicate(), Replicate()). + + Returns + ------- + DTensor + Same shape and placements as sigma. + + """ + self._check_sigma_placement(sigma) + sigma_sq = scalar_tensor_op(2, sigma, ElementwiseOp.POW) + denom = scalar_tensor_op(self.sigma_data**2, sigma_sq, ElementwiseOp.SUM) + return scalar_tensor_op(self.sigma_data**2, denom, ElementwiseOp.DIV) + + def c_out(self, sigma: DTensor) -> DTensor: + """Output scaling: sigma * sigma_data / sqrt(sigma^2 + sigma_data^2). + + Weights the network output in the preconditioning formula: + denoised = c_skip * noised + c_out * net_out. + + Parameters + ---------- + sigma : DTensor + Noise levels, shape (B*M,). Placements: (Shard(0), Replicate(), Replicate()). + + Returns + ------- + DTensor + Same shape and placements as sigma. + + """ + self._check_sigma_placement(sigma) + numer = scalar_tensor_op(self.sigma_data, sigma, ElementwiseOp.PROD) + sigma_sq = scalar_tensor_op(2, sigma, ElementwiseOp.POW) + denom = scalar_tensor_op(self.sigma_data**2, sigma_sq, ElementwiseOp.SUM) + denom = scalar_tensor_op(0.5, denom, ElementwiseOp.POW) + return elementwise_op(numer, denom, ElementwiseOp.DIV) + + def c_in(self, sigma: DTensor) -> DTensor: + """Input scaling: 1 / sqrt(sigma^2 + sigma_data^2). + + Normalizes noised coordinates before feeding into the score model + in preconditioned_network_forward. + + Parameters + ---------- + sigma : DTensor + Noise levels, shape (B*M,). Placements: (Shard(0), Replicate(), Replicate()). + + Returns + ------- + DTensor + Same shape and placements as sigma. + + """ + self._check_sigma_placement(sigma) + sigma_sq = scalar_tensor_op(2, sigma, ElementwiseOp.POW) + denom = scalar_tensor_op(self.sigma_data**2, sigma_sq, ElementwiseOp.SUM) + denom = scalar_tensor_op(0.5, denom, ElementwiseOp.POW) + return scalar_tensor_op(1, denom, ElementwiseOp.DIV) + + def c_noise(self, sigma: DTensor) -> DTensor: + """Noise conditioning: log(sigma / sigma_data) * 0.25. + + Produces the time embedding input for the score model's + SingleConditioning / FourierEmbedding layers. + + Parameters + ---------- + sigma : DTensor + Noise levels, shape (B*M,). Placements: (Shard(0), Replicate(), Replicate()). + + Returns + ------- + DTensor + Same shape and placements as sigma. + + """ + self._check_sigma_placement(sigma) + scaled = scalar_tensor_op(1 / self.sigma_data, sigma, ElementwiseOp.PROD) + scaled_local = scaled.to_local().clamp(min=1e-20) + scaled = DTensor.from_local( + scaled_local, + device_mesh=scaled.device_mesh, + placements=scaled.placements, + shape=scaled.shape, + stride=scaled.stride(), + ) + log_sigma = single_tensor_op(scaled, ElementwiseOp.LOG) + return scalar_tensor_op(0.25, log_sigma, ElementwiseOp.PROD) + + def loss_weight(self, sigma: DTensor) -> DTensor: + """Diffusion loss weighting: (sigma^2 + sigma_data^2) / (sigma * sigma_data)^2. + + Used by compute_loss to weight the MSE loss at each noise level. + + Parameters + ---------- + sigma : DTensor + Noise levels, shape (B*M,). Placements: (Shard(0), Replicate(), Replicate()). + + Returns + ------- + DTensor + Same shape and placements as sigma. + + """ + self._check_sigma_placement(sigma) + sigma_sq = scalar_tensor_op(2, sigma, ElementwiseOp.POW) + numer = scalar_tensor_op(self.sigma_data**2, sigma_sq, ElementwiseOp.SUM) + denom = scalar_tensor_op(self.sigma_data**2, sigma_sq, ElementwiseOp.PROD) + return elementwise_op(numer, denom, ElementwiseOp.DIV) + + def noise_distribution(self, batch_size: int, dtype: torch.dtype = torch.float32) -> DTensor: + """Sample noise levels from the training distribution. + + Generates sigma_data * exp(P_mean + P_std * randn(batch_size)). + Called by forward() to produce per-sample noise levels for the + diffusion training step. + + Parameters + ---------- + batch_size : int + Number of samples (typically B*M after multiplicity expansion). + dtype : torch.dtype, optional + Dtype for the generated noise levels. Should match the model's + compute dtype (e.g. feats["coords"].dtype). Default torch.float32. + + Returns + ------- + DTensor + Noise levels, shape (batch_size,). + Placements: (Shard(0), Replicate(), Replicate()). + + """ + noise = create_distributed_randn( + (batch_size,), + device_mesh=self.device_mesh, + placements=(Shard(0), Replicate(), Replicate()), + dtype=dtype, + ) + noise = scalar_tensor_op(self.P_std, noise, ElementwiseOp.PROD) + noise = single_tensor_op(noise, ElementwiseOp.EXP) + noise = scalar_tensor_op(self.sigma_data * exp(self.P_mean), noise, ElementwiseOp.PROD) + return noise + + # ------------------------------------------------------------------ + # Preconditioned network forward + # ------------------------------------------------------------------ + + def preconditioned_network_forward( + self, + noised_atom_coords: DTensor, + sigma: float | DTensor, + network_condition_kwargs: dict, + ) -> DTensor | tuple[DTensor, DTensor]: + """Preconditioned forward pass: c_skip * x + c_out * score_model(c_in * x, c_noise). + + Parameters + ---------- + noised_atom_coords : DTensor + Noisy atom coordinates, shape (B*M, N_atoms, 3). + Placements: (Shard(0), Shard(1), Replicate()). + sigma : float or DTensor + Noise level. If float, broadcast to all batch elements. + If DTensor, shape (B*M,) with placements (Shard(0), Replicate(), Replicate()). + network_condition_kwargs : dict + Conditioning arguments for the score model. + + Returns + ------- + DTensor or tuple[DTensor, DTensor] + Internalized: ``(denoised_coords, token_a)`` + Externalized: ``denoised_coords`` + + """ + batch_size = noised_atom_coords.shape[0] + + if isinstance(sigma, float): + sigma = full( + (batch_size,), + sigma, + dtype=noised_atom_coords.dtype, + device_mesh=self.device_mesh, + placements=(Shard(0), Replicate(), Replicate()), + ) + + # Expand sigma to (B, 3) for element-wise multiply with (B, N, 3) + padded_sigma = shardwise_repeat_interleave(shardwise_unsqueeze(sigma, dim=-1), 3, -1) + + # r_noisy = c_in(sigma) * noised_atom_coords + r_noisy = replicate_op(noised_atom_coords, self.c_in(padded_sigma), 1, ReplicateOp.PROD) + times = self.c_noise(sigma) + + net_out = self.score_model( + r_noisy=r_noisy, + times=times, + **network_condition_kwargs, + ) + + if self.internalized_conditioning: + # V1: score_model returns dict {"r_update": ..., "token_a": ...} + r_update = net_out["r_update"] + token_a = net_out["token_a"] + else: + # V2: score_model returns r_update DTensor directly + r_update = net_out + + # denoised = c_skip(sigma) * noised_atom_coords + c_out(sigma) * r_update + skip_term = replicate_op(noised_atom_coords, self.c_skip(padded_sigma), 1, ReplicateOp.PROD) + out_term = replicate_op(r_update, self.c_out(padded_sigma), 1, ReplicateOp.PROD) + denoised_coords = elementwise_op(skip_term, out_term, ElementwiseOp.SUM) + + if self.internalized_conditioning: + return denoised_coords, token_a + return denoised_coords + + # ------------------------------------------------------------------ + # Sampling schedule + # ------------------------------------------------------------------ + + def sample_schedule(self, num_sampling_steps: int | None = None) -> torch.Tensor: + """Generate sigma schedule for sampling. Returns plain Tensor (scalar schedule).""" + num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps) + if num_sampling_steps < 2: + raise ValueError(f"Need at least 2 sampling steps, got {num_sampling_steps}") + inv_rho = 1 / self.rho + steps = torch.arange(num_sampling_steps, device=self.device, dtype=torch.float32) + sigmas = ( + self.sigma_max**inv_rho + + steps / (num_sampling_steps - 1) * (self.sigma_min**inv_rho - self.sigma_max**inv_rho) + ) ** self.rho + sigmas = sigmas * self.sigma_data + sigmas = torch.nn.functional.pad(sigmas, (0, 1), value=0.0) + return sigmas + + # ------------------------------------------------------------------ + # Training forward + # ------------------------------------------------------------------ + + def forward( + self, + s_inputs: DTensor, + s_trunk: DTensor, + feats: dict[str, DTensor], + # Externalized conditioning (when internalized_conditioning=False) + diffusion_conditioning: dict[str, DTensor] | None = None, + # Internalized conditioning inputs (when internalized_conditioning=True) + z_trunk: DTensor | None = None, + relative_position_encoding: DTensor | None = None, + # Common + multiplicity: int = 1, + ) -> dict[str, DTensor]: + """Training forward: add noise, run preconditioned network, return denoised coords. + + Parameters + ---------- + s_inputs : DTensor + Input single representation, shape (B, N, token_s). + s_trunk : DTensor + Trunk single representation, shape (B, N, token_s). + feats : dict[str, DTensor] + Pre-packed atom features. Must contain 'coords' and 'atom_pad_mask'. + diffusion_conditioning : dict[str, DTensor] or None + Externalized conditioning (required when internalized_conditioning=False). + z_trunk : DTensor or None + Trunk pair representation (required when internalized_conditioning=True). + relative_position_encoding : DTensor or None + Relative position encoding (required when internalized_conditioning=True). + multiplicity : int + Number of diffusion samples per batch element. + + Returns + ------- + dict[str, DTensor] + denoised_atom_coords, sigmas, aligned_true_atom_coords. + + """ + # Sanity: ensure forward args and feats shapes match conditioning mode + coords = feats["coords"] + atom_pad_mask = feats["atom_pad_mask"] + B, N_atoms = atom_pad_mask.shape[0], atom_pad_mask.shape[1] + expected_coords_shape = (B * multiplicity, N_atoms, 3) + mode = "V1" if self.internalized_conditioning else "V2" + if self.internalized_conditioning: + if diffusion_conditioning is not None: + raise ValueError("internalized_conditioning: diffusion_conditioning must be None") + if z_trunk is None or relative_position_encoding is None: + raise ValueError("internalized_conditioning: z_trunk and relative_position_encoding are required") + else: + if diffusion_conditioning is None: + raise ValueError("externalized_conditioning: diffusion_conditioning dict is required") + if z_trunk is not None or relative_position_encoding is not None: + raise ValueError("externalized_conditioning: z_trunk and relative_position_encoding must be None") + if coords.shape != expected_coords_shape: + raise ValueError( + f"{mode}: feats['coords'] expected shape {expected_coords_shape} " + f"(B*M, N_atoms, 3) from atom_pad_mask (B={B}, N={N_atoms}) and multiplicity={multiplicity}, " + f"got {coords.shape}. Caller must expand coords with multiplicity before calling forward()." + ) + batch_size = B + + coords_dtype = feats["coords"].dtype + if self.synchronize_sigmas: + sigmas = self.noise_distribution(batch_size, dtype=coords_dtype) + sigmas = shardwise_repeat_interleave(sigmas, multiplicity, 0) + else: + sigmas = self.noise_distribution(batch_size * multiplicity, dtype=coords_dtype) + + padded_sigmas = shardwise_repeat_interleave(shardwise_unsqueeze(sigmas, dim=-1), 3, -1) + + # Process atom coordinates + atom_coords = feats["coords"] + atom_mask = feats["atom_pad_mask"] + atom_mask = shardwise_repeat_interleave(atom_mask, multiplicity, 0) + + atom_coords = center_random_augmentation(atom_coords, atom_mask, augmentation=self.coordinate_augmentation) + + noise = create_distributed_randn( + atom_coords.shape, + device_mesh=self.device_mesh, + placements=atom_coords.placements, + dtype=atom_coords.dtype, + ) + + # Add noise: noised = coords + sigma * noise + noised_atom_coords = elementwise_op( + atom_coords, + replicate_op(noise, padded_sigmas, 1, ReplicateOp.PROD), + ElementwiseOp.SUM, + ) + + # Build network_condition_kwargs based on conditioning mode + if self.internalized_conditioning: + network_condition_kwargs = { + "s_inputs": s_inputs, + "s_trunk": s_trunk, + "feats": feats, + "multiplicity": multiplicity, + "z_trunk": z_trunk, + "relative_position_encoding": relative_position_encoding, + } + else: + network_condition_kwargs = { + "s_inputs": s_inputs, + "s_trunk": s_trunk, + "feats": feats, + "multiplicity": multiplicity, + "diffusion_conditioning": diffusion_conditioning, + } + + # Preconditioned network forward + precond_result = self.preconditioned_network_forward( + noised_atom_coords, + sigmas, + network_condition_kwargs=network_condition_kwargs, + ) + + # V1 returns (denoised, token_a), V2 returns denoised directly + if self.internalized_conditioning: + denoised_atom_coords = precond_result[0] + else: + denoised_atom_coords = precond_result + + result = { + "denoised_atom_coords": denoised_atom_coords, + "sigmas": sigmas, + "aligned_true_atom_coords": atom_coords, + } + if self.internalized_conditioning: + result["noised_atom_coords"] = noised_atom_coords + return result + + # ------------------------------------------------------------------ + # Sampling (inference) + # ------------------------------------------------------------------ + + def sample( + self, + atom_mask: DTensor, + num_sampling_steps: int | None = None, + multiplicity: int = 1, + max_parallel_samples: int | None = None, + train_accumulate_token_repr: bool = False, + **network_condition_kwargs, + ) -> dict[str, DTensor | None]: + """Sample from the diffusion model (inference denoising loop). + + Parameters + ---------- + atom_mask : DTensor + Atom mask, shape (B, N_atoms). + Placements: (Shard(0), Shard(1), Replicate()). + num_sampling_steps : int or None, optional + Number of sampling steps. If None, uses default. + multiplicity : int, optional + Multiplicity factor, by default 1. + max_parallel_samples : int or None, optional + Maximum multiplicity samples processed per chunk. If None, + all samples are processed in a single call. + train_accumulate_token_repr : bool, optional + V1-specific flag for accumulating token representations. + Not yet implemented in DTensor mode; raises NotImplementedError if True. + **network_condition_kwargs + Additional conditioning. For externalized: s_inputs, s_trunk, feats, + diffusion_conditioning. For internalized: s_inputs, s_trunk, feats, + z_trunk, relative_position_encoding. + + Returns + ------- + dict[str, DTensor | None] + sample_atom_coords: denoised coordinates, diff_token_repr: always None + (accumulate_token_repr not yet implemented in DTensor mode). + + """ + if train_accumulate_token_repr: + raise NotImplementedError("train_accumulate_token_repr not implemented in DTensor mode yet") + + if max_parallel_samples is None: + max_parallel_samples = multiplicity + + num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps) + atom_mask = shardwise_repeat_interleave(atom_mask, multiplicity, 0) + shape = (*atom_mask.shape, 3) + + # Sampling schedule (plain Tensor, deterministic) + sigmas = self.sample_schedule(num_sampling_steps) + gammas = torch.where(sigmas > self.gamma_min, self.gamma_0, 0.0) + sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[1:])) + + # Initial noise + init_sigma = sigmas[0].item() + atom_coords = create_distributed_randn( + shape, + device_mesh=self.device_mesh, + placements=atom_mask.placements, + scale=init_sigma, + ) + atom_coords_denoised = None + + # V1: model_cache for inference optimization + model_cache = {} if self.use_inference_model_cache else None + + # V1: token representation tracking for accumulate_token_repr + token_repr = None + token_a = None + + # Denoising loop + for step_idx, (sigma_tm, sigma_t, gamma) in enumerate(sigmas_and_gammas): + aug = self.coordinate_augmentation + result = center_random_augmentation( + atom_coords, + atom_mask, + augmentation=aug, + return_second_coords=True, + second_coords=atom_coords_denoised, + return_roto=aug, + ) + if aug: + atom_coords, atom_coords_denoised, _ = result + else: + atom_coords, atom_coords_denoised = result + + sigma_tm, sigma_t, gamma = sigma_tm.item(), sigma_t.item(), gamma.item() + t_hat = sigma_tm * (1 + gamma) + noise_var = self.noise_scale**2 * (t_hat**2 - sigma_tm**2) + eps = create_distributed_randn( + shape, + device_mesh=self.device_mesh, + placements=atom_mask.placements, + scale=sqrt(noise_var), + ) + atom_coords_noisy = elementwise_op(atom_coords, eps, ElementwiseOp.SUM) + + with torch.no_grad(): + placements = atom_coords_noisy.placements + noisy_local = atom_coords_noisy.to_local() + denoised_local = torch.zeros_like(noisy_local) + if noisy_local.shape[0] % multiplicity != 0: + # this should only happen if all upstream DTensor modules have removed + # the non-even sharding requirements and that we actual have a unevenly + # sharded batch dimension + raise ValueError( + f"noisy_local.shape[0] is not divisible by multiplicity: {noisy_local.shape[0]} % {multiplicity} = {noisy_local.shape[0] % multiplicity}" + ) + B_local = noisy_local.shape[0] // multiplicity + + sample_ids = torch.arange(multiplicity, device=self.device) + n_chunks = (multiplicity + max_parallel_samples - 1) // max_parallel_samples + sample_ids_chunks = sample_ids.chunk(n_chunks) + + for sample_ids_chunk in sample_ids_chunks: + # noisy_local is (B_local*M, N, 3). + # Unflatten to (B_local, M, N, 3) and index the M axis so each + # chunk selects the correct multiplicity slices, then reflatten + # to (B_local*chunk_M, N, 3) and rebuild as DTensor. + chunk_M = sample_ids_chunk.numel() + noisy_chunk_local = noisy_local.unflatten(0, (B_local, multiplicity))[:, sample_ids_chunk].flatten( + 0, 1 + ) + chunk_global_shape = ( + atom_coords_noisy.shape[0] * chunk_M // multiplicity, + atom_coords_noisy.shape[1], + 3, + ) + noisy_chunk_dt = DTensor.from_local( + noisy_chunk_local, + device_mesh=self.device_mesh, + placements=placements, + shape=chunk_global_shape, + stride=LayoutRightMap(chunk_global_shape).strides, + ) + + precond_kwargs = dict(multiplicity=chunk_M, **network_condition_kwargs) + if model_cache is not None: + precond_kwargs["model_cache"] = model_cache + + precond_result = self.preconditioned_network_forward( + noisy_chunk_dt, + t_hat, + network_condition_kwargs=precond_kwargs, + ) + + if self.internalized_conditioning: + denoised_chunk_dt, token_a = precond_result + else: + denoised_chunk_dt = precond_result + + denoised_local.unflatten(0, (B_local, multiplicity))[:, sample_ids_chunk] = ( + denoised_chunk_dt.to_local().unflatten(0, (B_local, chunk_M)) + ) + + atom_coords_denoised = DTensor.from_local( + denoised_local, + device_mesh=self.device_mesh, + placements=placements, + shape=atom_coords_noisy.shape, + stride=atom_coords_noisy.stride(), + ) + + # TODO: accumulate_token_repr support (requires DTensor wrapping of OutTokenFeatUpdate) + + # Alignment reverse diffusion: align noisy coords to denoised coords + if self.alignment_reverse_diff: + atom_coords_noisy = weighted_rigid_align( + atom_coords_noisy, + atom_coords_denoised, + atom_mask, + atom_mask, + ) + + # Next step: x_{t+1} = x_noisy + step_scale * (sigma_t - t_hat) * (x_noisy - x_denoised) / t_hat + denoised_over_sigma = scalar_tensor_op( + 1 / t_hat, + elementwise_op(atom_coords_noisy, atom_coords_denoised, ElementwiseOp.SUB), + ElementwiseOp.PROD, + ) + atom_coords = elementwise_op( + atom_coords_noisy, + scalar_tensor_op(self.step_scale * (sigma_t - t_hat), denoised_over_sigma, ElementwiseOp.PROD), + ElementwiseOp.SUM, + ) + + return {"sample_atom_coords": atom_coords, "diff_token_repr": token_repr} + + # ------------------------------------------------------------------ + # Compute loss + # ------------------------------------------------------------------ + + def compute_loss( + self, + feats: dict[str, DTensor], + out_dict: dict[str, DTensor], + add_smooth_lddt_loss: bool = True, + nucleotide_loss_weight: float = 5.0, + ligand_loss_weight: float = 10.0, + multiplicity: int = 1, + filter_by_plddt: float = 0.0, + use_triton_kernel: bool = True, + ) -> dict[str, DTensor]: + """Compute loss for the diffusion model. + + Parameters + ---------- + feats : dict[str, DTensor] + Features for the diffusion model. + out_dict : dict[str, DTensor] + Output dictionary from the diffusion model. + add_smooth_lddt_loss : bool, optional + Whether to add smooth LDDT loss. + nucleotide_loss_weight : float, optional + Weight for nucleotide loss. + ligand_loss_weight : float, optional + Weight for ligand loss. + multiplicity : int, optional + Multiplicity factor, by default 1. + filter_by_plddt : float, optional + Filter by pLDDT threshold, by default 0.0. + use_triton_kernel : bool, optional + Whether to use Triton kernel for smooth LDDT loss. + + Returns + ------- + dict[str, DTensor] + Loss dictionary containing "loss" and "loss_breakdown" keys. + "loss" is the total loss tensor. + "loss_breakdown" is a dictionary containing "mse_loss" and "smooth_lddt_loss" keys. + "mse_loss" is the MSE loss tensor. + "smooth_lddt_loss" is the smooth LDDT loss tensor. + """ + if not self.v2 and filter_by_plddt != 0.0: + raise ValueError("filter_by_plddt is only supported for V2") + + with torch.autocast("cuda", enabled=False): + safe_dtype = torch.promote_types(torch.float32, out_dict["denoised_atom_coords"].dtype) + denoised_atom_coords = out_dict["denoised_atom_coords"].to(dtype=safe_dtype) + sigmas = out_dict["sigmas"].to(dtype=safe_dtype) + + resolved_atom_mask_uni = feats["atom_resolved_mask"].to(dtype=safe_dtype) + resolved_atom_mask = shardwise_repeat_interleave(resolved_atom_mask_uni, multiplicity, 0) + + if self.v2 and filter_by_plddt > 0: + if "plddt" not in feats: + raise RuntimeError("Missing required plddt data in feats for plddt filtering") + plddt_mask = scalar_tensor_op(filter_by_plddt, feats["plddt"], ElementwiseOp.LT) + resolved_atom_mask_uni_plddt_masked = elementwise_op( + resolved_atom_mask_uni, plddt_mask.to(dtype=safe_dtype), ElementwiseOp.PROD + ) + resolved_atom_mask_plddt_masked = shardwise_repeat_interleave( + resolved_atom_mask_uni_plddt_masked, multiplicity, 0 + ) + else: + resolved_atom_mask_uni_plddt_masked = resolved_atom_mask_uni + resolved_atom_mask_plddt_masked = resolved_atom_mask + + atom_type = single_repr_token_to_atom( + feats["mol_type"].to(dtype=safe_dtype), feats["atom_to_token"].to(dtype=safe_dtype) + ) + atom_type_mult = shardwise_repeat_interleave(atom_type, multiplicity, 0) + + is_nucleotide_mult = elementwise_op( + scalar_tensor_op( + const.chain_type_ids["DNA"], + atom_type_mult, + ElementwiseOp.EQUAL, + ), + scalar_tensor_op( + const.chain_type_ids["RNA"], + atom_type_mult, + ElementwiseOp.EQUAL, + ), + ElementwiseOp.SUM, # or equivalently OR + ) + + nucleotide_loss_weights = scalar_tensor_op( + nucleotide_loss_weight, + is_nucleotide_mult, + ElementwiseOp.PROD, + ) + + ligand_loss_weights = scalar_tensor_op( + ligand_loss_weight, + scalar_tensor_op( + const.chain_type_ids["NONPOLYMER"], + atom_type_mult, + ElementwiseOp.EQUAL, + ), + ElementwiseOp.PROD, + ) + align_weights = scalar_tensor_op( + 1.0, + elementwise_op( + nucleotide_loss_weights, + ligand_loss_weights, + ElementwiseOp.SUM, + ), + ElementwiseOp.SUM, + ) + + with torch.no_grad(): + atom_coords = out_dict["aligned_true_atom_coords"].to(dtype=safe_dtype) + atom_coords_aligned_ground_truth = weighted_rigid_align( + atom_coords, + denoised_atom_coords, + align_weights.to(dtype=safe_dtype), + mask=resolved_atom_mask.to(dtype=safe_dtype), + ) + # Cast back + atom_coords_aligned_ground_truth: DTensor = atom_coords_aligned_ground_truth.to( + dtype=denoised_atom_coords.dtype + ) + + # Weighted MSE loss of denoised atom positions (match serial v2 formula) + mse_loss = elementwise_op(denoised_atom_coords, atom_coords_aligned_ground_truth, ElementwiseOp.SUB) + mse_loss = scalar_tensor_op(2.0, mse_loss, ElementwiseOp.POW) + mse_loss = shardwise_sum(mse_loss, dim=-1) + mse_loss = elementwise_op(mse_loss, resolved_atom_mask_plddt_masked, ElementwiseOp.PROD) + + resolved_align_weights = elementwise_op(align_weights, resolved_atom_mask_plddt_masked, ElementwiseOp.PROD) + denom = sharded_sum( + scalar_tensor_op(3.0, resolved_align_weights, ElementwiseOp.PROD), + dim=-1, + ) + if self.v2: + denom = scalar_tensor_op(1e-5, denom, ElementwiseOp.SUM) + + mse_loss = elementwise_op(mse_loss, resolved_align_weights, ElementwiseOp.PROD) + mse_loss = sharded_sum(mse_loss, dim=-1) + mse_loss = elementwise_op(mse_loss, denom, ElementwiseOp.DIV) + loss_weights = self.loss_weight(sigmas) + + mse_loss = elementwise_op(mse_loss, loss_weights, ElementwiseOp.PROD) + mse_loss = scalar_tensor_op( + 1.0 / mse_loss.shape[0], + sharded_sum(mse_loss, dim=0), + ElementwiseOp.PROD, + ) + total_loss = mse_loss + + if add_smooth_lddt_loss: + is_nucleotide = elementwise_op( + scalar_tensor_op(const.chain_type_ids["DNA"], atom_type, ElementwiseOp.EQUAL), + scalar_tensor_op(const.chain_type_ids["RNA"], atom_type, ElementwiseOp.EQUAL), + ElementwiseOp.SUM, + ) + loss_func = ( + smooth_lddt_loss_triton + if use_triton_kernel and self.device_mesh.device_type == "cuda" + else smooth_lddt_loss + ) + + lddt_loss = loss_func( + denoised_atom_coords, + atom_coords, + is_nucleotide=is_nucleotide, + coords_mask=resolved_atom_mask_uni_plddt_masked, + comm=self.transpose_comm, + multiplicity=multiplicity, + v2=self.v2, + ) + total_loss = elementwise_op(total_loss, lddt_loss, ElementwiseOp.SUM) + else: + lddt_loss = zeros( + total_loss.shape, + requires_grad=False, + device_mesh=total_loss.device_mesh, + placements=total_loss.placements, + ) + + loss_breakdown = { + "mse_loss": mse_loss, + "smooth_lddt_loss": lddt_loss, + } + + return {"loss": total_loss, "loss_breakdown": loss_breakdown} diff --git a/src/boltz/distributed/model/modules/diffusion_conditioning.py b/src/boltz/distributed/model/modules/diffusion_conditioning.py new file mode 100644 index 000000000..20e1b5bbd --- /dev/null +++ b/src/boltz/distributed/model/modules/diffusion_conditioning.py @@ -0,0 +1,197 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""DTensor-compatible DiffusionConditioning module for Context Parallelism. + +V2-only module that pre-computes conditioning (pair features, atom encoder, +bias projections) outside the diffusion loop. This is a Boltz-2-specific +refactoring that moves conditioning from inside the diffusion step to a +one-time pre-computation. +""" + +from torch import nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.nn import Module + +from boltz.distributed.data.feature.featurizer import pack_atom_features +from boltz.distributed.model.layers.cat_and_chunk import shardwise_cat +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.linear import LinearParamsReplicated +from boltz.distributed.model.modules.encoders import AtomEncoder, PairwiseConditioning +from boltz.model.modules.diffusion_conditioning import DiffusionConditioning as SerialDiffusionConditioning + + +class DiffusionConditioning(Module): + """DTensor DiffusionConditioning for Context Parallelism (V2 only). + + Pre-computes: + - Pairwise conditioning (z) + - Atom encoder (q, c, p) + - Bias projections for atom encoder, atom decoder, and token transformer layers + """ + + def __init__( + self, + layer: SerialDiffusionConditioning, + device_mesh: DeviceMesh, + ): + """Initialize the DTensor DiffusionConditioning. + + Parameters + ---------- + layer : SerialDiffusionConditioning + The serial DiffusionConditioning module (V2 only). + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + + """ + super().__init__() + assert isinstance( + layer, SerialDiffusionConditioning + ), f"Expected SerialDiffusionConditioning, got {type(layer)}" + self.device_mesh = device_mesh + self.atoms_per_window_queries = layer.atom_encoder.atoms_per_window_queries + + # Wrap child modules + self.pairwise_conditioner = PairwiseConditioning( + layer=layer.pairwise_conditioner, + device_mesh=device_mesh, + ) + self.atom_encoder = AtomEncoder( + layer=layer.atom_encoder, + device_mesh=device_mesh, + ) + + # Bias projection layers: ModuleList of Sequential(LayerNorm, Linear) + self.atom_enc_proj_z = nn.ModuleList() + for serial_seq in layer.atom_enc_proj_z: + self.atom_enc_proj_z.append( + nn.Sequential( + LayerNormParamsReplicated(serial_seq[0], device_mesh), + LinearParamsReplicated(serial_seq[1], device_mesh), + ) + ) + + self.atom_dec_proj_z = nn.ModuleList() + for serial_seq in layer.atom_dec_proj_z: + self.atom_dec_proj_z.append( + nn.Sequential( + LayerNormParamsReplicated(serial_seq[0], device_mesh), + LinearParamsReplicated(serial_seq[1], device_mesh), + ) + ) + + self.token_trans_proj_z = nn.ModuleList() + for serial_seq in layer.token_trans_proj_z: + self.token_trans_proj_z.append( + nn.Sequential( + LayerNormParamsReplicated(serial_seq[0], device_mesh), + LinearParamsReplicated(serial_seq[1], device_mesh), + ) + ) + + def forward( + self, + s_trunk: DTensor, + z_trunk: DTensor, + relative_position_encoding: DTensor, + feats: dict[str, DTensor], + ) -> tuple[DTensor, DTensor, DTensor, DTensor, DTensor]: + """Forward pass of the DTensor DiffusionConditioning. + + Parameters + ---------- + s_trunk : DTensor + Token single representation with shape (B, N, token_s). + Placements: (Shard(0), Shard(1), Replicate()). + z_trunk : DTensor + Token pair representation with shape (B, N, N, token_z). + Placements: (Shard(0), Shard(1), Shard(2)). + relative_position_encoding : DTensor + Relative position encoding with shape (B, N, N, token_z). + Placements: (Shard(0), Shard(1), Shard(2)). + feats : dict[str, DTensor] + Unpacked atom features (with intersperse padding from the CP DTensor + data loader). The module calls ``pack_atom_features`` internally. + + Returns + ------- + tuple[DTensor, DTensor, DTensor, DTensor, DTensor] + q : DTensor with shape (B, N_atoms_packed, atom_s). + Placements: (Shard(0), Shard(1), Replicate()). + c : DTensor with shape (B, N_atoms_packed, atom_s). + Placements: (Shard(0), Shard(1), Replicate()). + atom_enc_bias : DTensor with shape (B, K, W, H, total_atom_enc_heads). + Placements: (Shard(0), Shard(1), Replicate()). + Window-batched atom pair repr; K is sharded, W/H are local window dims. + atom_dec_bias : DTensor with shape (B, K, W, H, total_atom_dec_heads). + Placements: (Shard(0), Shard(1), Replicate()). + Window-batched atom pair repr; K is sharded, W/H are local window dims. + token_trans_bias : DTensor with shape (B, N, N, total_token_trans_heads). + Placements: (Shard(0), Shard(1), Shard(2)). + Token pair repr; both token dims are sharded. + + """ + # Atom features (feats) are expected in unpacked layout (with intersperse padding + # from the CP DTensor data loader). Each module calls pack_atom_features internally + # to form a self-contained pack/unpack closure, so that no external caller needs to + # pre-pack features. This ensures all modules accept atom features directly as + # produced by the data loader, and future refactoring of the data loading pipeline + # will not require changes to these modules. + _keys_atom_features_packed = { + "atom_pad_mask", + "ref_pos", + "ref_space_uid", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "atom_to_token", + } + feats_packed = pack_atom_features(feats, _keys_atom_features_packed, self.atoms_per_window_queries) + + # Pairwise conditioning + z = self.pairwise_conditioner(z_trunk, relative_position_encoding) + + # Atom encoder: q, c, p (all in packed layout) + q, c, p = self.atom_encoder(feats=feats_packed, s_trunk=s_trunk, z=z) + + # Atom encoder bias projections: project p (window-batched atom pair) through each layer, concatenate + # p: (B, K, W, H, atom_z) with placements (S(0), S(1), R) — K sharded, W/H local window dims + atom_enc_bias_list = [] + for proj in self.atom_enc_proj_z: + atom_enc_bias_list.append(proj(p)) + atom_enc_bias = shardwise_cat(atom_enc_bias_list, dim=-1) + + # Atom decoder bias projections + atom_dec_bias_list = [] + for proj in self.atom_dec_proj_z: + atom_dec_bias_list.append(proj(p)) + atom_dec_bias = shardwise_cat(atom_dec_bias_list, dim=-1) + + # Token transformer bias projections: project z (token pair) + # z: (B, N, N, token_z) with placements (S(0), S(1), S(2)) + token_trans_bias_list = [] + for proj in self.token_trans_proj_z: + token_trans_bias_list.append(proj(z)) + token_trans_bias = shardwise_cat(token_trans_bias_list, dim=-1) + + return q, c, atom_enc_bias, atom_dec_bias, token_trans_bias diff --git a/src/boltz/distributed/model/modules/encoders.py b/src/boltz/distributed/model/modules/encoders.py new file mode 100644 index 000000000..74cf5433f --- /dev/null +++ b/src/boltz/distributed/model/modules/encoders.py @@ -0,0 +1,1491 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""DTensor-compatible encoder/decoder modules for Context Parallelism. + +Compatible with both Boltz-1x and Boltz-2 serial modules. Focuses on the +window-batching variant. + +Modules: +- RelativePositionEncoder: pairwise relative position features → linear projection +- AtomAttentionDecoder: token → atom position updates +- AtomAttentionEncoder / _atom_encoder: atom-level encoder with window batching +""" + +import copy +from math import pi + +import torch +from torch import nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor import zeros as dtensor_zeros +from torch.nn import Module +from torch.nn.functional import one_hot + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.model.layers.cat_and_chunk import shardwise_cat +from boltz.distributed.model.layers.elementwise_op import ( + ElementwiseOp, + elementwise_op, + scalar_tensor_op, + single_tensor_op, +) +from boltz.distributed.model.layers.flatten_and_unflatten import ( + shardwise_flatten, + shardwise_flatten_sharded, + shardwise_unflatten_sharded, +) +from boltz.distributed.model.layers.gather import distributed_gather +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.linear import LinearParamsReplicated +from boltz.distributed.model.layers.outer_gather import distributed_outer_gather +from boltz.distributed.model.layers.redistribute_transpose import redistribute_transpose +from boltz.distributed.model.layers.repeat_interleave import shardwise_repeat_interleave +from boltz.distributed.model.layers.replicate_op import ReplicateOp, replicate_op +from boltz.distributed.model.layers.scatter import distributed_scatter_reduce +from boltz.distributed.model.layers.shardwise_op import ShardwiseOuterOp, shardwise_outer_op, shardwise_sum +from boltz.distributed.model.layers.squeeze import shardwise_unsqueeze +from boltz.distributed.model.layers.transition import Transition +from boltz.distributed.model.layers.utils import convert_single_repr_to_window_batched_key +from boltz.distributed.model.modules.transformers import AtomTransformer +from boltz.distributed.model.modules.utils import validate_window_batching_parameters +from boltz.model.modules.encoders import AtomAttentionDecoder as AtomAttentionDecoderBoltz1 +from boltz.model.modules.encoders import AtomAttentionEncoder as AtomAttentionEncoderBoltz1 +from boltz.model.modules.encoders import FourierEmbedding as FourierEmbeddingBoltz1 +from boltz.model.modules.encoders import PairwiseConditioning as PairwiseConditioningBoltz1 +from boltz.model.modules.encoders import RelativePositionEncoder as SerialRelativePositionEncoderV1 +from boltz.model.modules.encoders import SingleConditioning as SingleConditioningBoltz1 +from boltz.model.modules.encodersv2 import AtomAttentionDecoder as AtomAttentionDecoderBoltz2 +from boltz.model.modules.encodersv2 import AtomAttentionEncoder as AtomAttentionEncoderBoltz2 +from boltz.model.modules.encodersv2 import AtomEncoder as AtomEncoderBoltz2 +from boltz.model.modules.encodersv2 import FourierEmbedding as FourierEmbeddingBoltz2 +from boltz.model.modules.encodersv2 import PairwiseConditioning as PairwiseConditioningBoltz2 +from boltz.model.modules.encodersv2 import RelativePositionEncoder as SerialRelativePositionEncoderV2 +from boltz.model.modules.encodersv2 import SingleConditioning as SingleConditioningBoltz2 + + +class RelativePositionEncoder(Module): + """DTensor RelativePositionEncoder for Boltz-1x and Boltz-2. + + Computes pairwise relative position features from single-representation + features (``asym_id``, ``residue_index``, ``entity_id``, ``token_index``, + ``sym_id``, ``cyclic_period``) and projects them through a linear layer. + + Under context parallelism the single features are sharded along the token + dimension. Pairwise outer comparisons (``feat_i[:, :, None] op feat_j[:, None, :]``) + require the "column" shard from the transposed rank, obtained via + ``redistribute_transpose``. + + All intermediate computation (outer ops, clipping, one-hot encoding) is + non-differentiable and operates on local tensors. Only the final linear + projection (``LinearParamsReplicated``) is differentiable. + + Communication budget (forward): + - 6 ``redistribute_transpose`` calls (one per feature key) for the + column shards. Each is a single P2P send/recv. + - 0 additional collectives (the ``LinearParamsReplicated`` backward + handles gradient all-reduce). + + Supports both Boltz-1x and Boltz-2 serial modules: + - Boltz-1x: ``RelativePositionEncoder(token_z, r_max, s_max)`` + - Boltz-2 adds ``fix_sym_check`` and ``cyclic_pos_enc`` flags. + """ + + # Feature keys that need column-shard transpose + _KEYS_TO_TRANSPOSE = ("asym_id", "entity_id", "residue_index", "token_index", "sym_id", "cyclic_period") + + def __init__( + self, + layer: Module, + device_mesh: DeviceMesh, + transpose_comm: TransposeComm, + ) -> None: + """Initialize the distributed RelativePositionEncoder. + + Parameters + ---------- + layer : Module + Serial RelativePositionEncoder (v1 or v2). + device_mesh : DeviceMesh + The device mesh (subgroups mesh with dp, cp_axis_0, cp_axis_1). + transpose_comm : TransposeComm + Transpose communication for distributed outer operations. + Separate deep copies are created for each feature key. + """ + if not isinstance(layer, (SerialRelativePositionEncoderV1, SerialRelativePositionEncoderV2)): + raise TypeError(f"Expected SerialRelativePositionEncoderV1 or V2, got {type(layer)}") + super().__init__() + self.device_mesh = device_mesh + self.r_max = layer.r_max + self.s_max = layer.s_max + + # V2-only flags (default to V1 behaviour). + # V1 has no cyclic_pos_enc attr but always applies the cyclic correction, + # so default to True to unify V1 and V2-with-flag-on into the same branch. + self.fix_sym_check = getattr(layer, "fix_sym_check", False) + self.cyclic_pos_enc = getattr(layer, "cyclic_pos_enc", True) + + # Wrap the linear layer + self.linear_layer = LinearParamsReplicated(layer.linear_layer, device_mesh=device_mesh) + + # One TransposeComm per feature key (separate P2P buffers). + # TransposeComm is not an nn.Module, so store as plain dict + individual attrs. + self._transpose_comms: dict[str, TransposeComm] = {} + for i, key in enumerate(self._KEYS_TO_TRANSPOSE): + tc = transpose_comm if i == 0 else copy.deepcopy(transpose_comm) + self._transpose_comms[key] = tc + setattr(self, f"_tc_{key}", tc) # for pickling visibility + + def forward(self, feats: dict[str, DTensor]) -> DTensor: + """Compute relative position embeddings. + + Parameters + ---------- + feats : dict[str, DTensor] + Must contain keys: ``asym_id``, ``entity_id``, ``residue_index``, + ``token_index``, ``sym_id``, ``cyclic_period``. + Each has shape ``(B, N)`` with placements + ``(Shard(0), Shard(1), Replicate())``. + + Returns + ------- + DTensor + Relative position embeddings, shape ``(B, N, N, token_z)``, + placements ``(Shard(0), Shard(1), Shard(2))``. + """ + expected_placements = (Shard(0), Shard(1)) + for key in self._KEYS_TO_TRANSPOSE: + dt = feats[key] + if not isinstance(dt, DTensor): + raise TypeError(f"Expected DTensor for '{key}', got {type(dt)}") + # Check first two placements (third may be Replicate for 3D mesh) + if dt.placements[:2] != expected_placements: + raise ValueError( + f"Expected '{key}' placements to start with {expected_placements}, got {dt.placements}" + ) + + # Get column shards via redistribute_transpose. + # Row feats: (B, N) placements (Shard(0), Shard(1), Replicate()) + # Column feats: swap cp_axis_0 ↔ cp_axis_1 → (Shard(0), Replicate(), Shard(1)) + # so each rank gets the transpose peer's token shard for outer ops. + col_placements = (Shard(0), Replicate(), Shard(1)) + + feats_col = {} + for key in self._KEYS_TO_TRANSPOSE: + feats_col[key] = redistribute_transpose( + feats[key], + transpose_comm=self._transpose_comms[key], + output_placements=col_placements, + dim0=None, + dim1=None, + ) + + # Extract local tensors for non-differentiable feature computation + row = {k: feats[k].to_local() for k in self._KEYS_TO_TRANSPOSE} + col = {k: feats_col[k].to_local() for k in self._KEYS_TO_TRANSPOSE} + + # Pairwise comparisons: row[:, :, None] op col[:, None, :] + b_same_chain = torch.eq(row["asym_id"][:, :, None], col["asym_id"][:, None, :]) + b_same_residue = torch.eq(row["residue_index"][:, :, None], col["residue_index"][:, None, :]) + b_same_entity = torch.eq(row["entity_id"][:, :, None], col["entity_id"][:, None, :]) + + d_residue = row["residue_index"][:, :, None] - col["residue_index"][:, None, :] + + # Cyclic period adjustment. + # The serial code guards with torch.any(feats["cyclic_period"] > 0) + # over the full tensor, but in CP each rank only sees its local row + # and column shards. Rather than broadcasting a flag across ranks, + # we unconditionally apply the correction — the torch.where with + # fallback period=10000 makes it a no-op when no token has a + # positive cyclic period (round(d/10000) == 0 for typical d). + if self.cyclic_pos_enc: + period = torch.where( + col["cyclic_period"] > 0, + col["cyclic_period"], + torch.zeros_like(col["cyclic_period"]) + 10000, + ) + d_residue = (d_residue - period[:, None, :] * torch.round(d_residue / period[:, None, :])).long() + # cyclic_pos_enc=False (V2 only): skip cyclic correction entirely + + d_residue = torch.clip(d_residue + self.r_max, 0, 2 * self.r_max) + d_residue = torch.where(b_same_chain, d_residue, torch.zeros_like(d_residue) + 2 * self.r_max + 1) + a_rel_pos = one_hot(d_residue, 2 * self.r_max + 2) + + d_token = torch.clip( + row["token_index"][:, :, None] - col["token_index"][:, None, :] + self.r_max, + 0, + 2 * self.r_max, + ) + d_token = torch.where( + b_same_chain & b_same_residue, + d_token, + torch.zeros_like(d_token) + 2 * self.r_max + 1, + ) + a_rel_token = one_hot(d_token, 2 * self.r_max + 2) + + d_chain = torch.clip( + row["sym_id"][:, :, None] - col["sym_id"][:, None, :] + self.s_max, + 0, + 2 * self.s_max, + ) + if self.fix_sym_check: + # V2 path: sentinel when NOT same entity + d_chain = torch.where( + ~b_same_entity, + torch.zeros_like(d_chain) + 2 * self.s_max + 1, + d_chain, + ) + else: + # V1 path: sentinel when same chain + d_chain = torch.where( + b_same_chain, + torch.zeros_like(d_chain) + 2 * self.s_max + 1, + d_chain, + ) + a_rel_chain = one_hot(d_chain, 2 * self.s_max + 2) + + # Concatenate and cast to linear weight dtype + dtype = self.linear_layer.weight.to_local().dtype + features_local = torch.cat( + [ + a_rel_pos.to(dtype), + a_rel_token.to(dtype), + b_same_entity.unsqueeze(-1).to(dtype), + a_rel_chain.to(dtype), + ], + dim=-1, + ) + + # Wrap as DTensor with pair placements for the linear layer + pair_placements = (Shard(0), Shard(1), Shard(2)) + # Pad placements to match mesh ndim (e.g., 3D mesh: dp, cp0, cp1) + if self.device_mesh.ndim > 3: + raise ValueError(f"Expected device mesh ndim <= 3, got {self.device_mesh.ndim}") + full_pair_placements = pair_placements[: self.device_mesh.ndim] + + # Compute global shape and contiguous strides for DTensor.from_local + asym_dt = feats["asym_id"] + B_global = asym_dt.shape[0] + N_global = asym_dt.shape[1] + feat_dim = features_local.shape[-1] + global_shape = (B_global, N_global, N_global, feat_dim) + global_stride = ( + N_global * N_global * feat_dim, + N_global * feat_dim, + feat_dim, + 1, + ) + + features_dt = DTensor.from_local( + features_local.contiguous(), + self.device_mesh, + full_pair_placements, + shape=global_shape, + stride=global_stride, + ) + + return self.linear_layer(features_dt) + + +class AtomAttentionDecoder(Module): + """DTensor AtomAttentionDecoder for window batching. + + Compatible with both Boltz-1x and Boltz-2 serial AtomAttentionDecoder modules. + + This module converts token representations to atom-level position updates + using an AtomTransformer. The window batching variant uses distributed_gather + with global atom-to-token indices instead of torch.bmm with one-hot matrices. + + Key operations: + 1. Transform token repr to atom space via a_to_q_trans (Linear) + 2. Gather from token to atom positions via distributed_gather + 3. Run AtomTransformer (window-batched) + 4. Project atom features to position updates via atom_feat_to_atom_pos_update + """ + + def __init__( + self, + layer: nn.Module, + device_mesh: DeviceMesh, + ): + """Initialize the DTensor-distributed atom attention decoder. + + Parameters + ---------- + layer : nn.Module + The serial AtomAttentionDecoder module to be distributed. + Accepts both Boltz-1x and Boltz-2 versions. + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + + Raises + ------ + TypeError + If layer is not a recognized type. + """ + super().__init__() + + if not isinstance(layer, (AtomAttentionDecoderBoltz1, AtomAttentionDecoderBoltz2)): + raise TypeError( + ", ".join( + [ + f"Instance {layer} should have type " + f"{AtomAttentionDecoderBoltz1} or {AtomAttentionDecoderBoltz2}", + f"but instead has type {type(layer)}.", + ] + ) + ) + + self.could_use_model_cache = isinstance(layer, AtomAttentionDecoderBoltz1) + self.attn_window_queries = layer.atom_decoder.attn_window_queries + self.attn_window_keys = layer.atom_decoder.attn_window_keys + validate_window_batching_parameters(self.attn_window_queries, self.attn_window_keys, use_window_batching=True) + + # a_to_q_trans: LinearNoBias(2 * token_s, atom_s) + self.a_to_q_trans = LinearParamsReplicated(layer_local=layer.a_to_q_trans, device_mesh=device_mesh) + + # atom_decoder: DTensor AtomTransformer (window batching) + self.atom_decoder = AtomTransformer(layer=layer.atom_decoder, device_mesh=device_mesh) + + # atom_feat_to_atom_pos_update: + # Boltz-1: always Sequential(LayerNorm, LinearNoBias) -- no post_layer_norm support + # Boltz-2 with transformer_post_layer_norm=False: Sequential(LayerNorm, LinearNoBias) + # Boltz-2 with transformer_post_layer_norm=True: just LinearNoBias (no LayerNorm) + # + # Infer transformer_post_layer_norm from the inner DiffusionTransformerLayer's + # post_lnorm attribute: nn.LayerNorm means True, nn.Identity means False. + # Boltz-1 DiffusionTransformerLayer has no post_lnorm attribute (always False). + first_dtl = layer.atom_decoder.diffusion_transformer.layers[0] + transformer_post_layer_norm = hasattr(first_dtl, "post_lnorm") and not isinstance( + first_dtl.post_lnorm, nn.Identity + ) + + if transformer_post_layer_norm: + # Boltz-2 with transformer_post_layer_norm=True: just LinearNoBias + self.atom_feat_to_atom_pos_update = LinearParamsReplicated( + layer_local=layer.atom_feat_to_atom_pos_update, device_mesh=device_mesh + ) + else: + # Boltz-1, or Boltz-2 with transformer_post_layer_norm=False: + # Sequential(LayerNorm, LinearNoBias) + self.atom_feat_to_atom_pos_update = nn.Sequential( + LayerNormParamsReplicated(layer.atom_feat_to_atom_pos_update[0], device_mesh=device_mesh), + LinearParamsReplicated(layer_local=layer.atom_feat_to_atom_pos_update[1], device_mesh=device_mesh), + ) + + def forward( + self, + a: DTensor, + q: DTensor, + c: DTensor, + p: DTensor, + feats: dict[str, DTensor], + multiplicity: int = 1, + model_cache: dict[str, dict[str, DTensor]] | None = None, + ) -> DTensor: + """Forward pass for the DTensor-distributed atom attention decoder. + + All tensors use device mesh (dp, cp_axis_0, cp_axis_1). + Placements: Shard(0)=dp batch, Shard(1)=cp atom/token axis, Replicate()=cp_axis_1. + + Parameters + ---------- + a : DTensor + Token representation, shape (B * M, N_tokens, 2 * token_s). + Placements: (Shard(0), Shard(1), Replicate()). + q : DTensor + Atom query representation, shape (B * M, N_atoms_packed, atom_s). + Placements: (Shard(0), Shard(1), Replicate()). + c : DTensor + Atom conditioning representation, shape (B * M, N_atoms_packed, atom_s). + Placements: (Shard(0), Shard(1), Replicate()). + p : DTensor + Pair representation / pre-computed bias in window-batched format. + - Boltz-1: shape (B, K, W, H, c_z) + - Boltz-2: shape (B, K, W, H, num_heads * depth) + Placements: (Shard(0), Shard(1), Replicate()). + feats : dict[str, DTensor] + Features dict containing: + - "atom_pad_mask": (B, N_atoms_packed), placements per atom_features config + - "atom_to_token_ids_global": (B, N_atoms_packed), placements per atom_features config + multiplicity : int, optional + Number of diffusion samples, by default 1. + model_cache : dict or None, optional + Model cache for inference optimization (V1 internalized path only). + + Returns + ------- + DTensor + Position updates, shape (B * M, N_atoms_packed, 3). + Placements: (Shard(0), Shard(1), Replicate()). + """ + if model_cache is not None and not self.could_use_model_cache: + raise ValueError("model_cache is only supported with V1 AtomAttentionDecoder") + + W = self.attn_window_queries + N = q.shape[1] # N_atoms_packed + + if N % W != 0: + raise ValueError( + f"Packed atom sequence length N={N} must be divisible by window size W={W} " + f"for window batching, but N % W = {N % W}" + ) + + K = N // W + + # Get atom mask (without multiplicity -- matches p's batch dim) + atom_mask = feats["atom_pad_mask"].bool() + + # Get global atom-to-token indices + atom_to_token_ids_global = feats["atom_to_token_ids_global"] + + # Apply multiplicity to indices and mask for gather (must match a's batch dim B*M) + atom_mask_mul = shardwise_repeat_interleave(atom_mask, multiplicity, 0) + atom_to_token_ids_global = shardwise_repeat_interleave(atom_to_token_ids_global, multiplicity, 0) + + # Unflatten to window view for gather: (B*M, N) -> (B*M, K, W) + atom_mask_q = shardwise_unflatten_sharded(atom_mask_mul, axis=1, sizes=(K, W)) + atom_to_token_ids_global_q = shardwise_unflatten_sharded(atom_to_token_ids_global, axis=1, sizes=(K, W)) + + # a_to_q transform and gather with autocast disabled (matching Boltz-2 serial behavior + # at encodersv2.py:544 which uses torch.autocast("cuda", enabled=False) for numerical + # precision in the atom-to-token gather operation) + with torch.autocast("cuda", enabled=False): + # (B*M, N_tokens, 2*token_s) -> (B*M, N_tokens, atom_s) + a_to_q = self.a_to_q_trans(a) + + # Gather from token to atom: (B*M, N_tokens, atom_s) -> (B*M, K, W, atom_s) + # Equivalent to torch.bmm(atom_to_token, a_to_q) in serial code + a_to_q = distributed_gather( + a_to_q, atom_to_token_ids_global_q, axis=1, are_ids_contiguous=True, idx_mask=atom_mask_q + ) + + # Flatten back to atom sequence: (B*M, K, W, atom_s) -> (B*M, N, atom_s) + a_to_q = shardwise_flatten_sharded(a_to_q, start_dim=1, end_dim=2) + + # Add to q + q = elementwise_op(q, a_to_q, ElementwiseOp.SUM) + + # V1 model_cache for decoder transformer + layer_cache = None + if model_cache is not None: + cache_prefix = "atomdecoder" + if cache_prefix not in model_cache: + model_cache[cache_prefix] = {} + layer_cache = model_cache[cache_prefix] + + # Call AtomTransformer with window batching + # multiplicity=1 because multiplicity is already applied to q, c + # mask should NOT have multiplicity applied (to match p's batch dim) + q = self.atom_decoder( + q=q, + c=c, + p=p, + mask=atom_mask, # NO multiplicity - matches p's batch dim + multiplicity=1, # multiplicity already applied to q, c + model_cache=layer_cache, + pair_mask=None, # window batching doesn't use pair_mask + ) + + r_update = self.atom_feat_to_atom_pos_update(q) + return r_update + + +class _MaskPaddingAtoms(torch.autograd.Function): + """Zero out atom representations at padding positions. + + Intersperse padding creates atoms with zero features; ``embed_atom_features`` + maps them to the layer bias, leaking non-zero values into pair features + (``c_to_p_trans_q/k``) and downstream attention. This function multiplies + ``c`` (or any atom-level DTensor with ``(Shard(0), Shard(1), Replicate())``) + by the binary ``atom_pad_mask``, broadcasting along the feature dim. + + The mask is non-differentiable — backward simply re-applies the mask to + the upstream gradient. No collectives are issued. + + Parameters + ---------- + c : DTensor + Atom representations, shape ``(B, N_atoms, D)``. + atom_pad_mask : DTensor + Binary mask, shape ``(B, N_atoms)``. 1 = real atom, 0 = padding. + """ + + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward(ctx, c: DTensor, atom_pad_mask: DTensor) -> DTensor: + if not isinstance(c, DTensor): + raise TypeError(f"c must be a DTensor, got {type(c)}") + if not isinstance(atom_pad_mask, DTensor): + raise TypeError(f"atom_pad_mask must be a DTensor, got {type(atom_pad_mask)}") + if c.device_mesh != atom_pad_mask.device_mesh: + raise ValueError("c and atom_pad_mask must share the same device_mesh") + + c_local = c.to_local() + mask_local = atom_pad_mask.to_local().to(c_local.dtype).unsqueeze(-1) + result_local = c_local * mask_local + + ctx.save_for_backward(mask_local) + ctx.placements = list(c.placements) + ctx.device_mesh = c.device_mesh + ctx.shape_c = c.shape + ctx.stride_c = c.stride() + + return DTensor.from_local( + result_local, + c.device_mesh, + list(c.placements), + shape=c.shape, + stride=c.stride(), + ) + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_output: DTensor): + (mask_local,) = ctx.saved_tensors + grad_local = grad_output.to_local() * mask_local + grad_c = DTensor.from_local( + grad_local, + ctx.device_mesh, + ctx.placements, + shape=ctx.shape_c, + stride=ctx.stride_c, + ) + return grad_c, None + + +def _mask_padding_atoms(c: DTensor, atom_pad_mask: DTensor) -> DTensor: + """Zero out representations at padding atom positions. See :class:`_MaskPaddingAtoms`.""" + return _MaskPaddingAtoms.apply(c, atom_pad_mask) + + +def _atom_encoder( + c: DTensor, + embed_atompair_ref_pos: nn.Module, + embed_atompair_ref_dist: nn.Module, + embed_atompair_mask: nn.Module, + s_to_c_trans: nn.Module | None, + z_to_p_trans: nn.Module | None, + c_to_p_trans_q: nn.Module, + c_to_p_trans_k: nn.Module, + p_mlp: nn.Module, + feats: dict[str, DTensor], + s_trunk: DTensor | None, + z: DTensor | None, + structure_prediction: bool, + W: int, + H: int, +) -> tuple[DTensor, DTensor, DTensor]: + """Shared DTensor pair computation for atom encoding (window batching). + + Migrated from V1x DTensor src_v1/boltz/distributed/model/modules/encoders.py + lines 776-893. Takes pre-embedded c and performs window-batched pair computation. + + This function does NOT manage autocast context. The caller is responsible for + wrapping the call in torch.autocast("cuda", enabled=False) when used from the + V2 AtomEncoder path (matching V2 serial encodersv2.py:297). The V1 path + (AtomAttentionEncoder) should NOT disable autocast, matching V1 serial behavior. + + Parameters + ---------- + c : DTensor + Pre-embedded atom single representation, shape (B, N_atoms, atom_s). + Placements: (Shard(0), Shard(1), Replicate()). + embed_atompair_ref_pos : nn.Module + Linear for pair position embedding. + embed_atompair_ref_dist : nn.Module + Linear for pair distance embedding. + embed_atompair_mask : nn.Module + Linear for pair mask embedding. + s_to_c_trans : nn.Module or None + Sequential(LayerNorm, Linear) for token-to-atom conditioning. None if not structure_prediction. + z_to_p_trans : nn.Module or None + Sequential(LayerNorm, Linear) for pair token-to-atom conditioning. None if not structure_prediction. + c_to_p_trans_q : nn.Module + Sequential(ReLU, Linear) for query pair contribution. + c_to_p_trans_k : nn.Module + Sequential(ReLU, Linear) for key pair contribution. + p_mlp : nn.Module + Sequential MLP for pair representation refinement. + feats : dict[str, DTensor] + Must contain: atom_pad_mask, ref_pos, ref_space_uid, atom_to_token_ids_global. + s_trunk : DTensor or None + Token single representation (B, N_tokens, token_s). + Placements: (Shard(0), Shard(1), Replicate()). None if not structure_prediction. + z : DTensor or None + Token pair representation (B, N_tokens, N_tokens, token_z). + Placements: (Shard(0), Shard(1), Shard(2)). None if not structure_prediction. + structure_prediction : bool + Whether to apply token-to-atom conditioning. + W : int + Atoms per window for queries. + H : int + Atoms per window for keys. + + Returns + ------- + tuple[DTensor, DTensor, DTensor] + (q, c, p): + - q: initial c before conditioning, shape (B, N_atoms, atom_s). + Placements: (Shard(0), Shard(1), Replicate()). + - c: s_to_c-conditioned atom representation, shape (B, N_atoms, atom_s). + Placements: (Shard(0), Shard(1), Replicate()). + - p: atom pair representation, shape (B, K, W, H, atom_z) where K = N_atoms // W. + Placements: (Shard(0), Shard(1), Replicate()). + """ + # Sanity checks: structure_prediction-dependent arguments must be consistently None or non-None. + # When structure_prediction=True, s_to_c_trans, z_to_p_trans, s_trunk, z are all required. + # When structure_prediction=False, they must all be None. + if structure_prediction: + if s_to_c_trans is None or z_to_p_trans is None: + raise ValueError("structure_prediction=True requires s_to_c_trans and z_to_p_trans to be provided.") + if s_trunk is None or z is None: + raise ValueError("structure_prediction=True requires s_trunk and z to be provided.") + else: + if s_to_c_trans is not None or z_to_p_trans is not None: + raise ValueError("structure_prediction=False but s_to_c_trans or z_to_p_trans was provided.") + if s_trunk is not None or z is not None: + raise ValueError("structure_prediction=False but s_trunk or z was provided.") + + N = c.shape[1] + if N % W != 0: + raise ValueError( + f"Sequence length N={N} must be divisible by window size W={W} for window batching, but N % W = {N % W}" + ) + K = N // W + + # Mimic serial code's .float() casts: promote to at least float32 for + # numerical stability, but preserve higher precision (e.g. float64) if available. + compute_dtype = torch.promote_types(c.dtype, torch.float32) + + atom_ref_pos = feats["ref_pos"] + atom_mask_bool = feats["atom_pad_mask"].bool() + atom_uid = feats["ref_space_uid"] + + # Convert atom_ref_pos to query/key views + atom_ref_pos_q = shardwise_unflatten_sharded(atom_ref_pos, axis=1, sizes=(K, W)) + atom_ref_pos_k = convert_single_repr_to_window_batched_key(atom_ref_pos, W, H) + + # Compute distance: d = keys - queries + # shardwise_outer_op computes lhs - rhs = queries - keys, so negate + d = shardwise_outer_op(atom_ref_pos_q, atom_ref_pos_k, axis=2, op=ShardwiseOuterOp.SUBTRACT) + d = scalar_tensor_op(-1.0, d, ElementwiseOp.PROD) + d_norm = shardwise_sum(elementwise_op(d, d, ElementwiseOp.PROD), dim=-1, keepdim=True) + d_norm = scalar_tensor_op(1.0, scalar_tensor_op(1.0, d_norm, ElementwiseOp.SUM), ElementwiseOp.DIV) + + # Compute validity mask + atom_mask_q = shardwise_unflatten_sharded(atom_mask_bool, axis=1, sizes=(K, W)) + atom_mask_k = convert_single_repr_to_window_batched_key(atom_mask_bool, W, H) + atom_uid_q = shardwise_unflatten_sharded(atom_uid, axis=1, sizes=(K, W)) + atom_uid_k = convert_single_repr_to_window_batched_key(atom_uid, W, H) + + # v = (atom_mask_q & atom_mask_k & (atom_uid_q == atom_uid_k)) + mask_and = shardwise_outer_op(atom_mask_q, atom_mask_k, axis=2, op=ShardwiseOuterOp.LOGICAL_AND) + uid_eq = shardwise_outer_op(atom_uid_q, atom_uid_k, axis=2, op=ShardwiseOuterOp.EQUAL) + v = elementwise_op(mask_and, uid_eq, ElementwiseOp.BITAND) + # Serial: (...).float().unsqueeze(-1) -- use compute_dtype instead of .float() + v = shardwise_unsqueeze(v, -1).to(compute_dtype) + + # Compute pair representation p: (B, K, W, H, atom_z) + # TODO: this DTensor native broadcasting elementwise multiplication should be + # replaced with custom autograd.Function + p = embed_atompair_ref_pos(d) * v + p = elementwise_op(p, embed_atompair_ref_dist(d_norm) * v, ElementwiseOp.SUM) + p = elementwise_op(p, embed_atompair_mask(v) * v, ElementwiseOp.SUM) + + q = c + + if structure_prediction: + atom_to_token_ids_global = feats["atom_to_token_ids_global"] + atom_to_token_ids_global_q = shardwise_unflatten_sharded(atom_to_token_ids_global, axis=1, sizes=(K, W)) + + # Token-to-atom gather: s_trunk -> s_to_c -> gather to atom positions + # Serial: s_to_c_trans(s_trunk.float()) -- use compute_dtype + s_to_c = s_to_c_trans(s_trunk.to(compute_dtype)) + s_to_c = distributed_gather( + s_to_c, atom_to_token_ids_global_q, axis=1, are_ids_contiguous=True, idx_mask=atom_mask_q + ) + s_to_c = shardwise_flatten_sharded(s_to_c, start_dim=1, end_dim=2) + # Serial: c = c + s_to_c.to(c) -- cast back to c's dtype + c = elementwise_op(c, s_to_c.to(c.dtype), ElementwiseOp.SUM) + + # Pair token-to-atom gather: z -> z_to_p -> outer gather + # Serial: z_to_p_trans(z.float()) -- use compute_dtype + z_to_p = z_to_p_trans(z.to(compute_dtype)) + atom_to_token_ids_global_k = convert_single_repr_to_window_batched_key(atom_to_token_ids_global, W, H) + z_to_p = distributed_outer_gather( + z_to_p, + atom_to_token_ids_global_q, + atom_to_token_ids_global_k, + axis=1, + are_ids_contiguous=True, + idx_n_mask=atom_mask_q, + idx_m_mask=atom_mask_k, + ) + # Serial: p = p + z_to_p.to(p) -- cast back to p's dtype + p = elementwise_op(p, z_to_p.to(p.dtype), ElementwiseOp.SUM) + + # c_to_p contributions in window-batched form + c_q = shardwise_unflatten_sharded(c, axis=1, sizes=(K, W)) + c_q = c_to_p_trans_q(c_q) + c_k = convert_single_repr_to_window_batched_key(c, W, H) + c_k = c_to_p_trans_k(c_k) + c_qk = shardwise_outer_op(c_q, c_k, axis=2, op=ShardwiseOuterOp.ADD) + + p = elementwise_op(p, c_qk, ElementwiseOp.SUM) + p = elementwise_op(p, p_mlp(p), ElementwiseOp.SUM) + + return q, c, p + + +class AtomEncoder(Module): + """DTensor AtomEncoder for V2 (window batching). + + Wraps V2 serial AtomEncoder. Flat attributes match V2 serial checkpoint keys. + Performs V2-specific feature construction then calls _atom_encoder() for + shared pair computation. + """ + + def __init__( + self, + layer: AtomEncoderBoltz2, + device_mesh: DeviceMesh, + ): + """Initialize the DTensor-distributed atom encoder. + + Parameters + ---------- + layer : AtomEncoderBoltz2 + The serial AtomEncoder module (V2) to be distributed. + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + """ + super().__init__() + + if not isinstance(layer, AtomEncoderBoltz2): + raise TypeError( + ", ".join( + [ + f"Instance {layer} should have type {AtomEncoderBoltz2}", + f"but instead has type {type(layer)}.", + ] + ) + ) + + # V2-specific feature construction config. + # use_residue_feats_atoms and use_atom_backbone_feat default to False and are never + # set to True in the Boltz-2 training/inference workflow (structurev2.yaml, boltz2.py). + # Fail early at init time rather than at forward time. + if layer.use_residue_feats_atoms: + raise NotImplementedError( + "DTensor AtomEncoder does not yet support use_residue_feats_atoms=True. " + "This requires DTensor-ified one_hot and token-to-atom gather for residue features. " + "No Boltz-2 workflow currently enables this option." + ) + if layer.use_atom_backbone_feat: + raise NotImplementedError( + "DTensor AtomEncoder does not yet support use_atom_backbone_feat=True. " + "No Boltz-2 workflow currently enables this option." + ) + self.use_no_atom_char = layer.use_no_atom_char + self.use_atom_backbone_feat = layer.use_atom_backbone_feat + self.use_residue_feats_atoms = layer.use_residue_feats_atoms + self.structure_prediction = layer.structure_prediction + self.atoms_per_window_queries = layer.atoms_per_window_queries + self.atoms_per_window_keys = layer.atoms_per_window_keys + validate_window_batching_parameters( + self.atoms_per_window_queries, self.atoms_per_window_keys, use_window_batching=True + ) + + # Atom feature embedding (V2 uses Linear with bias) + self.embed_atom_features = LinearParamsReplicated( + layer_local=layer.embed_atom_features, device_mesh=device_mesh + ) + + # Pair computation modules + self.embed_atompair_ref_pos = LinearParamsReplicated( + layer_local=layer.embed_atompair_ref_pos, device_mesh=device_mesh + ) + self.embed_atompair_ref_dist = LinearParamsReplicated( + layer_local=layer.embed_atompair_ref_dist, device_mesh=device_mesh + ) + self.embed_atompair_mask = LinearParamsReplicated( + layer_local=layer.embed_atompair_mask, device_mesh=device_mesh + ) + + if self.structure_prediction: + self.s_to_c_trans = nn.Sequential( + LayerNormParamsReplicated(layer.s_to_c_trans[0], device_mesh=device_mesh), + LinearParamsReplicated(layer_local=layer.s_to_c_trans[1], device_mesh=device_mesh), + ) + self.z_to_p_trans = nn.Sequential( + LayerNormParamsReplicated(layer.z_to_p_trans[0], device_mesh=device_mesh), + LinearParamsReplicated(layer_local=layer.z_to_p_trans[1], device_mesh=device_mesh), + ) + + self.c_to_p_trans_q = nn.Sequential( + nn.ReLU(), + LinearParamsReplicated(layer_local=layer.c_to_p_trans_q[1], device_mesh=device_mesh), + ) + self.c_to_p_trans_k = nn.Sequential( + nn.ReLU(), + LinearParamsReplicated(layer_local=layer.c_to_p_trans_k[1], device_mesh=device_mesh), + ) + self.p_mlp = nn.Sequential( + nn.ReLU(), + LinearParamsReplicated(layer_local=layer.p_mlp[1], device_mesh=device_mesh), + nn.ReLU(), + LinearParamsReplicated(layer_local=layer.p_mlp[3], device_mesh=device_mesh), + nn.ReLU(), + LinearParamsReplicated(layer_local=layer.p_mlp[5], device_mesh=device_mesh), + ) + + def forward( + self, + feats: dict[str, DTensor], + s_trunk: DTensor | None = None, + z: DTensor | None = None, + ) -> tuple[DTensor, DTensor, DTensor]: + """Forward pass for the DTensor-distributed atom encoder. + + Parameters + ---------- + feats : dict[str, DTensor] + Atom features including ref_pos, ref_charge, ref_element, atom_pad_mask, + ref_space_uid, atom_to_token_ids_global, and optionally ref_atom_name_chars. + Atom-level tensors have placements (Shard(0), Shard(1), Replicate()). + s_trunk : DTensor or None + Token single representation (B, N_tokens, token_s). + Placements: (Shard(0), Shard(1), Replicate()). None if not structure_prediction. + z : DTensor or None + Token pair representation (B, N_tokens, N_tokens, token_z). + Placements: (Shard(0), Shard(1), Shard(2)). None if not structure_prediction. + + Returns + ------- + tuple[DTensor, DTensor, DTensor] + (q, c, p): + - q: initial atom representation, shape (B, N_atoms, atom_s). + Placements: (Shard(0), Shard(1), Replicate()). + - c: conditioned atom representation, shape (B, N_atoms, atom_s). + Placements: (Shard(0), Shard(1), Replicate()). + - p: atom pair representation, shape (B, K, W, H, atom_z) where K = N_atoms // W. + Placements: (Shard(0), Shard(1), Replicate()). + """ + # V2 feature construction and pair computation run with autocast disabled, + # matching V2 serial AtomEncoder.forward() (encodersv2.py) which wraps + # the entire forward in torch.autocast("cuda", enabled=False). + with torch.autocast("cuda", enabled=False): + atom_feats_list = [ + feats["ref_pos"], + shardwise_unsqueeze(feats["ref_charge"], -1), + feats["ref_element"], + ] + if not self.use_no_atom_char: + # ref_atom_name_chars: (B, N, 4, 64) -> (B, N, 256) + # Dims 2,3 are replicated (not sharded), use shardwise_flatten for non-sharded dims + atom_feats_list.append(shardwise_flatten(feats["ref_atom_name_chars"], start_dim=2, end_dim=3)) + + atom_feats = shardwise_cat(atom_feats_list, dim=-1) + c = self.embed_atom_features(atom_feats) + c = _mask_padding_atoms(c, feats["atom_pad_mask"]) + + q, c, p = _atom_encoder( + c=c, + embed_atompair_ref_pos=self.embed_atompair_ref_pos, + embed_atompair_ref_dist=self.embed_atompair_ref_dist, + embed_atompair_mask=self.embed_atompair_mask, + s_to_c_trans=self.s_to_c_trans if self.structure_prediction else None, + z_to_p_trans=self.z_to_p_trans if self.structure_prediction else None, + c_to_p_trans_q=self.c_to_p_trans_q, + c_to_p_trans_k=self.c_to_p_trans_k, + p_mlp=self.p_mlp, + feats=feats, + s_trunk=s_trunk, + z=z, + structure_prediction=self.structure_prediction, + W=self.atoms_per_window_queries, + H=self.atoms_per_window_keys, + ) + return q, c, p + + +class AtomAttentionEncoder(Module): + """DTensor AtomAttentionEncoder for window batching. + + Compatible with both Boltz-1x and Boltz-2 serial AtomAttentionEncoder modules. + + V1 (internalized_AtomEncoder=True): + The V1 serial AtomAttentionEncoder is monolithic — it contains the AtomEncoder + logic (feature embedding, pair computation) within itself. This DTensor class + mirrors that by holding all embed/pair modules and calling _atom_encoder() in + forward to compute q/c/p from raw features. + + V2 (internalized_AtomEncoder=False): + The V2 architecture splits atom encoding into a separate AtomEncoder class. + This DTensor class receives pre-computed q/c/bias from the upstream DTensor + AtomEncoder and only holds r_to_q, AtomTransformer, and atom_to_token_trans. + + Common operations (both V1 and V2): + 1. Apply multiplicity via shardwise_repeat_interleave + 2. Apply r_to_q conditioning (if structure_prediction) + 3. Run AtomTransformer (window-batched) + 4. Atom-to-token scatter via distributed_scatter_reduce with autocast disabled + """ + + def __init__( + self, + layer: AtomAttentionEncoderBoltz1 | AtomAttentionEncoderBoltz2, + device_mesh: DeviceMesh, + ): + """Initialize the DTensor-distributed atom attention encoder. + + Parameters + ---------- + layer : AtomAttentionEncoderBoltz1 | AtomAttentionEncoderBoltz2 + The serial AtomAttentionEncoder module to be distributed. + Accepts both Boltz-1x and Boltz-2 versions. + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + + Raises + ------ + TypeError + If layer is not a recognized type. + """ + super().__init__() + + if not isinstance(layer, (AtomAttentionEncoderBoltz1, AtomAttentionEncoderBoltz2)): + raise TypeError( + ", ".join( + [ + f"Instance {layer} should have type " + f"{AtomAttentionEncoderBoltz1} or {AtomAttentionEncoderBoltz2}", + f"but instead has type {type(layer)}.", + ] + ) + ) + + # V1's serial AtomAttentionEncoder is monolithic: it internalizes the AtomEncoder + # logic (feature embedding + pair computation) within the same class. In contrast, + # V2 splits this into a separate AtomEncoder class whose outputs (q, c, bias) are + # passed into AtomAttentionEncoder from outside. + # + # When internalized_AtomEncoder is True (V1), this DTensor class holds all the + # embed/pair modules and calls _atom_encoder() in forward to compute q/c/p from + # raw features. When False (V2), q/c/bias are expected as pre-computed inputs + # from the upstream DTensor AtomEncoder. + self.internalized_AtomEncoder = isinstance(layer, AtomAttentionEncoderBoltz1) + self.structure_prediction = layer.structure_prediction + + # V1 only: holds all pair computation modules for _atom_encoder() + # (V2 delegates these to the separate DTensor AtomEncoder class) + if self.internalized_AtomEncoder: + self.atoms_per_window_queries = layer.atoms_per_window_queries + self.atoms_per_window_keys = layer.atoms_per_window_keys + validate_window_batching_parameters( + self.atoms_per_window_queries, self.atoms_per_window_keys, use_window_batching=True + ) + + # V1 atom feature embedding (LinearNoBias) + self.embed_atom_features = LinearParamsReplicated( + layer_local=layer.embed_atom_features, device_mesh=device_mesh + ) + + # Pair computation modules (same structure as DTensor AtomEncoder) + self.embed_atompair_ref_pos = LinearParamsReplicated( + layer_local=layer.embed_atompair_ref_pos, device_mesh=device_mesh + ) + self.embed_atompair_ref_dist = LinearParamsReplicated( + layer_local=layer.embed_atompair_ref_dist, device_mesh=device_mesh + ) + self.embed_atompair_mask = LinearParamsReplicated( + layer_local=layer.embed_atompair_mask, device_mesh=device_mesh + ) + + if self.structure_prediction: + self.s_to_c_trans = nn.Sequential( + LayerNormParamsReplicated(layer.s_to_c_trans[0], device_mesh=device_mesh), + LinearParamsReplicated(layer_local=layer.s_to_c_trans[1], device_mesh=device_mesh), + ) + self.z_to_p_trans = nn.Sequential( + LayerNormParamsReplicated(layer.z_to_p_trans[0], device_mesh=device_mesh), + LinearParamsReplicated(layer_local=layer.z_to_p_trans[1], device_mesh=device_mesh), + ) + + self.c_to_p_trans_q = nn.Sequential( + nn.ReLU(), + LinearParamsReplicated(layer_local=layer.c_to_p_trans_q[1], device_mesh=device_mesh), + ) + self.c_to_p_trans_k = nn.Sequential( + nn.ReLU(), + LinearParamsReplicated(layer_local=layer.c_to_p_trans_k[1], device_mesh=device_mesh), + ) + self.p_mlp = nn.Sequential( + nn.ReLU(), + LinearParamsReplicated(layer_local=layer.p_mlp[1], device_mesh=device_mesh), + nn.ReLU(), + LinearParamsReplicated(layer_local=layer.p_mlp[3], device_mesh=device_mesh), + nn.ReLU(), + LinearParamsReplicated(layer_local=layer.p_mlp[5], device_mesh=device_mesh), + ) + + # V2 only: r_to_q_trans (V1 also has it but under structure_prediction) + if self.structure_prediction: + self.r_to_q_trans = LinearParamsReplicated(layer_local=layer.r_to_q_trans, device_mesh=device_mesh) + + # Common: AtomTransformer (window batching) + self.atom_encoder = AtomTransformer(layer=layer.atom_encoder, device_mesh=device_mesh) + + # Common: atom_to_token_trans Sequential(LinearNoBias, ReLU) + # Strip nn.ReLU from serial sequential and wrap the Linear + self.atom_to_token_trans = nn.Sequential( + LinearParamsReplicated(layer_local=layer.atom_to_token_trans[0], device_mesh=device_mesh), + nn.ReLU(), + ) + + def forward( + self, + feats: dict[str, DTensor], + q: DTensor | None = None, + c: DTensor | None = None, + atom_enc_bias: DTensor | None = None, + s_trunk: DTensor | None = None, + z: DTensor | None = None, + r: DTensor | None = None, + multiplicity: int = 1, + model_cache: dict[str, dict[str, DTensor]] | None = None, + ) -> tuple[DTensor, DTensor, DTensor, DTensor]: + """Forward pass for the DTensor-distributed atom attention encoder. + + For V2: q, c, atom_enc_bias must be provided (from upstream DTensor AtomEncoder). + For V1: s_trunk, z are used to compute q, c, p from raw features via _atom_encoder(). + + Parameters + ---------- + feats : dict[str, DTensor] + Atom features. Must contain atom_pad_mask, atom_to_token_ids_global, + atom_to_token_local_onehot. For V1 also: ref_pos, ref_space_uid, + ref_charge, ref_element, ref_atom_name_chars. + q : DTensor or None + Pre-computed atom query representation (V2 only), shape (B, N_atoms, atom_s). + c : DTensor or None + Pre-computed atom conditioning representation (V2 only), shape (B, N_atoms, atom_s). + atom_enc_bias : DTensor or None + Pre-computed pairwise bias (V2 only), shape (B, K, W, H, num_heads * depth). + s_trunk : DTensor or None + Token single representation (V1 only), shape (B, N_tokens, token_s). + z : DTensor or None + Token pair representation (V1 only), shape (B, N_tokens, N_tokens, token_z). + r : DTensor or None + Atom positions for r_to_q conditioning (both V1 and V2), + shape (B * multiplicity, N_atoms, 3) for V2 or (B * multiplicity, N_atoms, 3) for V1. + multiplicity : int, optional + Number of diffusion samples, by default 1. + model_cache : dict or None, optional + Model cache for inference optimization (V1 internalized path only). + + Returns + ------- + tuple[DTensor, DTensor, DTensor, DTensor] + (a, q, c, p) where: + - a: Token representation, shape (B * multiplicity, N_tokens_local, token_s) + - q: Atom query after transformer, shape (B * multiplicity, N_atoms, atom_s) + - c: Atom conditioning, shape (B * multiplicity, N_atoms, atom_s) + - p: Pair representation / bias (B, K, W, H, atom_z or num_heads*depth) + """ + atom_mask = feats["atom_pad_mask"].bool() + + # model_cache is a V1-only feature (internalized path) + if model_cache is not None and not self.internalized_AtomEncoder: + raise ValueError("model_cache is only supported with internalized AtomEncoder (V1 path)") + + # V1 model_cache: cache q/c/p from _atom_encoder on first call + layer_cache = None + if model_cache is not None: + cache_prefix = "atomencoder" + if cache_prefix not in model_cache: + model_cache[cache_prefix] = {} + layer_cache = model_cache[cache_prefix] + + if self.internalized_AtomEncoder: + if model_cache is None or len(layer_cache) == 0: + # First call or no cache: compute q, c, p from raw features + atom_feats = shardwise_cat( + [ + feats["ref_pos"], + shardwise_unsqueeze(feats["ref_charge"], -1), + shardwise_unsqueeze(atom_mask, -1), + feats["ref_element"], + shardwise_flatten(feats["ref_atom_name_chars"], start_dim=2, end_dim=3), + ], + dim=-1, + ) + c = self.embed_atom_features(atom_feats) + c = _mask_padding_atoms(c, atom_mask) + + q, c, p = _atom_encoder( + c=c, + embed_atompair_ref_pos=self.embed_atompair_ref_pos, + embed_atompair_ref_dist=self.embed_atompair_ref_dist, + embed_atompair_mask=self.embed_atompair_mask, + s_to_c_trans=self.s_to_c_trans if self.structure_prediction else None, + z_to_p_trans=self.z_to_p_trans if self.structure_prediction else None, + c_to_p_trans_q=self.c_to_p_trans_q, + c_to_p_trans_k=self.c_to_p_trans_k, + p_mlp=self.p_mlp, + feats=feats, + s_trunk=s_trunk, + z=z, + structure_prediction=self.structure_prediction, + W=self.atoms_per_window_queries, + H=self.atoms_per_window_keys, + ) + + if model_cache is not None: + layer_cache["q"] = q + layer_cache["c"] = c + layer_cache["p"] = p + else: + # Subsequent calls: use cached q/c/p + q = layer_cache["q"] + c = layer_cache["c"] + p = layer_cache["p"] + else: + # V2: q, c, bias provided by upstream DTensor AtomEncoder + if q is None or c is None: + raise ValueError("V2 AtomAttentionEncoder requires pre-computed q and c from upstream AtomEncoder.") + p = atom_enc_bias # may be None if not structure_prediction + + # ==================================================================== + # Common: Apply multiplicity and r_to_q conditioning + # ==================================================================== + if self.structure_prediction: + q = shardwise_repeat_interleave(q, multiplicity, 0) + + if self.internalized_AtomEncoder: + # V1: r_to_q_trans takes (B*M, N, 10) = concat([r, zeros(B*M, N, 7)]) + r_zeros_shape = list(r.shape) + r_zeros_shape[-1] = 7 # replace last dim 3 → 7 + r_zeros_dt = dtensor_zeros( + r_zeros_shape, + device_mesh=r.device_mesh, + placements=r.placements, + dtype=r.dtype, + requires_grad=False, + ) + r_input = shardwise_cat([r, r_zeros_dt], dim=-1) + r_to_q = self.r_to_q_trans(r_input) + else: + # V2: r_to_q_trans takes (B*M, N, 3) directly + r_to_q = self.r_to_q_trans(r) + + q = elementwise_op(q, r_to_q, ElementwiseOp.SUM) + + c = shardwise_repeat_interleave(c, multiplicity, 0) + + # ==================================================================== + # Common: Run AtomTransformer (window batching) + # ==================================================================== + # multiplicity=1 because multiplicity is already applied to q, c + # mask should NOT have multiplicity applied (to match p's batch dim) + q = self.atom_encoder( + q=q, + c=c, + p=p, + mask=atom_mask, + multiplicity=1, # Multiplicity must be 1 as otherwise meaningless for DTensor window batching code + model_cache=None, + pair_mask=None, + ) + + # ==================================================================== + # Common: Atom-to-token scatter aggregation + # ==================================================================== + # Equivalent to serial's: atom_to_token_mean.T @ q_to_a (mean aggregation) + # Uses distributed_scatter_reduce with "mean" reduction instead of bmm. + with torch.autocast("cuda", enabled=False): + q_to_a = self.atom_to_token_trans(q) + + atom_to_token_ids_global = feats["atom_to_token_ids_global"] + atom_mask_bool = atom_mask.bool() + + # Apply multiplicity to scatter indices and mask + atom_to_token_ids_global_mul = shardwise_repeat_interleave(atom_to_token_ids_global, multiplicity, 0) + atom_mask_bool_mul = shardwise_repeat_interleave(atom_mask_bool, multiplicity, 0) + + # n_tokens_per_shard from atom_to_token_local_onehot (see plan notes) + n_tokens_per_shard = feats["atom_to_token_local_onehot"].to_local().shape[2] + + # Both serial and distributed now compute exact mean (sum / count). + # The serial code previously used biased mean: atom_to_token / (count + 1e-6), + # which was fixed to use count.clamp(min=1) for parity. + a = distributed_scatter_reduce( + n_tokens_per_shard, + 1, + atom_to_token_ids_global_mul, + q_to_a, + "mean", + idx_mask=atom_mask_bool_mul, + are_ids_contiguous=True, + ) + + return a, q, c, p + + +class FourierEmbedding(Module): + """DTensor FourierEmbedding for Context Parallelism. + + Wraps a frozen nn.Linear projection with LinearParamsReplicated. + Compatible with both Boltz-1x and Boltz-2 serial FourierEmbedding modules. + """ + + def __init__( + self, + layer: FourierEmbeddingBoltz1 | FourierEmbeddingBoltz2, + device_mesh: DeviceMesh, + ): + """Initialize the Fourier Embedding layer. + + Parameters + ---------- + layer : FourierEmbeddingBoltz1 | FourierEmbeddingBoltz2 + The serial Fourier embedding layer. + device_mesh : DeviceMesh + The device mesh. + + """ + super().__init__() + assert isinstance( + layer, (FourierEmbeddingBoltz1, FourierEmbeddingBoltz2) + ), f"Expected FourierEmbeddingBoltz1 or FourierEmbeddingBoltz2, got {type(layer)}" + self.device_mesh = device_mesh + self.proj = LinearParamsReplicated(layer.proj, self.device_mesh) + if self.proj.weight.requires_grad or self.proj.bias.requires_grad: + raise ValueError("Linear layer in FourierEmbedding should not have trainable parameters") + + def forward(self, times: DTensor) -> DTensor: + """Forward pass of the Fourier embedding layer. + + Parameters + ---------- + times : DTensor + The times tensor with shape (B,). + Placements: (Shard(0), Replicate(), Replicate()). + + Returns + ------- + DTensor + The Fourier embedding tensor with shape (B, D). + Placements: (Shard(0), Replicate(), Replicate()). + + """ + expected_placements = (Shard(0), Replicate(), Replicate()) + if times.placements != expected_placements: + raise ValueError(f"Times tensor has incorrect placements: {times.placements} != {expected_placements}") + + if times.ndim != 1: + raise ValueError(f"Times tensor should have shape (B,) but got {times.shape}") + + times = shardwise_unsqueeze(times, dim=1) + rand_proj = self.proj(times) + return single_tensor_op(2 * pi * rand_proj, ElementwiseOp.COS) + + +class SingleConditioning(Module): + """DTensor SingleConditioning for Context Parallelism. + + Compatible with both Boltz-1x and Boltz-2 serial SingleConditioning modules. + Handles V2's ``disable_times`` flag: when the serial module was constructed with + ``disable_times=True``, the fourier_embed / norm_fourier / fourier_to_single + child modules are absent, and the time-conditioning branch is skipped. + """ + + def __init__( + self, + layer: SingleConditioningBoltz1 | SingleConditioningBoltz2, + device_mesh: DeviceMesh, + ): + """Initialize the single conditioning layer with DTensor API. + + Parameters + ---------- + layer : SingleConditioningBoltz1 | SingleConditioningBoltz2 + The serial single conditioning layer. + device_mesh : DeviceMesh + The device mesh. + + """ + super().__init__() + assert isinstance( + layer, (SingleConditioningBoltz1, SingleConditioningBoltz2) + ), f"Expected SingleConditioningBoltz1 or SingleConditioningBoltz2, got {type(layer)}" + self.device_mesh = device_mesh + + self.norm_single = LayerNormParamsReplicated(layer.norm_single, self.device_mesh) + self.single_embed = LinearParamsReplicated(layer.single_embed, self.device_mesh) + + # V1 always has fourier time embedding; V2 conditionally creates it based on disable_times. + self.disable_times = getattr(layer, "disable_times", False) + if not self.disable_times: + self.fourier_embed = FourierEmbedding(layer.fourier_embed, self.device_mesh) + self.norm_fourier = LayerNormParamsReplicated(layer.norm_fourier, self.device_mesh) + self.fourier_to_single = LinearParamsReplicated(layer.fourier_to_single, self.device_mesh) + + self.transitions = nn.ModuleList([]) + for serial_transition in layer.transitions: + transition = Transition( + layer=serial_transition, + device_mesh=self.device_mesh, + ) + self.transitions.append(transition) + + def forward(self, times: DTensor, s_trunk: DTensor, s_inputs: DTensor) -> tuple[DTensor, DTensor | None]: + """Forward pass of the single conditioning layer. + + Parameters + ---------- + times : DTensor + The times tensor with shape (B,). + Placements: (Shard(0), Replicate(), Replicate()). + s_trunk : DTensor + The trunk single representation tensor with shape (B, N, D). + Placements: (Shard(0), Shard(1), Replicate()). + s_inputs : DTensor + The inputs single representation tensor with shape (B, N, D). + Placements: (Shard(0), Shard(1), Replicate()). + + Returns + ------- + tuple[DTensor, DTensor | None] + s : DTensor with shape (B, N, 2*D) and placements (Shard(0), Shard(1), Replicate()). + normed_fourier : DTensor with shape (B, D_fourier) and placements + (Shard(0), Replicate(), Replicate()), or None if disable_times. + + """ + expected_placements_times = (Shard(0), Replicate(), Replicate()) + if times.placements != expected_placements_times: + raise ValueError( + f"Times tensor has incorrect placements: {times.placements} != {expected_placements_times}" + ) + expected_placements_s = (Shard(0), Shard(1), Replicate()) + if s_trunk.placements != expected_placements_s: + raise ValueError( + f"s_trunk tensor has incorrect placements: {s_trunk.placements} != {expected_placements_s}" + ) + if s_inputs.placements != expected_placements_s: + raise ValueError( + f"s_inputs tensor has incorrect placements: {s_inputs.placements} != {expected_placements_s}" + ) + + s = shardwise_cat([s_trunk, s_inputs], dim=-1) + s = self.single_embed(self.norm_single(s)) + + normed_fourier: DTensor | None = None + if not self.disable_times: + fourier_embed = self.fourier_embed(times) + normed_fourier = self.norm_fourier(fourier_embed) + fourier_to_single = self.fourier_to_single(normed_fourier) + # fourier_to_single: (B, D) with (S(0), R, R) — broadcast-add to s: (B, N, 2D) with (S(0), S(1), R) + s = replicate_op(s, fourier_to_single, dim_to_unsqueeze_rhs=1, op=ReplicateOp.ADD) + + for transition in self.transitions: + s = elementwise_op(transition(s), s, ElementwiseOp.SUM) + + return s, normed_fourier + + +class PairwiseConditioning(Module): + """DTensor PairwiseConditioning for Context Parallelism. + + Compatible with both Boltz-1x and Boltz-2 serial PairwiseConditioning modules. + """ + + def __init__( + self, + layer: PairwiseConditioningBoltz1 | PairwiseConditioningBoltz2, + device_mesh: DeviceMesh, + ): + """Initialize the pairwise conditioning layer with DTensor API. + + Parameters + ---------- + layer : PairwiseConditioningBoltz1 | PairwiseConditioningBoltz2 + The serial pairwise conditioning layer. + device_mesh : DeviceMesh + The device mesh. + + """ + super().__init__() + assert isinstance( + layer, (PairwiseConditioningBoltz1, PairwiseConditioningBoltz2) + ), f"Expected PairwiseConditioningBoltz1 or PairwiseConditioningBoltz2, got {type(layer)}" + self.device_mesh = device_mesh + + self.dim_pairwise_init_proj = nn.Sequential( + LayerNormParamsReplicated(layer.dim_pairwise_init_proj[0], self.device_mesh), + LinearParamsReplicated(layer.dim_pairwise_init_proj[1], self.device_mesh), + ) + + self.transitions = nn.ModuleList([]) + for serial_transition in layer.transitions: + transition = Transition( + layer=serial_transition, + device_mesh=self.device_mesh, + ) + self.transitions.append(transition) + + def forward(self, z_trunk: DTensor, token_rel_pos_feats: DTensor) -> DTensor: + """Forward pass of the pairwise conditioning layer. + + Parameters + ---------- + z_trunk : DTensor + The trunk pair representation tensor with shape (B, N, N, D). + Placements: (Shard(0), Shard(1), Shard(2)). + token_rel_pos_feats : DTensor + The token relative position features tensor with shape (B, N, N, D_rel). + Placements: (Shard(0), Shard(1), Shard(2)). + + Returns + ------- + DTensor + The conditioned pair representation tensor with shape (B, N, N, D). + Placements: (Shard(0), Shard(1), Shard(2)). + + """ + expected_placements = (Shard(0), Shard(1), Shard(2)) + if z_trunk.placements != expected_placements: + raise ValueError(f"z_trunk tensor has incorrect placements: {z_trunk.placements} != {expected_placements}") + if token_rel_pos_feats.placements != expected_placements: + raise ValueError( + f"token_rel_pos_feats tensor has incorrect placements:" + f" {token_rel_pos_feats.placements} != {expected_placements}" + ) + + z = shardwise_cat([z_trunk, token_rel_pos_feats], dim=-1) + z = self.dim_pairwise_init_proj(z) + + for transition in self.transitions: + z = elementwise_op(transition(z), z, ElementwiseOp.SUM) + + return z diff --git a/src/boltz/distributed/model/modules/transformers.py b/src/boltz/distributed/model/modules/transformers.py new file mode 100644 index 000000000..6784cff91 --- /dev/null +++ b/src/boltz/distributed/model/modules/transformers.py @@ -0,0 +1,864 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""DTensor-compatible transformer modules for Context Parallelism. + +Compatible with both Boltz-1x and Boltz-2 serial modules. Supports both +window-batching (AttentionPairBiasShardwise) and ring attention (AttentionPairBias), +dispatched via the ``ring_comm`` parameter at construction time. +""" + +from functools import partial +from typing import Callable, Union + +from torch import nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.nn import Module, ModuleList +from torch.utils.checkpoint import checkpoint + +from boltz.distributed.comm import AttentionPairBiasComm +from boltz.distributed.model.layers.attention import AttentionPairBias as AttentionPairBiasRing +from boltz.distributed.model.layers.attention import AttentionPairBiasShardwise +from boltz.distributed.model.layers.cat_and_chunk import shardwise_chunk +from boltz.distributed.model.layers.elementwise_op import ElementwiseOp, elementwise_op +from boltz.distributed.model.layers.flatten_and_unflatten import ( + shardwise_flatten_sharded, + shardwise_unflatten_sharded, +) +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.linear import LinearParamsReplicated +from boltz.distributed.model.layers.sigmoid_gate import sigmoid_gate +from boltz.distributed.model.layers.swiglu import SwiGLU as SwiGLUWithDTensor +from boltz.distributed.model.layers.utils import convert_single_repr_window_batched_query_to_key +from boltz.distributed.model.modules.utils import ( + extract_checkpointing_config, + get_cpu_offload_context, + validate_window_batching_parameters, +) +from boltz.model.layers.attention import AttentionPairBias as AttentionPairBiasBoltz1 +from boltz.model.modules.transformers import AdaLN as AdaLNBoltz1 +from boltz.model.modules.transformers import AtomTransformer as AtomTransformerBoltz1 +from boltz.model.modules.transformers import ConditionedTransitionBlock as ConditionedTransitionBlockBoltz1 +from boltz.model.modules.transformers import DiffusionTransformer as DiffusionTransformerBoltz1 +from boltz.model.modules.transformers import DiffusionTransformerLayer as DiffusionTransformerLayerBoltz1 +from boltz.model.modules.transformersv2 import AdaLN as AdaLNBoltz2 +from boltz.model.modules.transformersv2 import AtomTransformer as AtomTransformerBoltz2 +from boltz.model.modules.transformersv2 import ConditionedTransitionBlock as ConditionedTransitionBlockBoltz2 +from boltz.model.modules.transformersv2 import DiffusionTransformer as DiffusionTransformerBoltz2 +from boltz.model.modules.transformersv2 import DiffusionTransformerLayer as DiffusionTransformerLayerBoltz2 + + +class AdaLN(Module): + """Adaptive Layer Normalization for DTensor. + + Compatible with both Boltz-1x and Boltz-2 serial AdaLN modules. + + Both versions have identical child modules: a_norm (LayerNorm), + s_norm (LayerNorm), s_scale (Linear), s_bias (Linear, no bias). + """ + + def __init__( + self, + ada_layer_norm: nn.Module, + device_mesh: DeviceMesh, + ): + """Initialize the DTensor-distributed adaptive layer normalization. + + Parameters + ---------- + ada_layer_norm : nn.Module + The serial AdaLN module to be distributed. Accepts both + boltz.model.modules.transformers.AdaLN (Boltz-1x) and + boltz.model.modules.transformersv2.AdaLN (Boltz-2). + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + + Raises + ------ + TypeError + If ada_layer_norm is not an instance of AdaLNBoltz1 or AdaLNBoltz2. + """ + super().__init__() + + # (1) Set non-module, non-parameter attributes + self.device_mesh: DeviceMesh = device_mesh + + # (2) Sanity checks + if not isinstance(ada_layer_norm, (AdaLNBoltz1, AdaLNBoltz2)): + raise TypeError( + ", ".join( + [ + f"Instance {ada_layer_norm} should have type {AdaLNBoltz1} or {AdaLNBoltz2}", + f"but instead has type {type(ada_layer_norm)}.", + ] + ) + ) + if not isinstance(self.device_mesh, DeviceMesh): + raise TypeError( + ", ".join( + [ + f"Instance device_mesh should have type {DeviceMesh}", + f"but instead has type {type(self.device_mesh)}.", + ] + ) + ) + + # (3) Initialize child modules explicitly + self.a_norm = LayerNormParamsReplicated(ada_layer_norm.a_norm, device_mesh=device_mesh) + self.s_norm = LayerNormParamsReplicated(ada_layer_norm.s_norm, device_mesh=device_mesh) + self.s_scale = LinearParamsReplicated(layer_local=ada_layer_norm.s_scale, device_mesh=device_mesh) + self.s_bias = LinearParamsReplicated(layer_local=ada_layer_norm.s_bias, device_mesh=device_mesh) + + def forward(self, a: DTensor, s: DTensor) -> DTensor: + """Forward pass for the DTensor-distributed adaptive layer normalization. + + All tensors use device mesh (dp, cp_axis_0, cp_axis_1). + Placements: (Shard(0), Shard(1), Replicate()) — batch over dp, sequence/window + index over cp_axis_0, features replicated over cp_axis_1. + + Parameters + ---------- + a : DTensor + The input tensor, shape (B, N, dim) or (B*M, K, W, dim) for window batching. + Placements: (Shard(0), Shard(1), Replicate()). + s : DTensor + The conditioning tensor, shape (B, N, dim_single_cond) or + (B*M, K, W, dim_single_cond) for window batching. + Placements: (Shard(0), Shard(1), Replicate()). + + Returns + ------- + DTensor + The output tensor, same shape and placements as a. + """ + a: DTensor = self.a_norm(a) + s: DTensor = self.s_norm(s) + + gate_input: DTensor = self.s_scale(s) + a: DTensor = sigmoid_gate(x=a, g=gate_input) + b: DTensor = self.s_bias(s) + c: DTensor = elementwise_op(a, b, op=ElementwiseOp.SUM) + + return c + + +class ConditionedTransitionBlock(Module): + """Conditioned Transition Block for DTensor. + + Compatible with both Boltz-1x and Boltz-2 serial ConditionedTransitionBlock modules. + + Both versions have identical child modules: adaln, swish_gate (Sequential of + LinearNoBias + SwiGLU), a_to_b, b_to_a, output_projection (Sequential of + Linear + Sigmoid). The Sigmoid is stripped and replaced by sigmoid_gate + in the forward pass for DTensor compatibility. + """ + + def __init__( + self, + conditioned_trans_block: nn.Module, + device_mesh: DeviceMesh, + ): + """Initialize the DTensor-distributed conditioned transition block. + + Parameters + ---------- + conditioned_trans_block : nn.Module + The serial ConditionedTransitionBlock module to be distributed. + Accepts both Boltz-1x and Boltz-2 versions. + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + + Raises + ------ + TypeError + If conditioned_trans_block is not a recognized type. + """ + super().__init__() + + if not isinstance( + conditioned_trans_block, (ConditionedTransitionBlockBoltz1, ConditionedTransitionBlockBoltz2) + ): + raise TypeError( + ", ".join( + [ + f"Instance {conditioned_trans_block} should have type " + f"{ConditionedTransitionBlockBoltz1} or {ConditionedTransitionBlockBoltz2}", + f"but instead has type {type(conditioned_trans_block)}.", + ] + ) + ) + + self.adaln = AdaLN( + ada_layer_norm=conditioned_trans_block.adaln, + device_mesh=device_mesh, + ) + self.swish_gate = nn.Sequential( + LinearParamsReplicated( + layer_local=conditioned_trans_block.swish_gate[0], + device_mesh=device_mesh, + ), + SwiGLUWithDTensor(), + ) + + self.a_to_b = LinearParamsReplicated( + layer_local=conditioned_trans_block.a_to_b, + device_mesh=device_mesh, + ) + self.b_to_a = LinearParamsReplicated( + layer_local=conditioned_trans_block.b_to_a, + device_mesh=device_mesh, + ) + + # Strip the sigmoid from output_projection - sigmoid operation is handled + # via sigmoid_gate in the forward pass for DTensor compatibility. + # Preserves the parameter initialization from the serial module. + self.output_projection = nn.Sequential( + LinearParamsReplicated( + layer_local=conditioned_trans_block.output_projection[0], + device_mesh=device_mesh, + ), + ) + + def forward( + self, + a: DTensor, + s: DTensor, + ) -> DTensor: + """Forward pass for the DTensor-distributed conditioned transition block. + + All tensors use placements (Shard(0), Shard(1), Replicate()) on mesh (dp, cp_axis_0, cp_axis_1). + + Parameters + ---------- + a : DTensor + The input tensor, shape (B, N, dim) or (B*M, K, W, dim). + Placements: (Shard(0), Shard(1), Replicate()). + s : DTensor + The conditioning tensor, shape (B, N, dim_single_cond) or + (B*M, K, W, dim_single_cond). + Placements: (Shard(0), Shard(1), Replicate()). + + Returns + ------- + DTensor + The output tensor, same shape and placements as a. + """ + a: DTensor = self.adaln(a, s) + c: DTensor = self.swish_gate(a) + b: DTensor = self.a_to_b(a) + b: DTensor = elementwise_op(c, b, op=ElementwiseOp.PROD) + a: DTensor = sigmoid_gate(x=self.b_to_a(b), g=self.output_projection[0](s)) + return a + + +class DiffusionTransformerLayer(Module): + """Diffusion Transformer Layer for DTensor. + + Compatible with both Boltz-1x and Boltz-2 serial DiffusionTransformerLayer modules. + + Supports two attention modes, dispatched by ``ring_comm``: + - **Window-batched** (``ring_comm=None``): Uses ``AttentionPairBiasShardwise``. + Input z is 5D ``(B, K, W, H, D)`` with ``(S(0), S(1), R)`` placements. + - **Ring attention** (``ring_comm`` provided): Uses ``AttentionPairBias``. + Input z is 4D ``(B, N, N, D)`` with ``(S(0), S(1), S(2))`` placements. + + Config flags (apply_initial_norm, compute_pair_bias, use_model_cache) are + auto-detected from the serial module's AttentionPairBias attributes for both + attention types. + """ + + def __init__( + self, + diff_transformer_layer: nn.Module, + device_mesh: DeviceMesh, + ring_comm: AttentionPairBiasComm | None = None, + ): + """Initialize the DTensor-distributed diffusion transformer layer. + + Parameters + ---------- + diff_transformer_layer : nn.Module + The serial DiffusionTransformerLayer module to be distributed. + Accepts both Boltz-1x and Boltz-2 versions. + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + ring_comm : AttentionPairBiasComm or None, optional + Ring communication object. When provided, uses ring attention + (AttentionPairBias); when None, uses window-batched attention + (AttentionPairBiasShardwise). Default None. + + Raises + ------ + TypeError + If diff_transformer_layer is not a recognized type. + """ + super().__init__() + + if not isinstance(diff_transformer_layer, (DiffusionTransformerLayerBoltz1, DiffusionTransformerLayerBoltz2)): + raise TypeError( + ", ".join( + [ + f"Instance {diff_transformer_layer} should have type " + f"{DiffusionTransformerLayerBoltz1} or {DiffusionTransformerLayerBoltz2}", + f"but instead has type {type(diff_transformer_layer)}.", + ] + ) + ) + + self.adaln = AdaLN( + ada_layer_norm=diff_transformer_layer.adaln, + device_mesh=device_mesh, + ) + + # Auto-detect V1/V2 config flags for the attention module + serial_attn = diff_transformer_layer.pair_bias_attn + is_boltz1 = isinstance(serial_attn, AttentionPairBiasBoltz1) + apply_initial_norm = getattr(serial_attn, "initial_norm", False) + compute_pair_bias = True if is_boltz1 else getattr(serial_attn, "compute_pair_bias", True) + use_model_cache = is_boltz1 + + if ring_comm is not None: + # Ring attention (all-to-all) — used for token-level transformer + self.pair_bias_attn = AttentionPairBiasRing( + attn_pair_bias=serial_attn, + device_mesh=device_mesh, + ring_comm=ring_comm, + apply_initial_norm=apply_initial_norm, + compute_pair_bias=compute_pair_bias, + use_model_cache=use_model_cache, + ) + else: + # Window-batched attention — used for atom-level transformer + self.pair_bias_attn = AttentionPairBiasShardwise( + attn_pair_bias=serial_attn, + device_mesh=device_mesh, + apply_initial_norm=apply_initial_norm, + compute_pair_bias=compute_pair_bias, + use_model_cache=use_model_cache, + ) + + # Track attention mode for forward dispatch + self.use_window_batching = isinstance(self.pair_bias_attn, AttentionPairBiasShardwise) + + # In DiffusionTransformerLayer, output_projection_linear is a class attribute. + # output_projection wraps it with Sigmoid, which is replaced by sigmoid_gate + # in the forward pass for DTensor compatibility. + self.output_projection_linear = LinearParamsReplicated( + layer_local=diff_transformer_layer.output_projection_linear, + device_mesh=device_mesh, + ) + self.output_projection = nn.Sequential(self.output_projection_linear) + + self.transition = ConditionedTransitionBlock( + conditioned_trans_block=diff_transformer_layer.transition, + device_mesh=device_mesh, + ) + + # Handle post_layer_norm (Boltz-2 only) + self.post_lnorm = None + if hasattr(diff_transformer_layer, "post_lnorm") and not isinstance( + diff_transformer_layer.post_lnorm, nn.Identity + ): + self.post_lnorm = LayerNormParamsReplicated( + diff_transformer_layer.post_lnorm, + device_mesh=device_mesh, + ) + + def forward( + self, + a: DTensor, + s: DTensor, + z: DTensor, + mask: Union[DTensor, None] = None, + to_keys: Union[Callable[[DTensor], DTensor], None] = None, + multiplicity: int = 1, + layer_cache: Union[dict[str, dict[str, DTensor]], None] = None, + pair_mask: Union[DTensor, None] = None, + ) -> DTensor: + """Forward pass for the DTensor-distributed diffusion transformer layer. + + Supports two modes: + - Window-batched: a/s are 4D (B*M, K, W, D), z is 5D (B, K, W, H, D), + mask is 3D (B, K, W). Uses to_keys for query→key conversion. + - Ring attention: a/s are 3D (B*M, N, D), z is 4D (B, N, N, D), + mask is 2D (B, N). Uses multiplicity for batch expansion. + + Parameters + ---------- + a : DTensor + The input tensor. + s : DTensor + The conditioning tensor. + z : DTensor + The pair representation / pre-computed bias tensor. + mask : DTensor or None, optional + The mask tensor. + to_keys : Callable or None, optional + Function to transform tensors from query space to key space. + Used by AttentionPairBiasShardwise for window batching. + multiplicity : int, optional + The multiplicity (number of diffusion samples), by default 1. + layer_cache : dict or None, optional + Cache for storing projected z during diffusion rollout. + pair_mask : DTensor or None, optional + The pair mask tensor. + + Returns + ------- + DTensor + The output tensor, same shape and placements as a. + """ + b: DTensor = self.adaln(a, s) + + if self.use_window_batching: + if multiplicity != 1: + raise NotImplementedError( + "DiffusionTransformerLayer: window batching mode does not need multiplicity " + "but use memory-efficient algorithm to avoid having to explicitly apply multiplicity. " + "Multiplicity must be 1 in this mode." + ) + # Window-batched attention: to_keys converts query→key space + b: DTensor = self.pair_bias_attn( + s=b, + z=z, + mask=mask, + to_keys=to_keys, + model_cache=layer_cache, + ) + else: + # Ring attention: uses multiplicity and pair_mask + b: DTensor = self.pair_bias_attn( + s=b, + z=z, + mask=mask, + multiplicity=multiplicity, + model_cache=layer_cache, + pair_mask=pair_mask, + ) + + b: DTensor = sigmoid_gate(g=self.output_projection[0](s), x=b) + + # Residual connections + a: DTensor = elementwise_op(a, b, op=ElementwiseOp.SUM) + c: DTensor = self.transition(a, s) + a: DTensor = elementwise_op(a, c, op=ElementwiseOp.SUM) + + # Optional post layer norm (Boltz-2 only) + if self.post_lnorm is not None: + a = self.post_lnorm(a) + + return a + + +class DiffusionTransformer(Module): + """Multi-layer DiffusionTransformer for DTensor. + + Compatible with both Boltz-1x and Boltz-2 serial DiffusionTransformer modules. + + Key difference: Boltz-2 splits the bias across layers (last dim = num_heads * L), + while Boltz-1 passes the same z to all layers (each layer projects independently). + + Boltz-2's pair_bias_attn=False is not supported (dead code in serial). + """ + + def __init__( + self, + diff_transformer: nn.Module, + device_mesh: DeviceMesh, + ring_comm: AttentionPairBiasComm | None = None, + ): + """Initialize the DTensor-distributed multi-layer diffusion transformer. + + Parameters + ---------- + diff_transformer : nn.Module + The serial DiffusionTransformer module to be distributed. + Accepts both Boltz-1x and Boltz-2 versions. + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + ring_comm : AttentionPairBiasComm or None, optional + Ring communication object. When provided, uses ring attention; + when None, uses window-batched attention. Default None. + + Raises + ------ + TypeError + If diff_transformer is not a recognized type. + NotImplementedError + If Boltz-2 serial module has pair_bias_attn=False. + """ + super().__init__() + + if not isinstance(diff_transformer, (DiffusionTransformerBoltz1, DiffusionTransformerBoltz2)): + raise TypeError( + ", ".join( + [ + f"Instance {diff_transformer} should have type " + f"{DiffusionTransformerBoltz1} or {DiffusionTransformerBoltz2}", + f"but instead has type {type(diff_transformer)}.", + ] + ) + ) + + # Boltz-2: raise if pair_bias_attn=False (dead code in serial, not supported) + if isinstance(diff_transformer, DiffusionTransformerBoltz2): + if not getattr(diff_transformer, "pair_bias_attn", True): + raise NotImplementedError( + "DTensor DiffusionTransformer does not support pair_bias_attn=False. " + "This is dead code in the serial Boltz-2 implementation." + ) + + # Detect Boltz-2 bias-splitting mode: + # Boltz-2 DiffusionTransformer receives bias with last dim = num_heads * L + # and splits it across layers. Boltz-1 passes the same z to all layers. + self.split_bias_across_layers = isinstance(diff_transformer, DiffusionTransformerBoltz2) + + # Track attention mode for forward dispatch + self.use_window_batching = ring_comm is None + + # Detect activation checkpointing and CPU offloading. + # + # Boltz-1x: fairscale checkpoint_wrapper replaces each layer's forward + # method, so we inspect per-layer. + # + # Boltz-2: the parent DiffusionTransformer stores + # ``activation_checkpointing`` as a module-level attribute and handles + # it in its own ``forward()``. The individual layers are plain modules. + # We check the parent attribute as a fallback when per-layer detection + # yields False. + activation_checkpointing = set() + cpu_offloading = set() + for serial_layer in diff_transformer.layers: + has_ckpt, has_offload = extract_checkpointing_config(serial_layer) + activation_checkpointing.add(has_ckpt) + cpu_offloading.add(has_offload) + + if len(activation_checkpointing) > 1: + raise ValueError( + "All layers must have the same activation checkpointing configuration but got different values: ", + activation_checkpointing, + ) + if len(cpu_offloading) > 1: + raise ValueError( + "All layers must have the same CPU offloading configuration but got different values: ", + cpu_offloading, + ) + + layer_level_ckpt = activation_checkpointing.pop() if activation_checkpointing else False + parent_level_ckpt = getattr(diff_transformer, "activation_checkpointing", False) + self.activation_checkpointing = layer_level_ckpt or parent_level_ckpt + self.cpu_offloading = cpu_offloading.pop() if cpu_offloading else False + + self.layers = ModuleList( + [ + DiffusionTransformerLayer( + diff_transformer_layer=layer, + device_mesh=device_mesh, + ring_comm=ring_comm, + ) + for layer in diff_transformer.layers + ] + ) + + def forward( + self, + a: DTensor, + s: DTensor, + z: DTensor, + mask: Union[DTensor, None] = None, + to_keys: Union[Callable[[DTensor], DTensor], None] = None, + multiplicity: int = 1, + model_cache: Union[dict[str, dict[str, DTensor]], None] = None, + pair_mask: Union[DTensor, None] = None, + ) -> DTensor: + """Forward pass for the DTensor-distributed multi-layer diffusion transformer. + + Supports two modes: + - Window-batched: a/s 4D, z 5D, mask 3D, uses to_keys. + - Ring attention: a/s 3D, z 4D, mask 2D, uses multiplicity. + + Parameters + ---------- + a : DTensor + The input tensor. + s : DTensor + The conditioning tensor. + z : DTensor + The pair representation / pre-computed bias tensor. + mask : DTensor or None, optional + The mask tensor. + to_keys : Callable or None, optional + Function to transform tensors from query space to key space (window-batched). + multiplicity : int, optional + The multiplicity (number of diffusion samples), by default 1. + model_cache : dict or None, optional + Cache for storing projected z during diffusion rollout. + pair_mask : DTensor or None, optional + The pair mask tensor (ring attention only). + + Returns + ------- + DTensor + The output tensor, same shape and placements as a. + """ + if self.split_bias_across_layers and len(self.layers) > 1: + # Boltz-2: split z last dim across layers + L = len(self.layers) + if z.shape[-1] % L != 0: + raise ValueError( + f"Boltz-2 bias last dimension ({z.shape[-1]}) must be evenly divisible by " + f"the number of layers ({L}). The Boltz-2 architecture guarantees this because " + f"the upstream bias construction (DiffusionConditioning and InputEmbedder) " + f"produces z.shape[-1] = num_heads * depth by design." + ) + # Window-batched: z is 5D (B, K, W, H, heads*L) → L chunks of (B, K, W, H, heads) + # Ring attention: z is 4D (B, N, N, heads*L) → L chunks of (B, N, N, heads) + # Both modes use the same shardwise_chunk operation. + z_chunks = shardwise_chunk(z, chunks=L, dim=-1) + else: + z_chunks = None # Boltz-1: same z for all layers, or single layer + + for i, layer in enumerate(self.layers): + layer_cache = None + if model_cache is not None: + prefix_cache = "layer_" + str(i) + if prefix_cache not in model_cache: + model_cache[prefix_cache] = {} + layer_cache = model_cache[prefix_cache] + + z_i = z_chunks[i] if z_chunks is not None else z + + if self.activation_checkpointing and self.training: + if self.cpu_offloading: + with get_cpu_offload_context(optimized=True): + a = checkpoint( + layer, + a, + s, + z_i, + mask, + to_keys, + multiplicity, + layer_cache, + pair_mask, + use_reentrant=False, + ) + else: + a = checkpoint( + layer, + a, + s, + z_i, + mask, + to_keys, + multiplicity, + layer_cache, + pair_mask, + use_reentrant=False, + ) + else: + a = layer( + a, + s, + z_i, + mask=mask, + to_keys=to_keys, + multiplicity=multiplicity, + layer_cache=layer_cache, + pair_mask=pair_mask, + ) + return a + + +class AtomTransformer(Module): + """AtomTransformer for DTensor (window batching). + + Compatible with both Boltz-1x and Boltz-2 serial AtomTransformer modules. + + Reshapes single repr (B, N, D) -> window-batched (B, K, W, D) using + shardwise_unflatten_sharded, delegates to DiffusionTransformer, then + flattens back using shardwise_flatten_sharded. + + Unlike the serial version which flattens (B, K) -> (B*K), the DTensor version + keeps B and K as separate axes since both are sharded on the device mesh. + """ + + def __init__( + self, + layer: nn.Module, + device_mesh: DeviceMesh, + ): + """Initialize the DTensor-distributed atom transformer. + + Parameters + ---------- + layer : nn.Module + The serial AtomTransformer module to be distributed. + Accepts both Boltz-1x and Boltz-2 versions. + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + + Raises + ------ + TypeError + If layer is not a recognized type. + """ + super().__init__() + + if not isinstance(layer, (AtomTransformerBoltz1, AtomTransformerBoltz2)): + raise TypeError( + ", ".join( + [ + f"Instance {layer} should have type {AtomTransformerBoltz1} or {AtomTransformerBoltz2}", + f"but instead has type {type(layer)}.", + ] + ) + ) + + validate_window_batching_parameters(layer.attn_window_queries, layer.attn_window_keys, use_window_batching=True) + + self.attn_window_queries = layer.attn_window_queries + self.attn_window_keys = layer.attn_window_keys + self.diffusion_transformer = DiffusionTransformer( + diff_transformer=layer.diffusion_transformer, + device_mesh=device_mesh, + ) + + def forward( + self, + q: DTensor, + c: DTensor, + p: DTensor, + mask: Union[DTensor, None] = None, + multiplicity: int = 1, + model_cache: Union[dict[str, dict[str, DTensor]], None] = None, + pair_mask: Union[DTensor, None] = None, + to_keys: None = None, + ) -> DTensor: + """Forward pass for the DTensor-distributed atom transformer (window batching). + + All tensors use device mesh (dp, cp_axis_0, cp_axis_1). + Placements: Shard(0)=dp batch, Shard(1)=cp atom/window axis, Replicate()=cp_axis_1. + Internally reshapes q/c from (B*M, N, D) to (B*M, K, W, D) via shardwise_unflatten_sharded, + delegates to DiffusionTransformer, then flattens back. + + Parameters + ---------- + q : DTensor + Query single representation, shape (B*M, N, dim) where N = K * W. + Placements: (Shard(0), Shard(1), Replicate()). + c : DTensor + Conditioning single representation, shape (B*M, N, dim_single_cond). + Placements: (Shard(0), Shard(1), Replicate()). + p : DTensor + Pair representation in window-batched format. + - Boltz-1: shape (B, K, W, H, c_z) + - Boltz-2: shape (B, K, W, H, num_heads * depth) + Placements: (Shard(0), Shard(1), Replicate()). + mask : DTensor or None, optional + The mask tensor, shape (B, N) or (B*M, N). + Placements: (Shard(0), Shard(1), Replicate()). + multiplicity : int, optional + The multiplicity (number of diffusion samples), by default 1. + Must be 1 for window batching mode. + model_cache : dict or None, optional + Cache for storing projected z during diffusion rollout. + pair_mask : DTensor or None, optional + The pair mask tensor. Not supported in window batching mode. + to_keys : None, optional + Not used -- to_keys is constructed internally for window batching. + + Returns + ------- + DTensor + The output tensor, shape (B*M, N, dim). + Placements: (Shard(0), Shard(1), Replicate()). + """ + W = self.attn_window_queries + H = self.attn_window_keys + + if pair_mask is not None: + raise NotImplementedError("pair_mask is not supported in AtomTransformer window batching mode") + + if multiplicity != 1: + raise NotImplementedError( + "AtomTransformer window batching mode uses memory-efficient algorithm " + "to avoid having to explicitly apply multiplicity. Multiplicity must be 1 in this mode." + ) + + if q.shape[1] % W != 0: + raise ValueError(f"q.shape[1] must be divisible by W, but got q.shape[1]={q.shape[1]} and W={W}") + + if c.shape[1] != q.shape[1]: + raise ValueError( + f"c.shape[1] must be equal to q.shape[1], but got c.shape[1]={c.shape[1]} and q.shape[1]={q.shape[1]}" + ) + + if mask is not None and mask.shape[1] != q.shape[1]: + raise ValueError( + f"mask.shape[1] must be equal to q.shape[1], " + f"but got mask.shape[1]={mask.shape[1]} and q.shape[1]={q.shape[1]}" + ) + + B, N, D = q.shape + K = N // W + + # NOTE: p is already in shape (B, K, W, H, D_z) + if p.ndim != 5: + raise ValueError(f"p must have 5 dimensions, but got p.ndim={p.ndim}") + + if p.shape[1:-1] != (K, W, H): + raise ValueError(f"p.shape[1:-1] must be (K, W, H) = {(K, W, H)}, but got p.shape[1:-1]={p.shape[1:-1]}") + + if B % p.shape[0] != 0: + raise ValueError(f"B must be divisible by p.shape[0], but got B={B} and p.shape[0]={p.shape[0]}") + + # Reshape the single repr into window-batched query view: + # (B, N, D) -> (B, K, W, D) + # Unlike the serial version, we don't flatten the resulting (B, K) axes + # since both of them are sharded on the device mesh. + q = shardwise_unflatten_sharded(q, axis=1, sizes=(K, W)) + c = shardwise_unflatten_sharded(c, axis=1, sizes=(K, W)) + if mask is not None: + mask = shardwise_unflatten_sharded(mask, axis=1, sizes=(K, W)) + + to_keys_new = partial( + convert_single_repr_window_batched_query_to_key, W=self.attn_window_queries, H=self.attn_window_keys + ) + + # Main transformer + q = self.diffusion_transformer( + q, + c, + p, + to_keys=to_keys_new, + mask=mask, + multiplicity=multiplicity, + model_cache=model_cache, + pair_mask=pair_mask, + ) + + # Flatten the window-batched query view back to the original single repr view: + # (B, K, W, D) -> (B, N, D) + q = shardwise_flatten_sharded(q, start_dim=1, end_dim=2) + + return q diff --git a/src/boltz/distributed/model/modules/trunkv2.py b/src/boltz/distributed/model/modules/trunkv2.py new file mode 100644 index 000000000..d7d28a744 --- /dev/null +++ b/src/boltz/distributed/model/modules/trunkv2.py @@ -0,0 +1,773 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""DTensor-based implementation of Boltz-2 trunk modules. + +This module provides distributed implementations of the MSAModule, MSALayer, +InputEmbedder, DistogramModule, BFactorModule, and ContactConditioning for Boltz-2. + +MSAModule and MSALayer use PairformerNoSeq layers for triangle attention +and multiplication. DistogramModule handles 5D output with num_distograms. +BFactorModule and ContactConditioning are Boltz-2–specific head modules. +InputEmbedder wraps the serial InputEmbedder for context parallelism using +distributed AtomEncoder and AtomAttentionEncoder. +""" + +from math import pi +from typing import Dict, Optional, Tuple + +import torch +from torch import nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor +from torch.utils.checkpoint import checkpoint + +from boltz.data import const +from boltz.distributed.comm import Ring2DComm, TransposeComm +from boltz.distributed.data.feature.featurizer import pack_atom_features +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.cat_and_chunk import shardwise_cat +from boltz.distributed.model.layers.clip import clip +from boltz.distributed.model.layers.dropout import apply_dropout_mask_msa_or_pair +from boltz.distributed.model.layers.elementwise_op import ElementwiseOp, elementwise_op +from boltz.distributed.model.layers.embedding import EmbeddingParamsReplicated +from boltz.distributed.model.layers.flatten_and_unflatten import shardwise_unflatten +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.distributed.model.layers.linear import LinearParamsReplicated +from boltz.distributed.model.layers.outer_product_mean import OuterProductMean +from boltz.distributed.model.layers.pair_averaging import PairWeightedAveraging, Ring2DCommPairAveraging +from boltz.distributed.model.layers.pairformer import PairformerNoSeqLayer +from boltz.distributed.model.layers.redistribute_transpose import redistribute_transpose +from boltz.distributed.model.layers.replicate_op import ReplicateOp, replicate_op +from boltz.distributed.model.layers.shardwise_op import shardwise_one_hot +from boltz.distributed.model.layers.squeeze import shardwise_unsqueeze +from boltz.distributed.model.layers.transition import Transition +from boltz.distributed.model.modules.encoders import AtomAttentionEncoder as DistAtomAttentionEncoder +from boltz.distributed.model.modules.encoders import AtomEncoder as DistAtomEncoder +from boltz.distributed.model.modules.utils import get_cpu_offload_context +from boltz.distributed.utils import update_exhaustive_strides +from boltz.model.modules.encodersv2 import FourierEmbedding as SerialFourierEmbedding +from boltz.model.modules.trunkv2 import BFactorModule as SerialBFactorModule +from boltz.model.modules.trunkv2 import ContactConditioning as SerialContactConditioning +from boltz.model.modules.trunkv2 import DistogramModule as SerialDistogramModule +from boltz.model.modules.trunkv2 import InputEmbedder as SerialInputEmbedder +from boltz.model.modules.trunkv2 import MSALayer as SerialMSALayer +from boltz.model.modules.trunkv2 import MSAModule as SerialMSAModule + + +class InputEmbedder(nn.Module): + """DTensor InputEmbedder for Boltz-2 using window-batching atom attention. + + Wraps the serial Boltz-2 ``InputEmbedder`` and distributes its sub-modules + for context parallelism. Atom-level operations delegate to the existing + distributed :class:`AtomEncoder` and :class:`AtomAttentionEncoder` (window- + batching variants) while token-level projections use + :class:`LinearParamsReplicated` / :class:`EmbeddingParamsReplicated`. + + Atom feature packing (via :func:`pack_atom_features`) is internalized in + :meth:`forward` so that callers only need to pass the raw distributed atom + features. The packed features are discarded after the forward pass. + """ + + _KEYS_ATOM_FEATURES_PACKED = { + "atom_pad_mask", + "ref_pos", + "ref_space_uid", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "atom_to_token", + } + + def __init__( + self, + module: SerialInputEmbedder, + device_mesh: DeviceMesh, + ) -> None: + """Initialize the distributed InputEmbedder for Boltz-2. + + Parameters + ---------- + module : SerialInputEmbedder + The serial InputEmbedder containing weights and configuration. + device_mesh : DeviceMesh + Device mesh defining the distributed computation topology. + """ + super().__init__() + + if not isinstance(module, SerialInputEmbedder): + raise TypeError(f"Expected SerialInputEmbedder, got {type(module)}") + + self.add_method_conditioning = module.add_method_conditioning + self.add_modified_flag = module.add_modified_flag + self.add_cyclic_flag = module.add_cyclic_flag + self.add_mol_type_feat = module.add_mol_type_feat + + # Atom-level modules -- delegate to existing distributed implementations + self.atom_encoder = DistAtomEncoder(layer=module.atom_encoder, device_mesh=device_mesh) + self.atoms_per_window_queries = self.atom_encoder.atoms_per_window_queries + + # atom_enc_proj_z: Sequential(LayerNorm, Linear) projects pair repr to + # attention bias. Operates on the last (replicated) dim of p so + # LayerNormParamsReplicated + LinearParamsReplicated work directly. + self.atom_enc_proj_z = nn.Sequential( + LayerNormParamsReplicated(module.atom_enc_proj_z[0], device_mesh=device_mesh), + LinearParamsReplicated(layer_local=module.atom_enc_proj_z[1], device_mesh=device_mesh), + ) + + self.atom_attention_encoder = DistAtomAttentionEncoder( + layer=module.atom_attention_encoder, device_mesh=device_mesh + ) + + # Token-level projections (replicated parameters, local ops) + self.res_type_encoding = LinearParamsReplicated(layer_local=module.res_type_encoding, device_mesh=device_mesh) + self.msa_profile_encoding = LinearParamsReplicated( + layer_local=module.msa_profile_encoding, device_mesh=device_mesh + ) + + # Optional conditioning modules + if self.add_method_conditioning: + self.method_conditioning_init = EmbeddingParamsReplicated( + module.method_conditioning_init, device_mesh=device_mesh + ) + if self.add_modified_flag: + self.modified_conditioning_init = EmbeddingParamsReplicated( + module.modified_conditioning_init, device_mesh=device_mesh + ) + if self.add_cyclic_flag: + self.cyclic_conditioning_init = LinearParamsReplicated( + layer_local=module.cyclic_conditioning_init, device_mesh=device_mesh + ) + if self.add_mol_type_feat: + self.mol_type_conditioning_init = EmbeddingParamsReplicated( + module.mol_type_conditioning_init, device_mesh=device_mesh + ) + + def forward(self, feats: dict[str, DTensor], affinity: bool = False) -> DTensor: + """Forward pass for the distributed InputEmbedder. + + Atom feature packing/unpacking is internalized: ``pack_atom_features`` + is called here to convert the raw distributed atom features into packed + format for window-batching. The packed features are discarded after + the atom encoder and atom attention encoder calls. + + Parameters + ---------- + feats : dict[str, DTensor] + Input features (token-level and atom-level). + affinity : bool, optional + When True, use ``profile_affinity`` / ``deletion_mean_affinity`` + instead of ``profile`` / ``deletion_mean``. Defaults to False. + + Returns + ------- + DTensor + Token-level single representation ``s`` with shape + ``(B, N_tokens, token_s)`` and placement + ``(Shard(0), Shard(1), Replicate())``. + """ + # Token-level features + # res_type is integer one-hot; cast to the layer's weight dtype so it + # works with any precision (float32, float64, bfloat16, etc.). + res_type = feats["res_type"].to(self.res_type_encoding.weight.dtype) + if affinity: + profile = feats["profile_affinity"] + deletion_mean = shardwise_unsqueeze(feats["deletion_mean_affinity"], -1) + else: + profile = feats["profile"] + deletion_mean = shardwise_unsqueeze(feats["deletion_mean"], -1) + + # Pack atom features for window batching. + # pack_atom_features converts pad_and_scatter output (with per-shard + # trailing padding) into packed format with global atom-to-token indices. + # The packed features are discarded after the forward pass. + feats_packed = pack_atom_features(feats, self._KEYS_ATOM_FEATURES_PACKED, self.atoms_per_window_queries) + + # Atom encoding: produces (q, c, p). + # AtomEncoder.forward() internally wraps its body in + # torch.autocast("cuda", enabled=False), matching the serial code. + q, c, p = self.atom_encoder(feats_packed) + + # Project pair representation to attention bias + atom_enc_bias = self.atom_enc_proj_z(p) + + # Atom attention encoding: produces (a, q, c, p) + a, _, _, _ = self.atom_attention_encoder( + feats=feats_packed, + q=q, + c=c, + atom_enc_bias=atom_enc_bias, + ) + + # Token-level embedding: sum of atom attention output + learned projections + profile_cat = shardwise_cat([profile, deletion_mean], dim=-1) + + s = elementwise_op(a, self.res_type_encoding(res_type), ElementwiseOp.SUM) + s = elementwise_op(s, self.msa_profile_encoding(profile_cat), ElementwiseOp.SUM) + + # Optional conditioning + if self.add_method_conditioning: + s = elementwise_op(s, self.method_conditioning_init(feats["method_feature"]), ElementwiseOp.SUM) + if self.add_modified_flag: + s = elementwise_op(s, self.modified_conditioning_init(feats["modified"]), ElementwiseOp.SUM) + if self.add_cyclic_flag: + cyclic = feats["cyclic_period"].to(self.cyclic_conditioning_init.weight.dtype) + cyclic = clip(cyclic, max_val=1.0) + cyclic = shardwise_unsqueeze(cyclic, -1) + s = elementwise_op(s, self.cyclic_conditioning_init(cyclic), ElementwiseOp.SUM) + if self.add_mol_type_feat: + s = elementwise_op(s, self.mol_type_conditioning_init(feats["mol_type"]), ElementwiseOp.SUM) + + return s + + +class MSALayer(nn.Module): + """Distributed MSA layer for Boltz-2 using DTensor. + + This is the Boltz-2 version of MSALayer which uses PairformerNoSeqLayer + for triangle operations instead of individual triangle multiplication and + attention layers. + + Input/Output Placements: + - z: (Shard(0), Shard(1), Shard(2)) - Pair representation + - m: (Shard(0), Shard(1), Shard(2)) - MSA representation + - token_mask: (Shard(0), Shard(1), Shard(2)) - Token pair mask + - msa_mask: (Shard(0), Shard(1), Shard(2)) - MSA mask + + Communication: + - PairWeightedAveraging: Ring communication for MSA-to-pair + - OuterProductMean: Ring communication for outer product + - PairformerNoSeqLayer: Triangle operations with ring communication + """ + + def __init__( + self, + layer: SerialMSALayer, + dist_manager: DistributedManager, + ) -> None: + """Initialize the distributed MSALayer for Boltz-2. + + Parameters + ---------- + layer : SerialMSALayer + The serial MSA layer containing weights and configuration to be distributed. + dist_manager : DistributedManager + Distributed manager defining the distributed computation topology and groups. + """ + super().__init__() + self.dist_manager = dist_manager + self.device_mesh = dist_manager.device_mesh_subgroups + + # Store dropout rates + self.msa_dropout = layer.msa_dropout + + # Create communication objects for distributed computation + ring_comm_2d_outer_product = Ring2DComm( + self.dist_manager.group["cp"], + self.dist_manager.subgroups["cp"][0], + self.dist_manager.layout_subgroups["cp"], + ) + + ## PWA Implementation with Ring2DCommPairAveraging + ring_comm_2d_pair_avg = Ring2DCommPairAveraging( + self.dist_manager.group["cp"], + self.dist_manager.subgroups["cp"][0], + self.dist_manager.layout_subgroups["cp"], + ) + self.pair_weighted_averaging = PairWeightedAveraging( + layer.pair_weighted_averaging, self.device_mesh, ring_comm_2d_pair_avg + ) + + # Map serial layers to distributed versions + self.msa_transition = Transition(layer.msa_transition, self.device_mesh) + + # Map PairformerNoSeqLayer to distributed version + self.pairformer_layer = PairformerNoSeqLayer(layer.pairformer_layer, dist_manager) + assert self.pairformer_layer.no_seq, ( + f"Expected no_seq=True for PairformerNoSeqLayer, " f"got no_seq={self.pairformer_layer.no_seq}" + ) + + self.outer_product_mean = OuterProductMean( + layer.outer_product_mean, self.device_mesh, ring_comm_2d_outer_product + ) + + def forward( + self, + z: DTensor, + m: DTensor, + token_mask: DTensor, + msa_mask: DTensor, + ) -> Tuple[DTensor, DTensor]: + """Perform the forward pass. + + Parameters + ---------- + z : DTensor + The pair representation with placement (Shard(0), Shard(1), Shard(2)) + m : DTensor + The MSA representation with placement (Shard(0), Shard(1), Shard(2)) + token_mask : DTensor + The token pair mask with placement (Shard(0), Shard(1), Shard(2)) + msa_mask : DTensor + The MSA mask with placement (Shard(0), Shard(1), Shard(2)) + + Returns + ------- + Tuple[DTensor, DTensor] + The updated pair representation and MSA representation. + """ + # Communication to MSA stack + m = elementwise_op( + m, + apply_dropout_mask_msa_or_pair( + self.pair_weighted_averaging(m, z, token_mask), self.msa_dropout, self.training + ), + ElementwiseOp.SUM, + ) + m = elementwise_op(m, self.msa_transition(m), ElementwiseOp.SUM) + + # Communication to pairwise stack via outer product + z = elementwise_op(z, self.outer_product_mean(m, msa_mask), ElementwiseOp.SUM) + + # Compute pairwise stack using PairformerNoSeqLayer + # Note: PairformerNoSeqLayer returns updated z directly (no residual connection needed) + z = self.pairformer_layer(z=z, pair_mask=token_mask) + + return z, m + + +class MSAModule(nn.Module): + """Distributed MSA module for Boltz-2 using DTensor. + + This is the Boltz-2 version of MSAModule which uses PairformerNoSeqLayer + for triangle operations within each MSA layer. + + Input/Output Placements: + - z: (Shard(0), Shard(1), Shard(2)) - Pair representation + - emb: (Shard(0), Replicate(), Shard(1)) - Single representation + - msa: (Shard(0), Shard(1), Shard(2)) - MSA sequences + - msa_mask: (Shard(0), Shard(1), Shard(2)) - MSA mask + - token_pair_pad_mask: (Shard(0), Shard(1), Shard(2)) - Pair mask + + Output: + - z: (Shard(0), Shard(1), Shard(2)) - Updated pair representation + + Communication: + - replicate_op: Broadcast emb to MSA dimension + - PairWeightedAveraging: Ring communication + - OuterProductMean: Ring communication + - PairformerNoSeqLayer: Triangle operations with ring communication + """ + + def __init__( + self, + module: SerialMSAModule, + dist_manager: DistributedManager, + cpu_offloading: bool = False, + ) -> None: + """Initialize the distributed MSAModule for Boltz-2. + + Parameters + ---------- + module : SerialMSAModule + The serial MSA module containing weights and configuration to be distributed. + dist_manager : DistributedManager + Distributed manager defining the distributed computation topology and groups. + cpu_offloading : bool, optional + Whether to offload checkpoint-boundary activations to CPU when + activation checkpointing is enabled. This is a distributed-only + option (the serial Boltz-2 MSAModule does not support it). + Defaults to False. + """ + super().__init__() + self.dist_manager = dist_manager + self.device_mesh = dist_manager.device_mesh_subgroups + + # Store attributes from the serial module + self.msa_blocks = module.msa_blocks + self.msa_dropout = module.msa_dropout + self.z_dropout = module.z_dropout + self.use_paired_feature = module.use_paired_feature + self.subsample_msa = module.subsample_msa + self.num_subsampled_msa = module.num_subsampled_msa + + # CP does not support MSA subsampling at module/layer level; require serial config to have it disabled. + if self.subsample_msa: + raise NotImplementedError( + "Subsampling MSA at module level is not supported with context parallelism. " + "The serial MSAModule must be built with subsample_msa=False." + ) + + # Activation checkpointing is read from the serial module. + self.activation_checkpointing = getattr(module, "activation_checkpointing", False) + # CPU offloading is a distributed-only option (the serial Boltz-2 MSAModule + # does not have this flag). When enabled together with activation + # checkpointing, checkpoint-boundary activations are moved to CPU during + # the forward pass and restored on the backward pass, trading extra + # CPU<->GPU transfers for reduced GPU memory. + self.cpu_offloading = cpu_offloading + + # Map serial projections to distributed versions + # Note: s_proj and msa_proj are LinearParamsReplicated since they operate on + # features that will be broadcast/replicated + self.s_proj = LinearParamsReplicated(module.s_proj, self.device_mesh) + self.msa_proj = LinearParamsReplicated(module.msa_proj, self.device_mesh) + + # Map MSA layers to distributed versions + self.layers = nn.ModuleList() + for serial_layer in module.layers: + self.layers.append(MSALayer(serial_layer, dist_manager)) + + def forward( + self, + z: DTensor, + emb: DTensor, + feats: Dict[str, DTensor], + ) -> DTensor: + """Perform the forward pass. + + Parameters + ---------- + z : DTensor + The pairwise embeddings with placement (Shard(0), Shard(1), Shard(2)) + emb : DTensor + The input embeddings with placement (Shard(0), Replicate(), Shard(1)) + feats : Dict[str, DTensor] + Input features as DTensors + + Returns + ------- + DTensor + The output pairwise embeddings. + """ + # Expected placements + expected_msa_placement = (Shard(0), Shard(1), Shard(2)) + expected_emb_placement = (Shard(0), Replicate(), Shard(1)) + + # Sanity check for emb placement + if emb.placements != expected_emb_placement: + raise ValueError(f"Expected emb placement {expected_emb_placement}, but got {emb.placements}") + + # Sanity check for z placement + if z.placements != expected_msa_placement: + raise ValueError(f"Expected z placement {expected_msa_placement}, but got {z.placements}") + + # Load relevant features – apply one-hot encoding to match the serial + # MSAModule (see src/boltz/model/modules/trunkv2.py), then cast from + # integer to the working dtype so it can be concatenated with the other + # floating-point features. + msa = feats["msa"] + msa = shardwise_one_hot(msa, num_classes=const.num_tokens).to(dtype=z.dtype) + has_deletion = shardwise_unsqueeze(feats["has_deletion"], -1) + deletion_value = shardwise_unsqueeze(feats["deletion_value"], -1) + msa_mask = feats["msa_mask"] + token_mask = feats["token_pair_pad_mask"] + + # Compute MSA embeddings + feats_to_cat = [msa, has_deletion, deletion_value] + if self.use_paired_feature: + is_paired = shardwise_unsqueeze(feats["msa_paired"], -1) + feats_to_cat.append(is_paired) + + # Sanity check for feature DTensor placements + for feat in feats_to_cat: + if feat.placements != expected_msa_placement: + raise ValueError(f"Expected MSA feature placement {expected_msa_placement}, but got {feat.placements}") + + # Concatenate MSA features + m = shardwise_cat(feats_to_cat, dim=-1) + + # Compute input projections + m = self.msa_proj(m) + emb_proj = self.s_proj(emb) + + # Use DTensor replicate_op to add emb to MSA + # emb_proj has placement (Shard(0), Replicate(), Shard(1)) + # m has placement (Shard(0), Shard(1), Shard(2)) + # We need to broadcast emb_proj along the MSA dimension (dim=1) + m = replicate_op(m, emb_proj, 1, op=ReplicateOp.ADD) + + # Perform MSA blocks. + # When activation_checkpointing is enabled, saved activations are recomputed + # during the backward pass. When cpu_offloading is additionally enabled, + # the checkpoint-boundary tensors are moved to CPU (module-level offloading) + # via get_cpu_offload_context, reducing GPU memory at the cost of extra + # CPU<->GPU transfers. + if self.activation_checkpointing and self.training: + if self.cpu_offloading: + with get_cpu_offload_context(optimized=True): + for i in range(self.msa_blocks): + z, m = checkpoint(self.layers[i], z, m, token_mask, msa_mask, use_reentrant=False) + else: + for i in range(self.msa_blocks): + z, m = checkpoint(self.layers[i], z, m, token_mask, msa_mask, use_reentrant=False) + else: + for i in range(self.msa_blocks): + z, m = self.layers[i](z, m, token_mask, msa_mask) + + return z + + +class DistogramModule(nn.Module): + """Distogram Module using DTensor for Boltz-2. + + This module wraps a serial DistogramModule and adds DTensor-based + context parallelism support with num_distograms for 5D output. + + The 4D->5D reshape uses shardwise_unflatten to avoid exposing DTensor + native operations to the autograd graph (which would cause problematic + all-gather operations during backward). + """ + + def __init__( + self, + module: SerialDistogramModule, + dist_manager: DistributedManager, + distogram_comm: Optional[TransposeComm] = None, + ) -> None: + """Initialize the DTensor-based distogram module. + + Parameters + ---------- + module : SerialDistogramModule + Serial DistogramModule from trunkv2 to be distributed. + Must have num_distograms and num_bins attributes. + dist_manager : DistributedManager + Distributed manager for device mesh and process groups. + distogram_comm : TransposeComm, optional + Communication object for CP transpose operations. + Default is None for serial mode. + """ + super().__init__() + self.dist_manager = dist_manager + self.device_mesh = dist_manager.device_mesh_subgroups + + self.distogram = LinearParamsReplicated(module.distogram, device_mesh=self.device_mesh) + self.distogram_comm = distogram_comm + + self.num_distograms = module.num_distograms + self.num_bins = module.num_bins + + def forward(self, z: DTensor) -> DTensor: + """Perform the forward pass. + + Parameters + ---------- + z : DTensor + The pairwise embeddings as DTensor with shape [B, N, N, token_z]. + Sharded along dimensions 1 and 2 for CP. + + Returns + ------- + DTensor + The predicted distogram with shape [B, N, N, num_distograms, num_bins]. + Maintains the same sharding as input along dimensions 1 and 2. + """ + x: DTensor = redistribute_transpose( + z, + transpose_comm=self.distogram_comm, + output_placements=(Shard(0), Shard(1), Shard(2)), + dim0=1, + dim1=2, + ) + y: DTensor = elementwise_op(z, x, ElementwiseOp.SUM) + + output_4d: DTensor = self.distogram(y) + + output_5d: DTensor = shardwise_unflatten( + output_4d, + dim=3, + sizes=(self.num_distograms, self.num_bins), + ) + + return output_5d + + +class BFactorModule(nn.Module): + """DTensor BFactorModule for Boltz-2. + + Wraps a serial BFactorModule's ``nn.Linear`` with ``LinearParamsReplicated`` + so that parameter gradients are correctly all-reduced across the CP mesh. + + The forward pass is elementwise on the single representation ``s`` + (shape ``[B, N, token_s]``, placements ``(Shard(0), Shard(1), Replicate())``). + No cross-shard communication is required. + """ + + def __init__(self, module: SerialBFactorModule, device_mesh: DeviceMesh) -> None: + """Initialize the distributed BFactorModule. + + Parameters + ---------- + module : SerialBFactorModule + Serial BFactorModule from trunkv2 to be distributed. + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + """ + if not isinstance(module, SerialBFactorModule): + raise TypeError(f"Expected SerialBFactorModule, got {type(module)}") + super().__init__() + self.bfactor = LinearParamsReplicated(module.bfactor, device_mesh=device_mesh) + self.num_bins = module.num_bins + + def forward(self, s: DTensor) -> DTensor: + """Predict per-token B-factor histogram. + + Parameters + ---------- + s : DTensor + Single representation, shape ``[B, N, token_s]``, + placements ``(Shard(0), Shard(1), Replicate())``. + + Returns + ------- + DTensor + Predicted B-factor logits, shape ``[B, N, num_bins]``, + same placements as input. + """ + return self.bfactor(s) + + +class ContactConditioning(nn.Module): + """DTensor ContactConditioning for Boltz-2. + + Wraps the serial ContactConditioning module for context parallelism. + All operations are elementwise on the last dimension of pair features + ``(B, N, N, *)``, so no cross-shard communication is needed. + + The serial FourierEmbedding's ``proj`` is frozen (no grad), so we keep + it as a plain ``nn.Linear`` and operate on local tensor shards directly + via ``to_local()`` / ``DTensor.from_local()``. The trainable ``encoder`` + linear is wrapped with ``LinearParamsReplicated`` for correct gradient + all-reduce. + """ + + def __init__(self, module: SerialContactConditioning, device_mesh: DeviceMesh) -> None: + """Initialize the distributed ContactConditioning. + + Parameters + ---------- + module : SerialContactConditioning + Serial ContactConditioning from trunkv2. + device_mesh : DeviceMesh + The device mesh for distributed tensor operations. + """ + if not isinstance(module, SerialContactConditioning): + raise TypeError(f"Expected SerialContactConditioning, got {type(module)}") + if const.contact_conditioning_info["UNSPECIFIED"] != 0: + raise ValueError( + f"Expected UNSPECIFIED index 0, got {const.contact_conditioning_info['UNSPECIFIED']}. " + "ContactConditioning forward slices cc[:,:,:,0:1] for UNSPECIFIED." + ) + if const.contact_conditioning_info["UNSELECTED"] != 1: + raise ValueError( + f"Expected UNSELECTED index 1, got {const.contact_conditioning_info['UNSELECTED']}. " + "ContactConditioning forward slices cc[:,:,:,1:2] for UNSELECTED." + ) + super().__init__() + self.device_mesh = device_mesh + + if not isinstance(module.fourier_embedding, SerialFourierEmbedding): + raise TypeError(f"Expected SerialFourierEmbedding, got {type(module.fourier_embedding)}") + self.fourier_embedding = module.fourier_embedding + if self.fourier_embedding.proj.weight.requires_grad or ( + self.fourier_embedding.proj.bias is not None and self.fourier_embedding.proj.bias.requires_grad + ): + raise ValueError("FourierEmbedding proj should not have trainable parameters") + + self.encoder = LinearParamsReplicated(module.encoder, device_mesh=device_mesh) + + all_replicate = [Replicate()] * device_mesh.ndim + self.encoding_unspecified = nn.Parameter( + distribute_tensor(module.encoding_unspecified.data, device_mesh, all_replicate), + requires_grad=module.encoding_unspecified.requires_grad, + ) + self.encoding_unselected = nn.Parameter( + distribute_tensor(module.encoding_unselected.data, device_mesh, all_replicate), + requires_grad=module.encoding_unselected.requires_grad, + ) + + self.cutoff_min = module.cutoff_min + self.cutoff_max = module.cutoff_max + + def forward(self, feats: dict[str, DTensor]) -> DTensor: + """Compute contact conditioning pairwise embeddings. + + Parameters + ---------- + feats : dict[str, DTensor] + Must contain: + - ``contact_conditioning``: shape ``(B, N, N, num_contact_types)``, + placements ``(Shard(0), Shard(1), Shard(2))`` + - ``contact_threshold``: shape ``(B, N, N)``, + placements ``(Shard(0), Shard(1), Shard(2))`` + + Returns + ------- + DTensor + Contact conditioning embeddings, shape ``(B, N, N, token_z)``, + placements ``(Shard(0), Shard(1), Shard(2))``. + """ + cc_dt: DTensor = feats["contact_conditioning"] + ct_dt: DTensor = feats["contact_threshold"] + + cc_local = cc_dt.to_local() + ct_local = ct_dt.to_local() + + ct_norm = (ct_local - self.cutoff_min) / (self.cutoff_max - self.cutoff_min) + ct_flat = ct_norm.flatten() + fourier_flat = torch.cos(2 * pi * self.fourier_embedding.proj(ct_flat.unsqueeze(-1))) + ct_fourier = fourier_flat.reshape(ct_norm.shape + (-1,)) + + cc_features = cc_local[:, :, :, 2:] + combined = torch.cat( + [cc_features, ct_norm.unsqueeze(-1), ct_fourier], + dim=-1, + ) + + combined_shape = cc_dt.shape[:-1] + (combined.shape[-1],) + combined_contiguous = combined.contiguous() + combined_stride = update_exhaustive_strides( + combined_contiguous.shape, combined_contiguous.stride(), combined_shape + ) + combined_dt = DTensor.from_local( + combined_contiguous, + self.device_mesh, + cc_dt.placements, + shape=combined_shape, + stride=combined_stride, + ) + encoded_dt = self.encoder(combined_dt) + + # Native DTensor arithmetic is used here despite the general rule + # against implicit DTensor ops on differentiable paths (CLAUDE.md). + # This is safe because all multiplications are Replicate × Shard(0,1,2) + # = local elementwise — no hidden all-gathers. The resulting + # Partial(Sum) gradients for the encoding params are reduced to + # Replicate by on_after_backward. + unspec_flag = cc_dt[:, :, :, 0:1] # (B, N, N, 1) + unsel_flag = cc_dt[:, :, :, 1:2] # (B, N, N, 1) + mask_factor = 1.0 - (unspec_flag + unsel_flag) # (B, N, N, 1) + + result = ( + encoded_dt * mask_factor + self.encoding_unspecified * unspec_flag + self.encoding_unselected * unsel_flag + ) + return result diff --git a/src/boltz/distributed/model/modules/utils.py b/src/boltz/distributed/model/modules/utils.py new file mode 100644 index 000000000..22bfbcb1d --- /dev/null +++ b/src/boltz/distributed/model/modules/utils.py @@ -0,0 +1,943 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Distributed model module utilities. + +This module provides: + +- precision-related helpers for DTensor-based distributed training/inference +- DTensor checkpoint conversion helpers used by context-parallel strategy code +""" + +import os +from contextlib import contextmanager +from enum import Enum +from typing import Any, Mapping, Optional + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor + +from boltz.distributed.model.layers.replicate_op import ReplicateOp, replicate_op +from boltz.distributed.utils import ( + LayoutRightMap, + all_reduce_weighted_mean, + create_and_broadcast_tensor_into_placements, + create_distributed_randn, +) +from boltz.model.modules.utils import random_rotations + + +def get_cpu_offload_hooks(optimized: bool = True): + """Return hooks for moving tensors to CPU asynchronously during activation checkpointing. + + Handles both regular ``torch.Tensor`` and ``DTensor`` by offloading the + underlying local shards. When *optimized* is True, a dedicated CUDA stream + and pinned memory are used for true asynchronous transfers. + + Parameters + ---------- + optimized : bool, optional + Use an optimised async implementation with a dedicated CUDA stream and + pinned memory. Defaults to True. + + Returns + ------- + tuple[Callable, Callable] + A ``(pack_hook, unpack_hook)`` pair for use with + ``torch.autograd.graph.saved_tensors_hooks``. + """ + offload_stream = torch.cuda.Stream() if optimized else None + + def pack_hook(tensor: Tensor): + orig_cls = tensor.__class__ + is_dtensor = isinstance(tensor, DTensor) + + if is_dtensor: + local_tensor = tensor.to_local() + metadata = (tensor.device_mesh, tensor.placements, tensor.shape, tensor.stride()) + else: + local_tensor = tensor + metadata = None + + if local_tensor.is_cuda: + if optimized: + with torch.cuda.stream(offload_stream): + cpu_tensor = torch.empty( + local_tensor.shape, dtype=local_tensor.dtype, device="cpu", pin_memory=True + ) + cpu_tensor.copy_(local_tensor, non_blocking=True) + return (cpu_tensor, orig_cls, metadata) + else: + return (local_tensor.to("cpu", non_blocking=True), orig_cls, metadata) + + return tensor + + def unpack_hook(pack_data): + if not isinstance(pack_data, tuple): + return pack_data + + cpu_tensor, cls, metadata = pack_data + + if optimized: + with torch.cuda.stream(offload_stream): + gpu_tensor = cpu_tensor.to("cuda", non_blocking=True) + torch.cuda.current_stream().wait_stream(offload_stream) + else: + gpu_tensor = cpu_tensor.to("cuda", non_blocking=True) + + if cls is DTensor: + device_mesh, placements, shape, stride = metadata + return DTensor.from_local(gpu_tensor, device_mesh, placements, shape=shape, stride=stride) + + return gpu_tensor + + return pack_hook, unpack_hook + + +def get_cpu_offload_context(optimized: bool = True): + """Return a context manager that offloads checkpoint-boundary tensors to CPU. + + When used together with ``torch.utils.checkpoint.checkpoint``, saved + activations (the *plateau*) are moved to CPU inside the context and + transparently restored to GPU on the backward pass. + + Parameters + ---------- + optimized : bool, optional + Use the optimised async offloading path. Defaults to True. + + Returns + ------- + torch.autograd.graph.saved_tensors_hooks + A context manager wrapping the pack/unpack hooks. + """ + pack, unpack = get_cpu_offload_hooks(optimized=optimized) + return torch.autograd.graph.saved_tensors_hooks(pack, unpack) + + +def extract_checkpointing_config(layer: torch.nn.Module) -> tuple[bool, bool]: + """Extract activation checkpointing configuration from a single layer. + + Detects if the layer has been wrapped with fairscale's checkpoint_wrapper, + which replaces the forward method with a functools.partial object. + + Parameters + ---------- + layer : nn.Module + A single layer module that may have been wrapped with checkpoint_wrapper. + + Returns + ------- + tuple[bool, bool] + (activation_checkpointing, cpu_offloading): + - activation_checkpointing: True if the layer has checkpointing enabled + - cpu_offloading: True if checkpointing is configured to offload to CPU + + """ + import functools + + forward_func = getattr(layer.forward, "func", None) + if ( + isinstance(layer.forward, functools.partial) + and forward_func is not None + and forward_func.__name__ == "_checkpointed_forward" + ): + cpu_offloading = layer.forward.args[-1] + return True, cpu_offloading + + return False, False + + +def has_dtensors(obj: Any) -> bool: + """Recursively check whether an object contains any DTensors. + + Args: + obj: Value to inspect. Supported container recursion includes dict/list/tuple. + + Returns: + ``True`` when at least one DTensor is present, otherwise ``False``. + """ + if isinstance(obj, DTensor): + return True + if isinstance(obj, dict): + return any(has_dtensors(value) for value in obj.values()) + if isinstance(obj, (list, tuple)): + return any(has_dtensors(value) for value in obj) + return False + + +def convert_dtensors_to_tensors(obj: Any) -> Any: + """Recursively convert DTensors to plain tensors. + + For ``Replicate``-only placements, ``to_local()`` returns the full + global tensor with no communication. For any ``Shard``/``Partial`` + placement, this function uses ``full_tensor()`` so checkpoints keep + global tensor semantics and remain topology-portable. + + Args: + obj: Value potentially containing DTensors. + + Returns: + Input structure with all DTensors replaced by plain tensors. + """ + if isinstance(obj, DTensor): + if all(isinstance(placement, Replicate) for placement in obj.placements): + return obj.to_local() + return obj.full_tensor() + if isinstance(obj, dict): + # Keep collective ordering deterministic across ranks when nested + # sharded DTensors are serialized. + keys_sorted = sorted(obj.keys(), key=repr) + return {key: convert_dtensors_to_tensors(obj[key]) for key in keys_sorted} + if isinstance(obj, list): + return [convert_dtensors_to_tensors(value) for value in obj] + if isinstance(obj, tuple): + return tuple(convert_dtensors_to_tensors(value) for value in obj) + return obj + + +def convert_distributed_checkpoint_to_serial_state_dict(checkpoint: Mapping[str, Any]) -> dict[str, Any]: + """Extract and convert a distributed checkpoint state dict to serial tensors. + + Args: + checkpoint: Mapping containing at least a ``"state_dict"`` entry. + + Returns: + A plain ``dict`` where any DTensor entries are converted to plain tensors. + + Raises: + KeyError: If ``"state_dict"`` is missing from ``checkpoint``. + TypeError: If ``checkpoint["state_dict"]`` is not mapping-like. + """ + if "state_dict" not in checkpoint: + raise KeyError("Checkpoint does not contain 'state_dict'") + + state_dict = checkpoint["state_dict"] + if not isinstance(state_dict, Mapping): + raise TypeError("'state_dict' must be a mapping") + + converted = convert_dtensors_to_tensors(state_dict) + if not isinstance(converted, dict): + return dict(converted) + return converted + + +def _convert_serial_value_to_template_layout(value: Any, template_value: Any) -> Any: + """Convert one checkpoint value to match a template value layout/device/dtype. + + Handles four tensor-to-tensor cases plus a non-tensor passthrough: + + * **DTensor → DTensor**: validate shape and stride, return as-is. + * **Tensor → DTensor**: validate shape, distribute to template's mesh/placements. + * **DTensor → Tensor**: unwrap via ``to_local()``, cast to template device/dtype. + * **Tensor → Tensor**: cast to template device/dtype. + * **Non-tensor**: return unchanged. + """ + # --- Common validation for any tensor-to-tensor conversion --------------- + both_tensors = isinstance(value, torch.Tensor) and isinstance(template_value, torch.Tensor) + if both_tensors: + if tuple(value.shape) != tuple(template_value.shape): + raise ValueError( + f"Value shape {tuple(value.shape)} does not match template shape {tuple(template_value.shape)}" + ) + if tuple(value.stride()) != tuple(template_value.stride()): + raise ValueError( + f"Value stride {tuple(value.stride())} does not match template stride {tuple(template_value.stride())}" + ) + + # --- DTensor template ---------------------------------------------------- + if isinstance(template_value, DTensor): + if isinstance(value, DTensor): + return value + if not isinstance(value, torch.Tensor): + raise TypeError(f"Expected tensor value for DTensor template, got {type(value)}") + + value = value.to(device=template_value.device_mesh.device_type, dtype=template_value.dtype) + if all(isinstance(p, Replicate) for p in template_value.placements): + # All ranks load the same checkpoint, so the value is already + # identical across ranks. from_local avoids the redundant + # all-gather that distribute_tensor would trigger. + return DTensor.from_local( + value, + device_mesh=template_value.device_mesh, + placements=template_value.placements, + shape=value.shape, + stride=value.stride(), + ) + return distribute_tensor( + value, + device_mesh=template_value.device_mesh, + placements=template_value.placements, + ) + + # --- Plain tensor template ----------------------------------------------- + if isinstance(template_value, torch.Tensor): + if isinstance(value, DTensor): + value = value.to_local() + if isinstance(value, torch.Tensor): + return value.to(device=template_value.device, dtype=template_value.dtype) + + # --- Fallback: unwrap DTensor or pass through ---------------------------- + if isinstance(value, DTensor): + return value.to_local() + return value + + +def convert_serial_checkpoint_to_distributed_state_dict( + checkpoint: Mapping[str, Any], + strict: bool = False, + state_dict_template: Optional[Mapping[str, Any]] = None, +) -> dict[str, Any]: + """Convert a serial checkpoint state dict to match a distributed state template. + + This helper intentionally works from an explicit ``state_dict_template`` rather + than constructing a full distributed model, so strategy tests can run without the + full CP model stack. + + Args: + checkpoint: Mapping containing a serial ``"state_dict"``. + strict: Enforce key parity between serial state and template when ``True``. + state_dict_template: A mapping (typically ``lightning_module.state_dict()``) + that defines desired output layout/type per key. + + Returns: + A new state dict aligned to ``state_dict_template``. + + Raises: + KeyError: If required checkpoint fields are missing, or strict key parity fails. + TypeError: If ``checkpoint["state_dict"]`` is not mapping-like. + ValueError: If ``state_dict_template`` is not provided. + """ + if "state_dict" not in checkpoint: + raise KeyError("Checkpoint does not contain 'state_dict'") + if state_dict_template is None: + raise ValueError("state_dict_template is required to convert serial checkpoint to distributed layout") + + state_dict = checkpoint["state_dict"] + if not isinstance(state_dict, Mapping): + raise TypeError("'state_dict' must be a mapping") + + template_keys = set(state_dict_template.keys()) + state_keys = set(state_dict.keys()) + missing_keys = template_keys - state_keys + extra_keys = state_keys - template_keys + if strict and (missing_keys or extra_keys): + msg = "State-dict keys do not match template keys." + if missing_keys: + msg += f" Missing keys: {sorted(missing_keys)}." + if extra_keys: + msg += f" Extra keys: {sorted(extra_keys)}." + raise KeyError(msg) + + converted_state: dict[str, Any] = {} + for key, template_value in state_dict_template.items(): + if key not in state_dict: + continue + converted_state[key] = _convert_serial_value_to_template_layout(state_dict[key], template_value) + + if not strict: + for key in extra_keys: + converted_state[key] = convert_dtensors_to_tensors(state_dict[key]) + + return converted_state + + +def validate_window_batching_parameters( + attn_window_queries: Optional[int], attn_window_keys: Optional[int], use_window_batching: bool +) -> None: + """Validates parameters for window batching in attention mechanisms. + + Args: + attn_window_queries: Size of the query window. Must be a positive even integer if provided. + attn_window_keys: Size of the key window. Must be a positive integer if provided. + use_window_batching: Whether window batching is enabled. + + Raises: + ValueError: If ``attn_window_queries`` and ``attn_window_keys`` are not both None or both not None. + ValueError: If ``use_window_batching`` is True but ``attn_window_queries`` is None. + ValueError: If ``attn_window_queries`` is not a positive even integer. + ValueError: If ``attn_window_keys`` is not a positive integer. + ValueError: If ``attn_window_keys`` is not divisible by ``attn_window_queries // 2``. + """ + if (attn_window_queries is None) != (attn_window_keys is None): + raise ValueError("attn_window_queries and attn_window_keys must be either both None or both not None") + + if (attn_window_queries is None) == use_window_batching: + raise ValueError( + f"attn_window_queries and attn_window_keys must be None if use_window_batching is False, otherwise they must be not None, but got attn_window_queries={attn_window_queries}, attn_window_keys={attn_window_keys} and use_window_batching={use_window_batching}" + ) + + if attn_window_queries is not None: + if not isinstance(attn_window_queries, int) or attn_window_queries <= 0: + raise ValueError("attn_window_queries must be a positive integer") + + if attn_window_queries % 2 != 0: + raise ValueError("attn_window_queries must be even") + + if attn_window_keys is not None: + if not isinstance(attn_window_keys, int) or attn_window_keys <= 0: + raise ValueError("attn_window_keys must be a positive integer") + + if attn_window_keys % (attn_window_queries // 2) != 0: + raise ValueError("attn_window_keys must be divisible by attn_window_queries // 2") + + +class Precision(Enum): + """Precision modes for model computation.""" + + BF16 = "BF16" + BF16_MIXED = "BF16_MIXED" + FP16 = "FP16" + TF32 = "TF32" + FP32 = "FP32" + FP64 = "FP64" + + +class SDPAWithBiasBackend(Enum): + """Scaled dot-product attention with bias backend implementations.""" + + REFERENCE = "reference" + TORCH_SDPA_EFFICIENT_ATTENTION = "torch_sdpa_efficient_attention" + TORCH_FLEX_ATTN = "torch_flex_attn" + + +class TriAttnBackend(Enum): + """Triangle attention backend implementations (for distributed triangular attention).""" + + REFERENCE = "reference" + CUEQ = "cueq" + TRIFAST = "trifast" + CUEQ_FWD_TRIFAST_BWD = "cueq_fwd_trifast_bwd" + + +class SetTriAttnBackend: + """Callable that sets ``triattn_backend`` on every :class:`PairformerLayer` in a model. + + Designed for use with :meth:`torch.nn.Module.apply`:: + + from boltz.distributed.model.modules.utils import SetTriAttnBackend, TriAttnBackend + model.apply(SetTriAttnBackend(TriAttnBackend.CUEQ)) + + ``MSALayer`` is **not** targeted directly because it contains a + ``PairformerNoSeqLayer`` child which is reached by the recursive + ``apply`` traversal. + """ + + def __init__(self, triattn_backend: TriAttnBackend) -> None: + # Lazy import: PairformerLayer imports TriAttnBackend from this module, + # so a top-level import would create a circular dependency. + from boltz.distributed.model.layers.pairformer import PairformerLayer + + valid = ( + TriAttnBackend.REFERENCE, + TriAttnBackend.CUEQ, + TriAttnBackend.TRIFAST, + TriAttnBackend.CUEQ_FWD_TRIFAST_BWD, + ) + if triattn_backend not in valid: + raise ValueError(f"triattn_backend must be one of {valid} but got {triattn_backend}") + self.triattn_backend = triattn_backend + self.supported_module_types = (PairformerLayer,) + + def __call__(self, module: torch.nn.Module) -> None: + if not isinstance(module, self.supported_module_types): + return + if not hasattr(module, "triattn_backend"): + raise AttributeError( + f"Module {type(module).__name__} should but does not have a 'triattn_backend' attribute" + ) + module.triattn_backend = self.triattn_backend + + +class SetAttnPairBiasBackend: + """Callable that sets ``sdpa_with_bias_backend`` on every :class:`AttentionPairBias` in a model. + + Designed for use with :meth:`torch.nn.Module.apply`:: + + from boltz.distributed.model.modules.utils import SDPAWithBiasBackend, SetAttnPairBiasBackend + model.apply(SetAttnPairBiasBackend(SDPAWithBiasBackend.TORCH_FLEX_ATTN)) + + Only ``REFERENCE`` and ``TORCH_FLEX_ATTN`` are valid for ring-attention + :class:`AttentionPairBias`; see the validation in its ``__init__``. + """ + + def __init__(self, sdpa_with_bias_backend: SDPAWithBiasBackend) -> None: + # Lazy import: attention.py imports from this module, so a top-level + # import would create a circular dependency. + from boltz.distributed.model.layers.attention import AttentionPairBias + + valid = (SDPAWithBiasBackend.REFERENCE, SDPAWithBiasBackend.TORCH_FLEX_ATTN) + if sdpa_with_bias_backend not in valid: + raise ValueError(f"sdpa_with_bias_backend must be one of {valid} but got {sdpa_with_bias_backend}") + self.sdpa_with_bias_backend = sdpa_with_bias_backend + self._target_type = AttentionPairBias + + def __call__(self, module: torch.nn.Module) -> None: + if not isinstance(module, self._target_type): + return + if not hasattr(module, "sdpa_with_bias_backend"): + raise AttributeError( + f"Module {type(module).__name__} should but does not have a " f"'sdpa_with_bias_backend' attribute" + ) + module.sdpa_with_bias_backend = self.sdpa_with_bias_backend + + +class SetAttnPairBiasShardwiseBackend: + """Callable that sets ``sdpa_with_bias_backend`` on every :class:`AttentionPairBiasShardwise` in a model. + + Designed for use with :meth:`torch.nn.Module.apply`:: + + from boltz.distributed.model.modules.utils import SDPAWithBiasBackend, SetAttnPairBiasShardwiseBackend + model.apply(SetAttnPairBiasShardwiseBackend(SDPAWithBiasBackend.TORCH_SDPA_EFFICIENT_ATTENTION)) + + All three ``SDPAWithBiasBackend`` members are valid for window-batched + :class:`AttentionPairBiasShardwise`. + """ + + def __init__(self, sdpa_with_bias_backend: SDPAWithBiasBackend) -> None: + # Lazy import: attention.py imports from this module, so a top-level + # import would create a circular dependency. + from boltz.distributed.model.layers.attention import AttentionPairBiasShardwise + + valid = ( + SDPAWithBiasBackend.REFERENCE, + SDPAWithBiasBackend.TORCH_SDPA_EFFICIENT_ATTENTION, + SDPAWithBiasBackend.TORCH_FLEX_ATTN, + ) + if sdpa_with_bias_backend not in valid: + raise ValueError(f"sdpa_with_bias_backend must be one of {valid} but got {sdpa_with_bias_backend}") + self.sdpa_with_bias_backend = sdpa_with_bias_backend + self._target_type = AttentionPairBiasShardwise + + def __call__(self, module: torch.nn.Module) -> None: + if not isinstance(module, self._target_type): + return + if not hasattr(module, "sdpa_with_bias_backend"): + raise AttributeError( + f"Module {type(module).__name__} should but does not have a " f"'sdpa_with_bias_backend' attribute" + ) + module.sdpa_with_bias_backend = self.sdpa_with_bias_backend + + +class OffloadActvCkptToCPU: + """Callable that enables ``cpu_offloading`` on selected distributed module types. + + Designed for use with :meth:`torch.nn.Module.apply`:: + + from boltz.distributed.model.modules.utils import OffloadActvCkptToCPU + model.apply(OffloadActvCkptToCPU(["DiffusionTransformer", "PairformerModule"])) + + Each targeted module must already have ``activation_checkpointing = True``; + a :class:`ValueError` is raised otherwise. + + Parameters + ---------- + module_types : set[str] + Subset of ``{"DiffusionTransformer", "MSAModule", "PairformerModule"}``. + """ + + def __init__(self, module_types: set[str]) -> None: + from boltz.distributed.model.layers.pairformer import PairformerModule + from boltz.distributed.model.modules.transformers import DiffusionTransformer + from boltz.distributed.model.modules.trunkv2 import MSAModule + + valid_map: dict[str, type] = { + "DiffusionTransformer": DiffusionTransformer, + "MSAModule": MSAModule, + "PairformerModule": PairformerModule, + } + module_types = set(module_types) + invalid = module_types - valid_map.keys() + if invalid: + raise ValueError( + f"Invalid module type(s) {sorted(invalid)} for OffloadActvCkptToCPU. " + f"Valid types: {sorted(valid_map)}" + ) + if not module_types: + raise ValueError("module_types must be non-empty") + self._target_types = tuple(valid_map[n] for n in sorted(module_types)) + + def __call__(self, module: torch.nn.Module) -> None: + if not isinstance(module, self._target_types): + return + for attr in ("activation_checkpointing", "cpu_offloading"): + if not hasattr(module, attr): + raise AttributeError(f"Module {type(module).__name__} should but does not have a '{attr}' attribute") + if not module.activation_checkpointing: + raise ValueError( + f"Cannot enable cpu_offloading on {type(module).__name__} because " + f"activation_checkpointing is not enabled. Enable it first " + f"(e.g. model.msa_args/pairformer_args/score_model_args" + f".activation_checkpointing=true)." + ) + module.cpu_offloading = True + + +PRECISION_TO_DTYPE = { + Precision.BF16: torch.bfloat16, + Precision.BF16_MIXED: torch.bfloat16, + Precision.FP16: torch.float16, + Precision.TF32: torch.float32, + Precision.FP32: torch.float32, + Precision.FP64: torch.float64, +} + + +DTYPE_TO_PRECISION = { + # no BF16-MIXED mapping as it's only relevant for training (mostly) + # also, this util dict is only used to look up dtype-specific attention INF values + # in the tests + torch.bfloat16: Precision.BF16, + torch.float16: Precision.FP16, + torch.float32: Precision.FP32, + torch.float64: Precision.FP64, +} + + +PRECISION_TO_LIGHTNING = { + Precision.BF16: "bf16-true", + Precision.BF16_MIXED: "bf16-mixed", + Precision.FP16: "fp16-true", + Precision.TF32: "32", + Precision.FP32: "32", + Precision.FP64: "64", +} + + +@contextmanager +def setup_tf32_env(precision: Precision): + """Context manager to setup TF32 environment based on precision setting. + + This context manager temporarily modifies TF32 settings for CUDA operations + and automatically restores the original settings when exiting the context. + + Args: + precision (Precision): Target precision mode + + Example: + >>> with setup_tf32_env(Precision.TF32): + ... # TF32 is enabled for this block + ... result = model(input_tensor) + >>> # Original TF32 settings are restored + + >>> with setup_tf32_env(Precision.FP32): + ... # TF32 is explicitly disabled for pure FP32 + ... result = model(input_tensor) + + Note: + This affects both CUDA matrix operations and cuDNN operations. + The original environment is always restored, even if an exception occurs. + """ + # Store original TF32 settings + original_env = os.environ.get("NVIDIA_TF32_OVERRIDE", None) + original_matmul_tf32 = torch.backends.cuda.matmul.allow_tf32 + original_cudnn_tf32 = torch.backends.cudnn.allow_tf32 + + # Setup TF32 environment based on precision + use_tf32 = precision == Precision.TF32 + + if use_tf32: + os.environ["NVIDIA_TF32_OVERRIDE"] = "1" + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + elif precision == Precision.FP32: + # Explicitly disable TF32 for pure FP32 + os.environ["NVIDIA_TF32_OVERRIDE"] = "0" + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + # For fp64 or other precisions, leave TF32 settings unchanged + + try: + yield + finally: + # Restore original TF32 settings + if original_env is not None: + os.environ["NVIDIA_TF32_OVERRIDE"] = original_env + else: + os.environ.pop("NVIDIA_TF32_OVERRIDE", None) + torch.backends.cuda.matmul.allow_tf32 = original_matmul_tf32 + torch.backends.cudnn.allow_tf32 = original_cudnn_tf32 + + +def create_and_broadcast_random_rotation( + shape: tuple[int, ...], + device_mesh: DeviceMesh, + dtype: torch.dtype = torch.float32, +) -> Tensor: + """Create a distributed random rotation matrix. + + Parameters + ---------- + shape : tuple[int, ...] + Shape of the random rotation matrix, e.g. (B, 3, 3). + device_mesh : DeviceMesh + The device mesh for DTensor operations. + dtype : torch.dtype, optional + The dtype of the random rotation matrix. Defaults to torch.float32. + + Returns + ------- + Tensor + The local random rotation matrix tensor (shard of the batch dim). + + """ + + def create_rand_rot_fn(shape_local, dtype, device): + return random_rotations(shape_local[0], dtype=dtype, device=device) + + tensor_local = create_and_broadcast_tensor_into_placements( + shape=shape, + create_local_fn=create_rand_rot_fn, + device_mesh=device_mesh, + placements=(Shard(0), Replicate(), Replicate()), + dtype=dtype, + ) + return tensor_local + + +@torch.no_grad() +def randomly_rotate( + coords: DTensor, + second_coords: Optional[DTensor] = None, + return_roto: bool = False, +) -> tuple[DTensor, Optional[DTensor], Optional[DTensor]]: + """Randomly rotate coordinates using DTensor operations. Does not support backward. + + Parameters + ---------- + coords : DTensor + The coordinates to rotate, shape (B, N, 3). + Placements: (Shard(0), Shard(1), Replicate()). + second_coords : Optional[DTensor], optional + Optional second coordinates to rotate with the same rotation matrix. + return_roto : bool, optional + Whether to return the rotation matrix. + + Returns + ------- + tuple[DTensor, Optional[DTensor], Optional[DTensor]] + Rotated coords, rotated second_coords (if provided), rotation matrix (if requested). + + """ + if coords.requires_grad: + raise ValueError("randomly_rotate does not support backward pass but got coords.requires_grad is True") + + device_mesh = coords.device_mesh + placements = coords.placements + + if placements != (Shard(0), Shard(1), Replicate()): + raise ValueError(f"Expected placements (Shard(0), Shard(1), Replicate()), got {placements}") + + if second_coords is not None and second_coords.placements != (Shard(0), Shard(1), Replicate()): + raise ValueError( + f"Expected second_coords placements (Shard(0), Shard(1), Replicate()), got {second_coords.placements}" + ) + + size_batch = coords.shape[0] + + # Create random rotation matrix (local shard) + R_local = create_and_broadcast_random_rotation( + shape=(size_batch, 3, 3), + device_mesh=device_mesh, + dtype=coords.to_local().dtype, + ) + + # Apply rotation using einsum on local tensors + coords_local = coords.to_local() + coords_rotated_local = torch.einsum("bmd,bds->bms", coords_local, R_local) + + coords_rotated = DTensor.from_local( + coords_rotated_local, + device_mesh=device_mesh, + placements=placements, + shape=coords.shape, + stride=coords.stride(), + ) + + # Handle second_coords + second_coords_rotated = None + if second_coords is not None: + second_coords_local = second_coords.to_local() + second_coords_rotated_local = torch.einsum("bmd,bds->bms", second_coords_local, R_local) + second_coords_rotated = DTensor.from_local( + second_coords_rotated_local, + device_mesh=device_mesh, + placements=placements, + shape=second_coords.shape, + stride=second_coords.stride(), + ) + + # Return rotation matrix if requested + roto = None + if return_roto: + shape_roto = (size_batch, 3, 3) + stride_roto = LayoutRightMap(shape_roto).strides + roto = DTensor.from_local( + R_local, + device_mesh, + (Shard(0), Replicate(), Replicate()), + shape=shape_roto, + stride=stride_roto, + ) + + return coords_rotated, second_coords_rotated, roto + + +def center_random_augmentation( + atom_coords: DTensor, + atom_mask: DTensor, + s_trans: float = 1.0, + augmentation: bool = True, + centering: bool = True, + return_second_coords: bool = False, + second_coords: Optional[DTensor] = None, + return_roto: bool = False, +) -> tuple[DTensor, DTensor, DTensor] | tuple[DTensor, DTensor] | DTensor: + """Center and randomly augment coordinates using DTensor operations. Does not support backward. + + Parameters + ---------- + atom_coords : DTensor + Atom coordinates, shape (B, N, 3). + Placements: (Shard(0), Shard(1), Replicate()). + atom_mask : DTensor + Atom mask, shape (B, N). + Placements: (Shard(0), Shard(1), Replicate()). + s_trans : float, optional + Translation scale factor, by default 1.0. + augmentation : bool, optional + Whether to add random rotation + translation, by default True. + centering : bool, optional + Whether to center coordinates to zero mean, by default True. + return_second_coords : bool, optional + Whether to return transformed second coordinates, by default False. + second_coords : Optional[DTensor], optional + Second coordinates to apply the same transformation. + return_roto : bool, optional + Whether to return the rotation matrix, by default False. + + Returns + ------- + DTensor | tuple[DTensor, ...] + Augmented coordinates, and optionally second coords and rotation matrix. + + """ + if atom_coords.requires_grad: + raise ValueError("center_random_augmentation does not support backward pass") + + if second_coords is not None and second_coords.requires_grad: + raise ValueError("center_random_augmentation does not support backward pass for second_coords") + + if return_roto and not augmentation: + raise ValueError("cannot return rotation matrix when augmentation is False") + + device_mesh = atom_coords.device_mesh + input_placements = atom_coords.placements + if input_placements != (Shard(0), Shard(1), Replicate()): + raise ValueError(f"Expected placements (Shard(0), Shard(1), Replicate()), got {input_placements}") + + cp_axis_0_group = device_mesh.get_group("cp_axis_0") + cp_axis_1_group = device_mesh.get_group("cp_axis_1") + cp_axis_1_rank = device_mesh.get_local_rank("cp_axis_1") + + if centering: + # Compute mean on cp_axis_1 rank 0, then broadcast to all column ranks + if cp_axis_1_rank == 0: + atom_coords_local = atom_coords.to_local().requires_grad_(False) + atom_mask_local = atom_mask.to_local().requires_grad_(False) + + atom_mean_local = all_reduce_weighted_mean( + atom_mask_local.unsqueeze(-1), + atom_coords_local, + group_reduce=cp_axis_0_group, + dim=1, + ) + else: + atom_coords_local = atom_coords.to_local() + atom_mean_local = torch.empty_like(atom_coords_local[:, 0, :]) + + # Broadcast mean across cp_axis_1 + dist.broadcast(atom_mean_local, dist.get_global_rank(cp_axis_1_group, 0), cp_axis_1_group) + + shape_atom_mean_global = (atom_coords.shape[0], atom_coords.shape[-1]) + stride_atom_mean_global = LayoutRightMap(shape_atom_mean_global).strides + atom_mean = DTensor.from_local( + atom_mean_local, + device_mesh=device_mesh, + placements=(Shard(0), Replicate(), Replicate()), + shape=shape_atom_mean_global, + stride=stride_atom_mean_global, + ) + + atom_coords = replicate_op(atom_coords, atom_mean, 1, ReplicateOp.SUB) + if second_coords is not None: + second_coords = replicate_op(second_coords, atom_mean, 1, ReplicateOp.SUB) + + if augmentation: + atom_coords, second_coords, roto = randomly_rotate( + atom_coords, + second_coords=second_coords, + return_roto=return_roto, + ) + + # Generate and apply random translation + batch_size = atom_coords.shape[0] + random_trans = create_distributed_randn( + (batch_size, 1, 3), + device_mesh=device_mesh, + placements=(Shard(0), Replicate(), Replicate()), + dtype=atom_coords.dtype, + scale=s_trans, + ) + + with torch.no_grad(): + atom_coords_local = atom_coords.to_local() + random_trans_local = random_trans.to_local() + atom_coords = DTensor.from_local( + atom_coords_local + random_trans_local, + device_mesh=device_mesh, + placements=input_placements, + shape=atom_coords.shape, + stride=atom_coords.stride(), + ) + + if second_coords is not None: + with torch.no_grad(): + second_coords_local = second_coords.to_local() + second_coords = DTensor.from_local( + second_coords_local + random_trans_local, + device_mesh=device_mesh, + placements=input_placements, + shape=second_coords.shape, + stride=second_coords.stride(), + ) + + if return_second_coords and return_roto: + return atom_coords, second_coords, roto + elif return_second_coords: + return atom_coords, second_coords + elif return_roto: + return atom_coords, roto + else: + return atom_coords diff --git a/src/boltz/distributed/model/optim/__init__.py b/src/boltz/distributed/model/optim/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/src/boltz/distributed/model/optim/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/src/boltz/distributed/model/optim/ema.py b/src/boltz/distributed/model/optim/ema.py new file mode 100644 index 000000000..4b5c1728a --- /dev/null +++ b/src/boltz/distributed/model/optim/ema.py @@ -0,0 +1,266 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""DTensor-aware EMA callback for distributed context-parallel training. + +This module provides :class:`DistributedEMA`, a drop-in subclass of the +base :class:`~boltz.model.optim.ema.EMA` callback that handles the +DTensor ↔ plain-tensor conversions required when model parameters live +on a distributed device mesh. + +Key differences from the base EMA callback: + +* **Save** – EMA shadow weights (which are DTensors during training) are + converted to plain local tensors before being written to the checkpoint, + keeping checkpoints portable across topologies. +* **Load** – Plain-tensor EMA weights from a checkpoint are re-distributed + to match the model's current DTensor placements before the first + training step. +* **Weight swap** (for validation with EMA) – The model-weight backup + preserves DTensor metadata instead of moving to CPU (which strips + DTensor placement information). +""" + +from __future__ import annotations + +import logging +from typing import Any + +import torch +from pytorch_lightning import LightningModule, Trainer +from torch.distributed.tensor import DTensor, Replicate + +from boltz.distributed.model.modules.utils import convert_dtensors_to_tensors +from boltz.model.optim.ema import EMA + +logger = logging.getLogger(__name__) + + +def _assert_replicate_only(tensor: DTensor, name: str) -> None: + """Raise if *tensor* has any non-Replicate placements. + + EMA uses in-place ``.data`` arithmetic which bypasses DTensor dispatch, + so it is only correct when every placement is ``Replicate``. + """ + for placement in tensor.placements: + if not isinstance(placement, Replicate): + raise ValueError( + f"DistributedEMA requires all placements to be Replicate for " + f"correct in-place arithmetic, but parameter '{name}' has " + f"placement {placement!r}. Shard/Partial placements would " + f"silently produce incorrect EMA updates." + ) + + +class DistributedEMA(EMA): + """DTensor-aware Exponential Moving Average callback. + + For models whose parameters are plain :class:`torch.Tensor` objects this + callback is behaviourally identical to :class:`EMA`. When parameters are + :class:`~torch.distributed.tensor.DTensor` instances (e.g. under context + parallelism), the additional conversions described in the module docstring + are applied transparently. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._placements_validated: bool = False + + @staticmethod + def _realign_weights_to_model( + weights: dict[str, torch.Tensor], + pl_module: LightningModule, + ) -> dict[str, torch.Tensor]: + """Re-distribute plain-tensor weights to match a model's DTensor layout. + + For every key in *weights* whose corresponding model parameter is a + :class:`DTensor`, the plain tensor is wrapped with the same device-mesh + and placements. Keys that are already DTensors or whose templates are + plain tensors pass through unchanged. + + Args: + weights: Dict mapping parameter names to tensors (plain or DTensor). + pl_module: The LightningModule whose ``state_dict()`` defines the + target DTensor layout. + + Returns: + A new dict with values distributed to match the model layout. + + Raises: + ValueError: If an EMA weight shape does not match the model + template shape, or if a DTensor template has non-Replicate + placements. + """ + model_state = pl_module.state_dict() + ema_keys = set(weights.keys()) + model_keys = set(model_state.keys()) + missing_in_model = ema_keys - model_keys + missing_in_ema = model_keys - ema_keys + if missing_in_model: + logger.warning( + "DistributedEMA: %d EMA key(s) not found in model state_dict (stale checkpoint?): %s", + len(missing_in_model), + sorted(missing_in_model)[:5], + ) + if missing_in_ema: + logger.warning( + "DistributedEMA: %d model key(s) not found in EMA weights (new layers added?): %s", + len(missing_in_ema), + sorted(missing_in_ema)[:5], + ) + + realigned: dict[str, torch.Tensor] = {} + for key, weight in weights.items(): + template = model_state.get(key) + if template is not None and isinstance(template, DTensor) and not isinstance(weight, DTensor): + if tuple(weight.shape) != tuple(template.shape): + raise ValueError( + f"DistributedEMA: shape mismatch for '{key}': " + f"EMA weight has shape {tuple(weight.shape)} but model " + f"template has shape {tuple(template.shape)}" + ) + _assert_replicate_only(template, key) + + weight = weight.to(device=template.device_mesh.device_type, dtype=template.dtype) + # Placements are validated as Replicate above, and all + # ranks hold the same EMA checkpoint — from_local avoids + # the redundant all-gather that distribute_tensor would do. + weight = DTensor.from_local( + weight, + device_mesh=template.device_mesh, + placements=template.placements, + shape=template.shape, + stride=template.stride(), + ) + realigned[key] = weight + + for key in missing_in_ema: + realigned[key] = model_state[key].detach().clone() + + return realigned + + def apply_ema(self, pl_module: LightningModule) -> None: + """Apply EMA update, asserting Replicate placements on DTensors. + + The base EMA class performs arithmetic via ``.data`` which bypasses + DTensor dispatch. This is only correct when all placements are + ``Replicate``. This override validates that invariant on the first + call and then delegates to the base implementation. + """ + # Validate on first call only (placements don't change during training) + if not self._placements_validated: + for k, ema_w in self._ema_weights.items(): + if isinstance(ema_w, DTensor): + _assert_replicate_only(ema_w, f"ema[{k}]") + for k, param in pl_module.state_dict().items(): + if isinstance(param, DTensor): + _assert_replicate_only(param, f"model[{k}]") + self._placements_validated = True + + super().apply_ema(pl_module) + + def on_save_checkpoint( + self, + trainer: Trainer, # noqa: ARG002 + pl_module: LightningModule, # noqa: ARG002 + checkpoint: dict[str, Any], + ) -> None: + """Save EMA state, converting any DTensors to plain tensors.""" + if self.ema_initialized: + checkpoint["ema"] = { + "cur_step": self._cur_step, + "ema_weights": convert_dtensors_to_tensors(self._ema_weights), + } + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # noqa: ARG002 + """Initialise or re-distribute EMA weights to match the model layout. + + * **First run** – clones the model's ``state_dict()`` (which may + contain DTensors). + * **Resumed run** – plain-tensor EMA weights loaded from checkpoint + are redistributed to match the model's DTensor placements. + """ + # Re-validate placements on next apply_ema (topology may have changed). + self._placements_validated = False + + if not self.ema_initialized: + self._ema_weights = {k: p.detach().clone() for k, p in pl_module.state_dict().items()} + else: + # Re-distribute plain-tensor EMA weights to match the model's DTensor layout. + self._ema_weights = self._realign_weights_to_model(self._ema_weights, pl_module) + + # Move to correct device (preserves DTensor placements). + self._ema_weights = {k: p.to(pl_module.device) for k, p in self._ema_weights.items()} + + def replace_model_weights(self, pl_module: LightningModule) -> None: + """Replace model weights with EMA weights, backing up originals to CPU. + + DTensor metadata (device mesh and placements) is stored alongside the + CPU tensors so :meth:`restore_original_weights` can reconstruct the + original DTensor layout without an extra ``state_dict()`` call. + Keeping the backup on CPU instead of on-device frees GPU memory for + the duration of validation. + + The ``inference_mode(False)`` guard is required because Lightning's + ``trainer.predict()`` wraps the entire workflow in + ``torch.inference_mode(True)`` by default (since Lightning ≥2.x). + Operations like ``.detach()``, ``.clone()``, and ``load_state_dict`` + manipulate version counters, which PyTorch ≥2.10 disallows on + inference tensors. + """ + with torch.inference_mode(False): + self._weights_buffer: dict[str, torch.Tensor] = {} + self._weights_dtensor_meta: dict[str, tuple] = {} + for k, p in pl_module.state_dict().items(): + if isinstance(p, DTensor): + self._weights_dtensor_meta[k] = (p.device_mesh, p.placements, p.shape, p.stride()) + self._weights_buffer[k] = p.to_local().detach().clone().cpu() + else: + self._weights_buffer[k] = p.detach().clone().cpu() + ema_weights_to_load = self._realign_weights_to_model(self._ema_weights, pl_module) + pl_module.load_state_dict(ema_weights_to_load, strict=False) + + def restore_original_weights(self, pl_module: LightningModule) -> None: + """Restore original model weights from the CPU backup. + + For keys that were originally DTensors, the plain CPU tensor is + redistributed back to the stored device mesh and placements. + + See :meth:`replace_model_weights` for why the ``inference_mode(False)`` + guard is necessary. + """ + with torch.inference_mode(False): + restored: dict[str, torch.Tensor] = {} + for k, cpu_tensor in self._weights_buffer.items(): + if k in self._weights_dtensor_meta: + mesh, placements, shape, stride = self._weights_dtensor_meta[k] + restored[k] = DTensor.from_local( + cpu_tensor.to(device=mesh.device_type), + device_mesh=mesh, + placements=placements, + shape=shape, + stride=stride, + ) + else: + restored[k] = cpu_tensor.to(device=pl_module.device) + pl_module.load_state_dict(restored, strict=False) + del self._weights_buffer + del self._weights_dtensor_meta diff --git a/src/boltz/distributed/model/validation/__init__.py b/src/boltz/distributed/model/validation/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/src/boltz/distributed/model/validation/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/src/boltz/distributed/model/validation/rcsb.py b/src/boltz/distributed/model/validation/rcsb.py new file mode 100644 index 000000000..900f2b050 --- /dev/null +++ b/src/boltz/distributed/model/validation/rcsb.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +"""Distributed RCSB validator for Boltz-2 with context parallelism. + +Uses diamond inheritance so that: + + DistributedRCSBValidator + ├── DistributedValidator (DTensor gathering + metric overrides) + └── RCSBValidator (process / on_epoch_end flow) + └── Validator (metric storage, serial compute logic) + +MRO: DistributedRCSBValidator → DistributedValidator → RCSBValidator → Validator + +This gives ``DistributedValidator`` precedence for every method it +overrides (``run_model``, ``common_val_step``, ``compute_disto_loss``, …) +while ``RCSBValidator.process()`` and ``RCSBValidator.on_epoch_end()`` +provide the RCSB-specific entry points that delegate to +``self.common_val_step`` / ``self.common_on_epoch_end`` – both of which +resolve to the distributed versions from ``DistributedValidator``. +""" + +from typing import Optional + +import torch +from pytorch_lightning import LightningModule + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.model.validation.validator import DistributedValidator +from boltz.model.validation.rcsb import RCSBValidator + + +class DistributedRCSBValidator(DistributedValidator, RCSBValidator): + """Distributed RCSB validator with DTensor-aware metric computation. + + Inherits: + * Metric overrides & gathering from :class:`DistributedValidator`. + * ``process`` / ``on_epoch_end`` flow from :class:`RCSBValidator`. + """ + + def __init__( + self, + val_names: list[str], + confidence_prediction: bool = False, + physicalism_metrics: bool = False, + rmsd_metrics: bool = False, + clash_score_metrics: bool = False, + override_val_method: Optional[str] = None, + ) -> None: + """Initialize the distributed RCSB validator. + + Parameters + ---------- + val_names : list[str] + The list of validation names. + confidence_prediction : bool + Whether to predict confidence. + physicalism_metrics : bool + Whether to compute physicalism metrics. + rmsd_metrics : bool + Whether to compute rmsd metrics. + clash_score_metrics : bool + Whether to compute clash score metrics. + override_val_method : Optional[str] + The override validation method. + """ + # Bypass cooperative super().__init__ to avoid MRO signature + # mismatches between DistributedValidator and RCSBValidator. + # Both ultimately initialise the same Validator base; calling it + # directly avoids passing kwargs that the intermediate classes + # don't expect. + DistributedValidator.__init__( + self, + val_names=val_names, + confidence_prediction=confidence_prediction, + physicalism_metrics=physicalism_metrics, + rmsd_metrics=rmsd_metrics, + clash_score_metrics=clash_score_metrics, + override_val_method=override_val_method, + ) + + def process( + self, + model: LightningModule, + batch: dict[str, torch.Tensor], + out: dict[str, torch.Tensor], + idx_dataset: int, + transpose_comm: TransposeComm, + ) -> None: + """Compute features. + + Parameters + ---------- + model : LightningModule + The LightningModule model. + batch : Dict[str, torch.Tensor] + The batch input. + out : Dict[str, torch.Tensor] + The output of the model. + idx_dataset : int + Global dataset index. + transpose_comm : TransposeComm + The transpose communication object. + """ + symmetry_correction = model.val_group_mapper[idx_dataset]["symmetry_correction"] + expand_to_diffusion_samples = symmetry_correction # True # TODO Mateo why is this set to sym correction? + + # For now all was dumped into the common operation in the parent Validator class + self.common_val_step( + model, + batch, + out, + idx_dataset, + expand_to_diffusion_samples=expand_to_diffusion_samples, + transpose_comm=transpose_comm, + ) diff --git a/src/boltz/distributed/model/validation/utils.py b/src/boltz/distributed/model/validation/utils.py new file mode 100644 index 000000000..128cb8e40 --- /dev/null +++ b/src/boltz/distributed/model/validation/utils.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from torch import Tensor +from torch.distributed.tensor import DTensor, Shard +from torch.distributed.tensor.placement_types import Replicate + + +def gather_along_cp(dtensor: DTensor) -> Tensor: + """Gather a DTensor over CP dimensions, keeping the DP shard, then unwrap to a plain Tensor. + + Redistributes CP mesh dimensions (all except dim 0) to Replicate while + preserving Shard(0) on the DP dimension. The returned tensor is the + local DP slice with full spatial extent. + + Parameters + ---------- + dtensor : DTensor + Input distributed tensor with arbitrary placements. + Expects the first mesh dimension to be DP. + + Returns + ------- + Tensor + The CP-gathered plain tensor, local to this DP rank. + """ + expected_mesh_dim_names = ("dp", "cp_axis_0", "cp_axis_1") + if dtensor.device_mesh.mesh_dim_names != expected_mesh_dim_names: + raise ValueError( + "gather_along_cp expects device mesh dim names " + f"{expected_mesh_dim_names}, got {dtensor.device_mesh.mesh_dim_names}." + ) + target_placements = [Shard(0)] + [Replicate()] * (dtensor.device_mesh.ndim - 1) + gathered = dtensor.redistribute(dtensor.device_mesh, target_placements) + return gathered.to_local() diff --git a/src/boltz/distributed/model/validation/validator.py b/src/boltz/distributed/model/validation/validator.py new file mode 100644 index 000000000..68562664e --- /dev/null +++ b/src/boltz/distributed/model/validation/validator.py @@ -0,0 +1,754 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +"""Distributed validator for Boltz-2 with context parallelism. + +Extends the serial :class:`~boltz.model.validation.validator.Validator` to +handle DTensor inputs. Each metric computation method is overridden to +selectively all-gather only the features it needs, then delegate to the +serial/triton implementation on plain tensors. + +Architecture: + DistributedValidator(Validator) + - Overrides: run_model, compute_disto_loss, compute_disto_lddt, + get_lddt_metrics, get_clash_metrics, get_pb_metrics, + get_confidence_metrics, common_val_step, common_on_epoch_end + - Each metric override: gather DTensor features → call serial function → + return plain-tensor results + - Metric storage, update, and epoch-end aggregation remain in the serial + base class operating on plain tensors. +""" + +from collections import defaultdict +from typing import Optional + +import torch +from pytorch_lightning import LightningModule +from torch import Tensor, nn +from torch.distributed.tensor import DTensor +from torchmetrics import MeanMetric + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.model.layers.atom_to_token import ( + reconstruct_atom_to_token_global, + reconstruct_r_set_to_rep_atom_global, + reconstruct_token_to_rep_atom_global, + single_repr_rep_atom_to_token, +) +from boltz.distributed.model.loss.distogram import distogram_loss +from boltz.distributed.model.loss.validation import ( + clash_score, + compute_disto_lddt, + compute_plddt_mae_triton, + get_lddt_metrics, +) +from boltz.distributed.model.validation.utils import gather_along_cp +from boltz.model.loss.validation import ( + compute_pae_mae, + compute_pde_mae, + weighted_minimum_rmsd_single, +) +from boltz.model.validation.validator import Validator + + +class DistributedValidator(Validator): + """Distributed validator that handles DTensor inputs for Boltz-2 CP. + + Overrides metric computation methods from :class:`Validator` to + selectively all-gather DTensor features before calling the serial + implementations. Metric storage and epoch-end aggregation remain in + the serial base class. + """ + + def __init__( + self, + val_names: list[str], + confidence_prediction: bool = False, + physicalism_metrics: bool = False, + rmsd_metrics: bool = False, + clash_score_metrics: bool = False, + override_val_method: Optional[str] = None, + ) -> None: + """Initialize the distributed validator. + + Parameters + ---------- + val_names : list[str] + The list of validation names. + confidence_prediction : bool + Whether to predict confidence. + physicalism_metrics : bool + Whether to compute physicalism metrics. + rmsd_metrics : bool + Whether to compute rmsd metrics. + clash_score_metrics : bool + Whether to compute clash score metrics. + override_val_method : Optional[str] + The override validation method. + """ + super().__init__( + val_names=val_names, + confidence_prediction=confidence_prediction, + physicalism_metrics=physicalism_metrics, + override_val_method=override_val_method, + ) + self.rmsd_metrics = rmsd_metrics + self.clash_score_metrics = clash_score_metrics + + if rmsd_metrics: + self.folding_metrics["rmsd"] = nn.ModuleList([nn.ModuleDict() for _ in range(self.num_val_datasets)]) + for val_idx in range(self.num_val_datasets): + self.folding_metrics["rmsd"][val_idx]["rmsd"] = MeanMetric() + + if clash_score_metrics: + self.folding_metrics["clash_score"] = nn.ModuleList([nn.ModuleDict() for _ in range(self.num_val_datasets)]) + for val_idx in range(self.num_val_datasets): + self.folding_metrics["clash_score"][val_idx]["clash_atoms_count"] = MeanMetric() + self.folding_metrics["clash_score"][val_idx]["clash_atoms_fraction"] = MeanMetric() + + # In our CP code, lightning is blind to our distributed computation context + # so MeanMetric is strictly single device and we should not rely on + # MeanMetric.compute() to get inter-DP all_reduce mean. + for m in self.modules(): + if isinstance(m, MeanMetric): + m.sync_on_compute = False + m._to_sync = False + + def run_model( + self, + model: LightningModule, + batch: dict[str, DTensor], + idx_dataset: int, + ) -> dict[str, DTensor]: + """Compute the forward pass using the distributed model. + + Parameters + ---------- + model : LightningModule + The distributed Boltz2 LightningModule. + batch : dict[str, DTensor] + Batch features as DTensors. + idx_dataset : int + Dataset index. + + Returns + ------- + dict[str, DTensor] + Model outputs as DTensors. + """ + if self.override_val_method is not None: + raise NotImplementedError("Override validation method is not supported for distributed validation") + # from boltz.distributed.model.layers.elementwise_op import ElementwiseOp, scalar_tensor_op + # new_feature = scalar_tensor_op(0.0, batch["method_feature"], ElementwiseOp.PROD) + # new_feature = scalar_tensor_op(self.override_val_method, new_feature, ElementwiseOp.SUM) + # batch["method_feature"] = new_feature + + out = model( + batch, + recycling_steps=model.validation_args.recycling_steps, + num_sampling_steps=model.validation_args.sampling_steps, + diffusion_samples=model.validation_args.diffusion_samples, + run_confidence_sequentially=model.validation_args.get("run_confidence_sequentially", False), + ) + return out + + def compute_disto_loss( + self, + model: LightningModule, + out: dict[str, DTensor], + batch: dict[str, DTensor], + idx_dataset: int, + transpose_comm: TransposeComm, + ) -> Tensor: + """Compute distogram loss using DTensor-native implementation. + + Uses the existing distributed distogram loss which operates directly + on DTensors without requiring all-gather. + """ + val_disto_loss, _ = distogram_loss( + out, batch, comm=transpose_comm, aggregate_distogram=model.aggregate_distogram + ) + return val_disto_loss.to_local() + + def compute_disto_lddt( + self, + model: LightningModule, + batch: dict[str, Tensor], + out: dict[str, Tensor], + idx_dataset: int, + ) -> tuple[dict, dict]: + """Compute distogram lddt.""" + disto_lddt_dict, disto_total_dict = compute_disto_lddt(model, batch, out) + return disto_lddt_dict, disto_total_dict + + def get_lddt_metrics( + self, + *args, + **kwargs, + ) -> None: + """Override the serial get_lddt_metrics method.""" + raise NotImplementedError("DistributedValidator does not need to implement get_lddt_metrics") + + def get_clash_metrics( + self, + batch: dict[str, DTensor], + out: dict[str, DTensor], + batch_gathered: dict[str, Tensor], + out_gathered: dict[str, Tensor], + ) -> tuple[dict, dict]: + """Compute clash metrics by gathering features at global atom dimension. + + Reuses unstripped ``asym_id``, ``atom_pad_mask``, and + ``sample_atom_coords`` from ``batch_gathered``/``out_gathered``. + Gathers ``atom_to_token`` and ``ref_element`` on demand. + Non-sharded features pass through directly from ``batch``. + """ + clash_feats = { + "asym_id": batch_gathered["asym_id"], + "atom_to_token": reconstruct_atom_to_token_global(batch["atom_to_token"]), + "ref_element": gather_along_cp(batch["ref_element"]), + "atom_pad_mask": batch_gathered["atom_pad_mask"], + "connections_edge_index": batch["connections_edge_index"], + "chain_symmetries": batch["chain_symmetries"], + } + clash_out = {"sample_atom_coords": out_gathered["sample_atom_coords"]} + + result = super().get_clash_metrics(clash_feats, clash_out) + del clash_feats, clash_out + return result + + def get_pb_metrics( + self, + batch: dict[str, DTensor], + out: dict[str, DTensor], + batch_gathered: dict[str, Tensor], + out_gathered: dict[str, Tensor], + ) -> tuple[dict, dict]: + """Compute PB metrics by gathering features at global atom dimension. + + Reuses unstripped ``asym_id``, ``mol_type``, and + ``sample_atom_coords`` from ``batch_gathered``/``out_gathered``. + Gathers ``atom_to_token`` on demand. Ligand features are not + sharded and pass through directly from ``batch``. + """ + pb_feats = { + "asym_id": batch_gathered["asym_id"], + "atom_to_token": reconstruct_atom_to_token_global(batch["atom_to_token"]), + "mol_type": batch_gathered["mol_type"], + } + _LIGAND_KEYS = ( + "ligand_edge_index", + "ligand_edge_lower_bounds", + "ligand_edge_upper_bounds", + "ligand_edge_bond_mask", + "ligand_edge_angle_mask", + "ligand_chiral_atom_index", + "ligand_chiral_check_mask", + "ligand_chiral_atom_orientations", + "ligand_stereo_bond_index", + "ligand_stereo_check_mask", + "ligand_stereo_bond_orientations", + "ligand_aromatic_5_ring_index", + "ligand_aromatic_6_ring_index", + "ligand_planar_double_bond_index", + ) + for k in _LIGAND_KEYS: + if k in batch: + pb_feats[k] = batch[k] + pb_out = {"sample_atom_coords": out_gathered["sample_atom_coords"]} + + result = super().get_pb_metrics(pb_feats, pb_out) + del pb_feats, pb_out + return result + + def get_confidence_metrics( + self, + model: LightningModule, + batch: dict[str, Tensor], + out: dict[str, Tensor], + idx_dataset: int, + n_samples: int, + true_coords: Tensor, + true_coords_resolved_mask: Tensor, + expand_to_diffusion_samples: bool, + batch_gathered: dict[str, Tensor], + out_gathered: dict[str, Tensor], + ): + """Compute confidence metrics using triton pLDDT and serial PDE/PAE. + + Uses :func:`compute_plddt_mae_triton` for pLDDT MAE (avoids + materialising the full N_token x N_R_set distance/mask matrices), + and delegates to the serial ``compute_pde_mae`` / ``compute_pae_mae`` + for PDE and PAE. ``token_to_rep_atom`` is deleted after pLDDT + computation to free memory. + """ + + atom_pad_mask_1d = batch_gathered["atom_pad_mask_1d"] + token_pad_mask_1d = batch_gathered["token_pad_mask_1d"] + + # Strip shared features from global to valid-only dimensions. + # true_coords / true_coords_resolved_mask are passed as parameters; + # rebind to stripped local copies so the caller's references stay global. + true_coords = true_coords[..., atom_pad_mask_1d, :] + true_coords_resolved_mask = true_coords_resolved_mask[..., atom_pad_mask_1d] + batch_gathered["mol_type"] = batch_gathered["mol_type"][:, token_pad_mask_1d] + batch_gathered["asym_id"] = batch_gathered["asym_id"][:, token_pad_mask_1d] + batch_gathered["token_pad_mask"] = batch_gathered["token_pad_mask"][:, token_pad_mask_1d] + batch_gathered["atom_pad_mask"] = batch_gathered["atom_pad_mask"][:, atom_pad_mask_1d] + out_gathered["sample_atom_coords"] = out_gathered["sample_atom_coords"][:, atom_pad_mask_1d] + + K = batch["coords"].shape[1] # ensemble dim is not sharded + msg = "Confidence_prediction is not supported for num_ensembles_val > 1" + assert K == 1, msg + + # Reconstruct atom_to_token + batch_gathered["atom_to_token"] = reconstruct_atom_to_token_global(batch["atom_to_token"])[ + :, atom_pad_mask_1d, : + ][:, :, token_pad_mask_1d] + + # Gather token-level confidence output + out_gathered["plddt"] = gather_along_cp(out["plddt"])[:, token_pad_mask_1d] + + # Gather pair-level confidence outputs (pde, pae) + for key in ("pde", "pae"): + out_gathered[key] = gather_along_cp(out[key])[:, token_pad_mask_1d, :][:, :, token_pad_mask_1d] + + # Reconstruct diagonally-sharded mapping matrices + batch_gathered["token_to_rep_atom"] = reconstruct_token_to_rep_atom_global(batch["token_to_rep_atom"])[ + :, token_pad_mask_1d, : + ][:, :, atom_pad_mask_1d] + + r_set_to_rep_atom_gathered = reconstruct_r_set_to_rep_atom_global(batch["r_set_to_rep_atom"]) + # r_set_to_rep_atom_gathered dim (dim 1) may have per-shard padding; strip rows that are all zeros + r_set_to_rep_atom_gathered_valid = r_set_to_rep_atom_gathered.any(dim=-1) # [B, N_R_global], here B is 1 + batch_gathered["r_set_to_rep_atom"] = r_set_to_rep_atom_gathered[:, r_set_to_rep_atom_gathered_valid[0], :][ + :, :, atom_pad_mask_1d + ] + + # Gather ensemble-aware frame features. + frames_idx_gathered = gather_along_cp(batch["frames_idx"]) + if frames_idx_gathered.ndim == 4: + # (B, E=1, T_padded, 3) → squeeze ensemble dim → (B, T_padded, 3) + frames_idx_gathered = frames_idx_gathered.squeeze(1) + batch_gathered["frames_idx"] = frames_idx_gathered[:, token_pad_mask_1d] + + frame_resolved_mask_gathered = gather_along_cp(batch["frame_resolved_mask"]) + if frame_resolved_mask_gathered.ndim == 3: + # (B, E=1, T_padded) → squeeze ensemble dim → (B, T_padded) + frame_resolved_mask_gathered = frame_resolved_mask_gathered.squeeze(1) + batch_gathered["frame_resolved_mask"] = frame_resolved_mask_gathered[:, token_pad_mask_1d] + + mae_plddt_dicts: dict[str, list] = defaultdict(list) + total_mae_plddt_dicts: dict[str, list] = defaultdict(list) + mae_pde_dicts: dict[str, list] = defaultdict(list) + total_mae_pde_dicts: dict[str, list] = defaultdict(list) + mae_pae_dicts: dict[str, list] = defaultdict(list) + total_mae_pae_dicts: dict[str, list] = defaultdict(list) + + if not expand_to_diffusion_samples: + true_coords_resolved_mask = true_coords_resolved_mask.unsqueeze(0).repeat((n_samples, 1)) + + for ensemble_idx in range(K): + if expand_to_diffusion_samples: + true_coords_k = true_coords[:, ensemble_idx] + else: + true_coords_k = true_coords[ensemble_idx].unsqueeze(0).repeat((n_samples, 1, 1)) + + # pLDDT MAE via triton cdist_lddt (rectangular, per_atom) + mae_plddt_dict, total_mae_plddt_dict = compute_plddt_mae_triton( + pred_atom_coords=out_gathered["sample_atom_coords"], + feats=batch_gathered, + true_atom_coords=true_coords_k, + pred_lddt=out_gathered["plddt"], + true_coords_resolved_mask=true_coords_resolved_mask, + multiplicity=n_samples, + ) + for key in mae_plddt_dict: + mae_plddt_dicts[key].append(mae_plddt_dict[key]) + total_mae_plddt_dicts[key].append(total_mae_plddt_dict[key]) + + if ensemble_idx == K - 1: + del batch_gathered["r_set_to_rep_atom"], out_gathered["plddt"] + + # PDE MAE via serial implementation + mae_pde_dict, total_mae_pde_dict = compute_pde_mae( + pred_atom_coords=out_gathered["sample_atom_coords"], + feats=batch_gathered, + true_atom_coords=true_coords_k, + pred_pde=out_gathered["pde"], + true_coords_resolved_mask=true_coords_resolved_mask, + multiplicity=n_samples, + ) + for key in mae_pde_dict: + mae_pde_dicts[key].append(mae_pde_dict[key]) + total_mae_pde_dicts[key].append(total_mae_pde_dict[key]) + + if ensemble_idx == K - 1: + del batch_gathered["token_to_rep_atom"], out_gathered["pde"] + + # PAE MAE via serial implementation + mae_pae_dict, total_mae_pae_dict = compute_pae_mae( + pred_atom_coords=out_gathered["sample_atom_coords"], + feats=batch_gathered, + true_atom_coords=true_coords_k, + pred_pae=out_gathered["pae"], + true_coords_resolved_mask=true_coords_resolved_mask, + multiplicity=n_samples, + ) + for key in mae_pae_dict: + mae_pae_dicts[key].append(mae_pae_dict[key]) + total_mae_pae_dicts[key].append(total_mae_pae_dict[key]) + + if ensemble_idx == K - 1: + del ( + batch_gathered["atom_to_token"], + batch_gathered["mol_type"], + batch_gathered["asym_id"], + batch_gathered["frames_idx"], + batch_gathered["frame_resolved_mask"], + batch_gathered["token_pad_mask"], + batch_gathered["atom_pad_mask"], + batch_gathered["token_pad_mask_1d"], + batch_gathered["atom_pad_mask_1d"], + out_gathered["sample_atom_coords"], + out_gathered["pae"], + ) + + # Mean over ensembles + for key in mae_plddt_dicts: + mae_plddt_dicts[key] = torch.stack(mae_plddt_dicts[key], dim=0).mean(dim=0) + total_mae_plddt_dicts[key] = torch.stack(total_mae_plddt_dicts[key], dim=0).mean(dim=0) + + for key in mae_pde_dicts: + mae_pde_dicts[key] = torch.stack(mae_pde_dicts[key], dim=0).mean(dim=0) + total_mae_pde_dicts[key] = torch.stack(total_mae_pde_dicts[key], dim=0).mean(dim=0) + + for key in mae_pae_dicts: + mae_pae_dicts[key] = torch.stack(mae_pae_dicts[key], dim=0).mean(dim=0) + total_mae_pae_dicts[key] = torch.stack(total_mae_pae_dicts[key], dim=0).mean(dim=0) + + return ( + mae_plddt_dicts, + total_mae_plddt_dicts, + mae_pde_dicts, + total_mae_pde_dicts, + mae_pae_dicts, + total_mae_pae_dicts, + ) + + def _dp_all_reduce_metrics(self, dp_group: torch.distributed.ProcessGroup) -> None: + """All-reduce MeanMetric internal states across DP ranks. + + In our CP code, lightning is blind to our distributed computation context + so MeanMetric is strictly single device and we should not rely on + MeanMetric.compute() to get inter-DP all_reduce mean. + + Each DP rank accumulates metrics from its own subset of validation + batches, so ``MeanMetric.mean_value`` (weighted sum) and + ``MeanMetric.weight`` (total weight) are local. Before calling + ``.compute()`` we must sum both across DP ranks so the resulting + ratio is the global weighted mean — matching the dev-v2 + ``_DP_all_reduce_mean`` semantics. + + This also handles the key-synchronisation problem: different DP + ranks may see different molecular modalities, so some + ``MeanMetric`` instances may have ``weight == 0`` on a rank that + never saw that modality. After all-reduce, a metric with + ``weight == 0`` globally will produce ``NaN`` from ``.compute()``, + which the serial ``common_on_epoch_end`` already maps to ``0.0``. + + """ + for m in self.modules(): + if isinstance(m, MeanMetric): + torch.distributed.all_reduce(m.mean_value, op=torch.distributed.ReduceOp.SUM, group=dp_group) + torch.distributed.all_reduce(m.weight, op=torch.distributed.ReduceOp.SUM, group=dp_group) + m._computed = None + elif not isinstance(m, (nn.ModuleDict, nn.ModuleList, DistributedValidator)): + raise ValueError(f"Only support MeanMetric, got {type(m)}") + + def common_on_epoch_end(self, model: LightningModule) -> None: + """Aggregate metrics at epoch end with DP all-reduce. + + 1. All-reduces every ``MeanMetric``'s internal ``mean_value`` and + ``weight`` across DP ranks so that ``.compute()`` returns the + global weighted mean (equivalent to dev-v2's + ``_DP_all_reduce_mean``). + 2. Wraps ``model.log`` with ``sync_dist=False`` to avoid NCCL + collectives inside Lightning's logging (which can deadlock when + validation batches are unevenly distributed across DP ranks). + 3. Delegates to the serial ``common_on_epoch_end`` for the actual + compute / log / reset cycle. + """ + # In our CP code, lightning is blind to our distributed computation context + # so MeanMetric is strictly single device and we should not rely on + # MeanMetric.compute() to get inter-DP all_reduce mean. + self._dp_all_reduce_metrics(model.dp_group) + + original_log = model.log + + def _log_no_sync(*args, **kwargs): + kwargs["sync_dist"] = False + return original_log(*args, **kwargs) + + model.log = _log_no_sync # type: ignore[assignment] + try: + for idx_dataset in range(self.num_val_datasets): + dataset_name_ori = self.val_names[idx_dataset] + dataset_name = "" if dataset_name_ori == "RCSB" else f"__{dataset_name_ori}" + if self.clash_score_metrics: + for m in ("clash_atoms_count", "clash_atoms_fraction"): + val = self.folding_metrics["clash_score"][idx_dataset][m].compute() + val = 0.0 if torch.isnan(val) else val.item() + self.folding_metrics["clash_score"][idx_dataset][m].reset() + model.log(f"val/{m}{dataset_name}", val) + if self.rmsd_metrics: + avg_rmsd = self.folding_metrics["rmsd"][idx_dataset]["rmsd"].compute() + avg_rmsd = 0.0 if torch.isnan(avg_rmsd) else avg_rmsd.item() + self.folding_metrics["rmsd"][idx_dataset]["rmsd"].reset() + model.log(f"val/rmsd{dataset_name}", avg_rmsd) + super().common_on_epoch_end(model) + finally: + model.log = original_log # type: ignore[assignment] + + def common_val_step( + self, + model: LightningModule, + batch: dict[str, DTensor], + out: dict[str, DTensor], + idx_dataset: int, + expand_to_diffusion_samples: bool, + transpose_comm: TransposeComm, + ) -> None: + """Run a common validation step with DTensor inputs. + + Gathers DTensor features once, then delegates metric computation + to the serial base class methods operating on plain tensors. + + Parameters + ---------- + model : LightningModule + The distributed Boltz2 model (used for accessing serial + ``get_true_coordinates`` via the model_serial attribute). + batch : dict[str, DTensor] + Batch features as DTensors. + out : dict[str, DTensor] + Model outputs as DTensors. + idx_dataset : int + Global dataset index. + expand_to_diffusion_samples : bool + Whether to expand coordinates to diffusion samples. + """ + symmetry_correction = model.val_group_mapper[idx_dataset]["symmetry_correction"] + idx_dataset = self.get_local_val_index(model, idx_dataset) + n_samples = model.validation_args.diffusion_samples + + # Compute distogram loss and update metrics + val_disto_loss = self.compute_disto_loss(model, out, batch, idx_dataset, transpose_comm) + self.folding_metrics["disto_loss"][idx_dataset]["disto_loss"].update(val_disto_loss) + + # Compute distogram lddt and update metrics + disto_lddt_dict, disto_total_dict = self.compute_disto_lddt(model, batch, out, idx_dataset) + + token_pad_mask = gather_along_cp(batch["token_pad_mask"].bool()) + atom_pad_mask = gather_along_cp(batch["atom_pad_mask"].bool()) + token_pad_mask_1d = token_pad_mask[0] + atom_pad_mask_1d = atom_pad_mask[0] + + # Get true coords (DTensors) and gather to plain tensors. + # All gathered tensors are kept at global (unstripped) dimensions so + # that lDDT (atom_to_token is global) and clash/PB (atom indices are + # global) can consume them directly. Confidence metrics strip + # internally via the 1D pad masks. + return_dict = self.get_true_coords( + model, + batch, + out, + n_samples, + symmetry_correction, + expand_to_diffusion_samples=expand_to_diffusion_samples, + ) + true_coords = gather_along_cp(return_dict["true_coords"]) + true_coords_resolved_mask = gather_along_cp(return_dict["true_coords_resolved_mask"]) + + mol_type = gather_along_cp(batch["mol_type"]) + asym_id = gather_along_cp(batch["asym_id"]) + sample_atom_coords = gather_along_cp(out["sample_atom_coords"]) + + batch_gathered = { + "token_pad_mask": token_pad_mask, + "atom_pad_mask": atom_pad_mask, + "token_pad_mask_1d": token_pad_mask_1d, + "atom_pad_mask_1d": atom_pad_mask_1d, + "mol_type": mol_type, + "asym_id": asym_id, + } + out_gathered = { + "sample_atom_coords": sample_atom_coords, + } + + # Get lddt metrics (all inputs at global dimensions) + K = batch["coords"].shape[1] + if expand_to_diffusion_samples: + resolved_mask_for_lddt = true_coords_resolved_mask + else: + resolved_mask_for_lddt = true_coords_resolved_mask.squeeze(0) + all_lddt_dict, all_total_dict = get_lddt_metrics( + atom_to_token_dtensor=batch["atom_to_token"], + num_conformers=K, + n_samples=n_samples, + true_coords=true_coords, + true_coords_resolved_mask=resolved_mask_for_lddt, + mol_type=mol_type, + asym_id=asym_id, + sample_atom_coords=sample_atom_coords, + expand_to_diffusion_samples=expand_to_diffusion_samples, + ) + + # Get physical realism metrics + if self.physicalism_metrics: + pair_clash_dict, pair_total_dict = self.get_clash_metrics(batch, out, batch_gathered, out_gathered) + pb_failure_dict, pb_total_dict = self.get_pb_metrics(batch, out, batch_gathered, out_gathered) + else: + pair_clash_dict, pair_total_dict = None, None + pb_failure_dict, pb_total_dict = None, None + + # Filtering based on confidence + if model.confidence_prediction and n_samples > 1: + ( + mae_plddt_dicts, + total_mae_plddt_dicts, + mae_pde_dicts, + total_mae_pde_dicts, + mae_pae_dicts, + total_mae_pae_dicts, + ) = self.get_confidence_metrics( + model, + batch, + out, + idx_dataset, + n_samples, + true_coords, + true_coords_resolved_mask, + expand_to_diffusion_samples, + batch_gathered, + out_gathered, + ) + + # Compute RMSD on gathered plain tensors (first conformer only) + if self.rmsd_metrics: + atom_to_token_global = reconstruct_atom_to_token_global(batch["atom_to_token"]) + + if expand_to_diffusion_samples: + atom_coords_rmsd = true_coords[:, 0] + resolved_mask_rmsd = true_coords_resolved_mask + else: + atom_coords_rmsd = true_coords[0].unsqueeze(0).repeat((n_samples, 1, 1)) + resolved_mask_rmsd = true_coords_resolved_mask.repeat((n_samples, 1)) + + rmsd_val, _, _ = weighted_minimum_rmsd_single( + pred_atom_coords=sample_atom_coords, + atom_coords=atom_coords_rmsd, + atom_mask=resolved_mask_rmsd.float(), + atom_to_token=atom_to_token_global.expand(n_samples, -1, -1) + if atom_to_token_global.shape[0] == 1 and n_samples > 1 + else atom_to_token_global, + mol_type=mol_type.expand(n_samples, -1) if mol_type.shape[0] == 1 and n_samples > 1 else mol_type, + ) + del atom_to_token_global, atom_coords_rmsd, resolved_mask_rmsd + + if self.clash_score_metrics: + clash_cutoff = model.validation_args.get("clash_cutoff") + if clash_cutoff: + coords_repr = gather_along_cp( + single_repr_rep_atom_to_token( + out["sample_atom_coords"], + batch["token_to_rep_atom"], + ) + ) + clash_count, clash_frac = clash_score( + coords_repr=coords_repr, + token_pad_mask=token_pad_mask, + multiplicity=n_samples, + clash_cutoff=clash_cutoff, + ) + self.folding_metrics["clash_score"][idx_dataset]["clash_atoms_count"].update(clash_count.mean()) + self.folding_metrics["clash_score"][idx_dataset]["clash_atoms_fraction"].update(clash_frac.mean()) + + # Update folding metrics + self.update_lddt_rmsd_metrics( + {"contact_conditioning": gather_along_cp(batch["contact_conditioning"]), "coords": batch["coords"]}, + all_lddt_dict, + all_total_dict, + disto_lddt_dict, + disto_total_dict, + idx_dataset, + ) + if self.rmsd_metrics: + self.folding_metrics["rmsd"][idx_dataset]["rmsd"].update(rmsd_val.mean()) + + # Update physical realism metrics + if self.physicalism_metrics: + self.update_physcialism_metrics( + pair_clash_dict, + pair_total_dict, + pb_failure_dict, + pb_total_dict, + idx_dataset, + ) + + # Update confidence metrics + if model.confidence_prediction and n_samples > 1: + # Pass through scalar confidence outputs (already replicated) to plain tensors + confidence_out = {} + for key in ( + "complex_plddt", + "complex_iplddt", + "complex_pde", + "complex_ipde", + "ptm", + "iptm", + "ligand_iptm", + "protein_iptm", + ): + if key in out: + val = out[key] + confidence_out[key] = val.to_local() if isinstance(val, DTensor) else val + + self.update_confidence_metrics( + batch, + confidence_out, + idx_dataset, + n_samples, + all_lddt_dict, + all_total_dict, + mae_plddt_dicts, + total_mae_plddt_dicts, + mae_pde_dicts, + total_mae_pde_dicts, + mae_pae_dicts, + total_mae_pae_dicts, + pair_clash_dict, + pair_total_dict, + pb_failure_dict, + pb_total_dict, + physicalism_metrics=self.physicalism_metrics, + ) diff --git a/src/boltz/distributed/predict.py b/src/boltz/distributed/predict.py new file mode 100644 index 000000000..5912461d8 --- /dev/null +++ b/src/boltz/distributed/predict.py @@ -0,0 +1,500 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Distributed inference entrypoint for Boltz-2 with DTensor context parallelism. + +Adapted from the Boltz-1x-CP ``distributed/main.py::run_predict`` function. +Differences from the Boltz-1 version: + +- Uses :class:`Boltz2` (serial) + :class:`Boltz2Distributed` wrapper. +- Uses :class:`Boltz2InferenceDataModuleDTensor` (v2 featurizer/tokenizer). +- Requires ``mol_dir`` for per-residue CCD molecule files. +- Precision default is ``bf16-mixed`` (matching serial Boltz-2 inference). +""" + +from __future__ import annotations + +import atexit +import warnings +from collections import OrderedDict +from dataclasses import asdict +from datetime import timedelta +from math import isqrt +from pathlib import Path +from typing import Any, Literal, Optional + +import torch +from lightning_fabric.plugins.precision.utils import _convert_fp_tensor +from lightning_utilities import apply_to_collection +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.plugins.precision import HalfPrecision +from pytorch_lightning.strategies import SingleDeviceStrategy + +from boltz.data import const +from boltz.data.types import Manifest +from boltz.data.write.writer import BoltzWriter +from boltz.distributed.data.module.inferencev2 import Boltz2InferenceDataModuleDTensor +from boltz.distributed.data.types import PairMaskMode +from boltz.distributed.data.utils import map_subgroup_mesh_to_cpu +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.triangular_attention import can_run_cueq_triattn_sm100f +from boltz.distributed.model.models.boltz2 import Boltz2 as Boltz2Distributed +from boltz.distributed.model.modules.utils import ( + PRECISION_TO_LIGHTNING, + Precision, + SDPAWithBiasBackend, + SetAttnPairBiasBackend, + SetAttnPairBiasShardwiseBackend, + SetTriAttnBackend, + TriAttnBackend, + setup_tf32_env, +) +from boltz.main import ( + Boltz2DiffusionParams, + BoltzProcessedInput, + check_inputs, + filter_inputs_structure, + process_inputs, +) +from boltz.model.models.boltz2 import Boltz2 as Boltz2Serial + + +class HalfPrecisionAllowFrozen(HalfPrecision): + """Half precision plugin that handles Boltz-2's frozen dataclasses. + + Lightning's default ``HalfPrecision`` plugin raises ``ValueError`` when + ``apply_to_collection`` encounters a frozen dataclass in the batch. This + subclass overrides ``convert_input`` to pass ``allow_frozen=True``, which + lets Lightning recurse into frozen dataclass fields without trying to + ``setattr`` on the frozen instance. + """ + + def convert_input(self, data: Any) -> Any: + return apply_to_collection( + data, + function=_convert_fp_tensor, + dtype=torch.Tensor, + dst_type=self._desired_input_dtype, + allow_frozen=True, + ) + + +def run_predict( + data: str | Path, + out_dir: str | Path, + mol_dir: str | Path, + checkpoint: str | Path, + size_dp: int = 1, + size_cp: int = 1, + accelerator: str = "gpu", + recycling_steps: int = 3, + sampling_steps: int = 200, + diffusion_samples: int = 1, + max_parallel_samples: Optional[int] = None, + step_scale: float = 1.5, + output_format: Literal["pdb", "mmcif"] = "mmcif", + seed: Optional[int] = None, + max_msa_seqs: int = const.max_msa_seqs, + msa_pad_to_max_seqs: bool = False, + input_format: str = "preprocessed", + timeout_nccl: Optional[float] = 30, + timeout_gloo: Optional[float] = 30, + precision: Precision = Precision.BF16_MIXED, + atoms_per_window_queries_keys: tuple[Optional[int], Optional[int]] = (32, 128), + pair_mask_mode: PairMaskMode = PairMaskMode.NONE, + local_batch_size: int = 1, + num_ensembles: int = 1, + extra_callbacks: Optional[list[Callback]] = None, + confidence_prediction: bool = True, + write_full_pae: bool = False, + use_templates: bool = True, + triattn_backend: TriAttnBackend = TriAttnBackend.CUEQ, + sdpa_with_bias_backend: SDPAWithBiasBackend = SDPAWithBiasBackend.TORCH_FLEX_ATTN, + sdpa_with_bias_shardwise_backend: SDPAWithBiasBackend = SDPAWithBiasBackend.TORCH_FLEX_ATTN, + auto_pad_tokens_for_sm100f: bool = True, + cuda_memory_profile: bool = False, + override: bool = False, + max_data_retries: int = 5, +) -> None: + """Run distributed Boltz-2 structure prediction with DTensor context parallelism. + + Parameters + ---------- + data : str or Path + Path to the input data. For ``input_format="preprocessed"``, a + directory containing manifest.json, structures/, msa/. For + ``input_format="config_files"``, a YAML/FASTA file or directory + of such files. + out_dir : str or Path + Output directory for predictions. + mol_dir : str or Path + Directory containing per-residue CCD molecule pickle files. + checkpoint : str or Path + Path to the Boltz-2 model checkpoint. + size_dp : int + Number of data-parallel ranks. + size_cp : int + Total number of context-parallel ranks (must be a perfect square). + accelerator : str + Device accelerator ("gpu" or "cpu"). + recycling_steps : int + Number of recycling iterations for the trunk. + sampling_steps : int + Number of diffusion denoising steps. + diffusion_samples : int + Number of independent diffusion samples per input. + max_parallel_samples : int or None + Max diffusion samples to run in parallel (None = all at once). + step_scale : float + Step scale for the diffusion schedule. + output_format : str + Output structure format ("pdb" or "mmcif"). + seed : int or None + Random seed for reproducibility. + max_msa_seqs : int + Maximum number of MSA sequences. + msa_pad_to_max_seqs : bool + Whether to pad MSA to max_msa_seqs. + input_format : str + Input data format ("preprocessed" or "config_files"). + timeout_nccl : float or None + NCCL timeout in minutes. + timeout_gloo : float or None + Gloo timeout in minutes. + precision : Precision + Model precision mode. + atoms_per_window_queries_keys : tuple + (queries, keys) window sizes for atom attention batching. + pair_mask_mode : PairMaskMode + Pair mask mode (NONE = window batching). + local_batch_size : int + Per-rank batch size. + num_ensembles : int + Number of ensemble members for structure prediction. + extra_callbacks : list[Callback] or None + Additional Lightning callbacks. + confidence_prediction : bool + Whether to run the confidence module (pLDDT, pTM, iPTM, PAE, PDE). + write_full_pae : bool + Whether to write full PAE matrices (requires confidence module). + use_templates : bool + Reserved for future use. Currently ignored — template weights are + always loaded from the checkpoint but the distributed TemplateModule + is not yet implemented, so templates are skipped during forward. + triattn_backend : TriAttnBackend + Backend for distributed triangle attention in PairformerLayer. + sdpa_with_bias_backend : SDPAWithBiasBackend + Backend for ring-attention AttentionPairBias layers. + sdpa_with_bias_shardwise_backend : SDPAWithBiasBackend + Backend for window-batched AttentionPairBiasShardwise layers. + auto_pad_tokens_for_sm100f : bool + When True and the cuEq TriAttn backend is selected with BF16 or + BF16_MIXED precision on an SM100/SM103 GPU, pad token counts so + that each CP shard has a multiple-of-8 sequence length, enabling + the SM100f cuEq TriAttn kernel. + cuda_memory_profile : bool + When True, records CUDA memory history and dumps a per-rank + snapshot pickle to ``out_dir`` at the end of prediction. + override : bool + When True, rerun predictions even if output already exists. + max_data_retries : int + Maximum number of retry attempts when a data sample fails to load. + Set to 0 to raise immediately on the first error. + + """ + atoms_per_window_queries, atoms_per_window_keys = atoms_per_window_queries_keys + + accelerator_to_device = {"cpu": "cpu", "gpu": "cuda"} + if accelerator not in accelerator_to_device: + raise ValueError(f"Accelerator {accelerator!r} not recognised; expected one of {sorted(accelerator_to_device)}") + device_type = accelerator_to_device[accelerator] + + if device_type == "cuda" and not torch.cuda.is_available(): + raise ValueError("accelerator='gpu' requires CUDA, but torch.cuda.is_available() is False") + + if triattn_backend in (TriAttnBackend.CUEQ, TriAttnBackend.TRIFAST) and ( + device_type != "cuda" or not torch.cuda.is_available() + ): + raise ValueError( + f"triattn_backend={triattn_backend.value!r} requires CUDA, " + f"but accelerator={accelerator!r} (device_type={device_type!r}), " + f"torch.cuda.is_available()={torch.cuda.is_available()}" + ) + + if not isinstance(size_cp, int) or size_cp <= 0: + raise TypeError("size_cp must be a positive integer") + if timeout_nccl is not None and timeout_nccl <= 0: + raise TypeError("timeout_nccl must be a positive float") + if timeout_gloo is not None and timeout_gloo <= 0: + raise TypeError("timeout_gloo must be a positive float") + if precision not in PRECISION_TO_LIGHTNING: + raise ValueError(f"Precision {precision} not supported") + + timeout_nccl_td = timedelta(minutes=timeout_nccl) if timeout_nccl is not None else None + timeout_gloo_td = timedelta(minutes=timeout_gloo) if timeout_gloo is not None else None + timeout_by_device = {"cpu": timeout_gloo_td, "cuda": timeout_nccl_td} + + # --- Distributed setup --- + DistributedManager.initialize(device_type=device_type, timeout=timeout_by_device[device_type]) + atexit.register(DistributedManager.cleanup) + dist_manager = DistributedManager() + + if size_dp * size_cp != dist_manager.world_size: + raise ValueError(f"world_size mismatch: {dist_manager.world_size} != size_dp*size_cp ({size_dp}*{size_cp})") + + size_cp_axis = isqrt(size_cp) + if size_cp_axis * size_cp_axis != size_cp: + raise ValueError(f"size_cp must be a perfect square, got {size_cp}") + + # Load checkpoint hparams to read existing pairformer_args / msa_args, + # then ensure the critical V2 flags are set without overriding model + # dimensions (num_heads, num_blocks, msa_s, etc.) with production defaults. + # + # Without the v2 flag, Boltz2Serial creates V1 attention layers (which use + # a different norm layout — 128 extra `norm_s` LayerNorms per pairformer + # layer). Under strict=False this was silent: the V1 norm parameters were + # randomly initialised and the V2 checkpoint weights were simply dropped, + # producing garbage predictions. We now merge the flag from the checkpoint + # and load with strict=True to catch any mismatch. + ckpt_raw = torch.load(str(checkpoint), map_location="cpu", weights_only=False, mmap=True) + ckpt_hp = ckpt_raw.get("hyper_parameters", {}) + del ckpt_raw + + pairformer_args = dict(ckpt_hp.get("pairformer_args", {})) + pairformer_args.setdefault("v2", True) + + msa_args = dict(ckpt_hp.get("msa_args", {})) + msa_args.setdefault("use_paired_feature", True) + + dim_hidden = pairformer_args.get("pairwise_head_width", 32) + sm100f_per_shard_token_multiple = 1 + if auto_pad_tokens_for_sm100f and triattn_backend == TriAttnBackend.CUEQ: + if precision in (Precision.BF16, Precision.BF16_MIXED): + if can_run_cueq_triattn_sm100f(dist_manager.device, torch.bfloat16, 8, dim_hidden, True): + sm100f_per_shard_token_multiple = 8 + + grid_group_sizes: OrderedDict[str, int | tuple[int, ...]] = OrderedDict( + [("dp", size_dp), ("cp", (size_cp_axis, size_cp_axis))] + ) + DistributedManager.create_grid_group(grid_group_sizes) + + DistributedManager.create_group( + "world_cpu", + dist_manager.group_ranks["world"], + backend="gloo", + use_local_synchronization=True, + timeout=timeout_by_device["cpu"], + ) + DistributedManager.create_group( + "cp_cpu", + dist_manager.group_ranks["cp"], + backend="gloo", + use_local_synchronization=True, + timeout=timeout_by_device["cpu"], + ) + device_mesh_cpu = map_subgroup_mesh_to_cpu(dist_manager) + + # --- Data loading --- + torch.set_grad_enabled(False) + if seed is not None: + seed_everything(seed) + + out_dir = Path(out_dir).expanduser() + mol_dir = Path(mol_dir).expanduser() + data = Path(data).expanduser() + + if dist_manager.group_rank["world"] == 0: + processed: Optional[BoltzProcessedInput] = None + try: + out_dir.mkdir(parents=True, exist_ok=True) + + if input_format == "config_files": + results_dir = out_dir / f"boltz_results_{data.stem}" + results_dir.mkdir(parents=True, exist_ok=True) + + data_files = check_inputs(data) + process_inputs( + data=data_files, + out_dir=results_dir, + ccd_path=None, + mol_dir=mol_dir, + use_msa_server=False, + msa_server_url="", + msa_pairing_strategy="greedy", + boltz2=True, + max_msa_seqs=max_msa_seqs, + ) + + manifest = Manifest.load(results_dir / "processed" / "manifest.json") + filtered_manifest = filter_inputs_structure( + manifest=manifest, + outdir=results_dir, + override=override, + ) + processed_dir = results_dir / "processed" + processed = BoltzProcessedInput( + manifest=filtered_manifest, + targets_dir=processed_dir / "structures", + msa_dir=processed_dir / "msa", + constraints_dir=( + (processed_dir / "constraints") if (processed_dir / "constraints").exists() else None + ), + template_dir=((processed_dir / "templates") if (processed_dir / "templates").exists() else None), + extra_mols_dir=((processed_dir / "mols") if (processed_dir / "mols").exists() else None), + ) + elif input_format == "preprocessed": + manifest = Manifest.load(data / "manifest.json") + filtered_manifest = filter_inputs_structure( + manifest=manifest, + outdir=out_dir, + override=override, + ) + processed = BoltzProcessedInput( + manifest=filtered_manifest, + targets_dir=data / "structures", + msa_dir=data / "msa", + constraints_dir=(data / "constraints") if (data / "constraints").exists() else None, + template_dir=(data / "templates") if (data / "templates").exists() else None, + extra_mols_dir=(data / "extra_mols") if (data / "extra_mols").exists() else None, + ) + else: + raise ValueError( + f"Unsupported input_format={input_format!r}; expected 'preprocessed' or 'config_files'" + ) + finally: + torch.distributed.broadcast_object_list([processed], src=0, group=dist_manager.group["world_cpu"]) + else: + processed_recv: list[Optional[BoltzProcessedInput]] = [None] + torch.distributed.broadcast_object_list(processed_recv, src=0, group=dist_manager.group["world_cpu"]) + processed = processed_recv[0] + if processed is None: + raise RuntimeError("Rank 0 failed during input processing; see rank 0 logs for the root cause.") + + out_dir = out_dir / f"boltz_results_{data.stem}" + if dist_manager.group_rank["world"] == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + data_module = Boltz2InferenceDataModuleDTensor( + manifest=processed.manifest, + target_dir=processed.targets_dir, + msa_dir=processed.msa_dir, + mol_dir=mol_dir, + num_workers=0, + device_mesh=dist_manager.device_mesh_subgroups, + device_mesh_cpu=device_mesh_cpu, + constraints_dir=None, + template_dir=processed.template_dir, + extra_mols_dir=processed.extra_mols_dir, + max_msa_seqs=max_msa_seqs, + msa_pad_to_max_seqs=msa_pad_to_max_seqs, + max_data_retries=max_data_retries, + pair_mask_mode=pair_mask_mode, + atoms_per_window_queries=atoms_per_window_queries, + atoms_per_window_keys=atoms_per_window_keys, + local_batch_size=local_batch_size, + num_ensembles=num_ensembles, + per_shard_token_multiple=sm100f_per_shard_token_multiple, + ) + + # --- Model loading --- + predict_args = { + "recycling_steps": recycling_steps, + "sampling_steps": sampling_steps, + "diffusion_samples": diffusion_samples, + "max_parallel_samples": max_parallel_samples, + "write_confidence_summary": confidence_prediction, + "write_full_pae": write_full_pae, + } + + diffusion_params = Boltz2DiffusionParams() + diffusion_params.step_scale = step_scale + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + model_serial: Boltz2Serial = Boltz2Serial.load_from_checkpoint( + str(checkpoint), + strict=True, + predict_args=predict_args, + map_location=dist_manager.device, + diffusion_process_args=asdict(diffusion_params), + pairformer_args=pairformer_args, + msa_args=msa_args, + confidence_prediction=confidence_prediction, + ema=False, + ) + model_serial.eval() + model_distributed = Boltz2Distributed(model_serial, dist_manager).eval() + model_distributed.apply(SetTriAttnBackend(triattn_backend)) + model_distributed.apply(SetAttnPairBiasBackend(sdpa_with_bias_backend)) + model_distributed.apply(SetAttnPairBiasShardwiseBackend(sdpa_with_bias_shardwise_backend)) + + # --- Callbacks --- + callbacks_lst: list[Callback] = [] + if dist_manager.group_rank["cp"] == 0: + name_subdir = f"predictions_dp{dist_manager.group_rank['dp']}_cp{dist_manager.group_rank['cp']}" + pred_writer = BoltzWriter( + data_dir=processed.targets_dir, + output_dir=out_dir / name_subdir, + output_format=output_format, + boltz2=True, + ) + callbacks_lst.append(pred_writer) + if cuda_memory_profile: + from boltz.workflow.utils import CUDAMemoryProfile + + mem_snapshot_path = out_dir / f"cuda_memory_profile_rank{dist_manager.group_rank['world']}.pickle" + callbacks_lst.append(CUDAMemoryProfile(output_path=mem_snapshot_path, max_entries=300000)) + if extra_callbacks is not None: + callbacks_lst.extend(extra_callbacks) + + callbacks: list[Callback] | None = callbacks_lst if callbacks_lst else None + + # --- Trainer --- + strategy = SingleDeviceStrategy(device=dist_manager.device) + devices = size_dp if size_cp == 1 else "auto" + precision_lightning = PRECISION_TO_LIGHTNING.get(precision, "bf16-mixed") + + plugins = None + if precision == Precision.BF16: + plugins = [HalfPrecisionAllowFrozen(precision_lightning)] + precision_lightning = None + + trainer = Trainer( + default_root_dir=out_dir, + strategy=strategy, + callbacks=callbacks, + accelerator=accelerator, + devices=devices, + precision=precision_lightning, + plugins=plugins, + ) + + if dist_manager.group_rank["world"] == 0: + print(f"Boltz-2 distributed inference: precision={precision}, dp={size_dp}, cp={size_cp}") + + with setup_tf32_env(precision): + trainer.predict( + model_distributed, + datamodule=data_module, + return_predictions=False, + ) + DistributedManager.cleanup() diff --git a/src/boltz/distributed/testing/utils.py b/src/boltz/distributed/testing/utils.py new file mode 100644 index 000000000..b2d09ac53 --- /dev/null +++ b/src/boltz/distributed/testing/utils.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from pathlib import Path + +import hydra +import torch +from omegaconf import OmegaConf +from torch import Tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor + +from boltz.data.module.trainingv2 import DataConfigV2 + +# Config file paths +ROOT_DIR = Path(__file__).resolve().parents[4] +CONFIG_FILE_BASE = ROOT_DIR / "scripts" / "train" / "configs" / "structurev2.yaml" + + +def create_atom_to_token_dtensor(atom_to_token_global: Tensor, device_mesh: DeviceMesh) -> DTensor: + """Create a distributed tensor for atom_to_token with proper placement. + + Args: + atom_to_token_global: Global atom_to_token tensor of shape (B, n_atoms, n_tokens) + device_mesh: DeviceMesh instance + + Returns: + DTensor: Distributed atom_to_token tensor with placement (Shard(0), Shard(1), Replicate()) + """ + # Get block diagonal chunk of atom_to_token_global + n_atoms, n_tokens = atom_to_token_global.shape[1:] + cp_axis_0_size = device_mesh.get_group("cp_axis_0").size() + + atom_to_token_local = [] + for cp_idx in range(cp_axis_0_size): + start_token_idx = cp_idx * n_tokens // cp_axis_0_size + end_token_idx = (cp_idx + 1) * n_tokens // cp_axis_0_size + start_atom_idx = cp_idx * n_atoms // cp_axis_0_size + end_atom_idx = (cp_idx + 1) * n_atoms // cp_axis_0_size + atom_to_token_local.append(atom_to_token_global[:, start_atom_idx:end_atom_idx, start_token_idx:end_token_idx]) + + atom_to_token_local = torch.cat(atom_to_token_local, dim=1) + + placements = (Shard(dim=0), Shard(dim=1), Replicate()) + atom_to_token_dtensor = distribute_tensor(atom_to_token_local, device_mesh, placements) + return atom_to_token_dtensor + + +def setup_mock_training_datamodule_config(test_data_dir: Path) -> DataConfigV2: + """Setup mock training datamodule configuration by loading and merging config files. + + Args: + test_data_dir: Base path for test data directory + + Returns: + Configured DataConfigV2 instance with test data paths + """ + config_dict = OmegaConf.load(CONFIG_FILE_BASE) + + # Provide temporary valid paths for required string fields before instantiate. + for dataset_cfg in config_dict.data.datasets: + dataset_cfg.target_dir = "." + dataset_cfg.msa_dir = "." + + # Instantiate the configuration + cfg = hydra.utils.instantiate(config_dict) + + data_config = DataConfigV2(**cfg.data) + + # Test data comes from RCSB only; keep a single dataset entry. + if len(data_config.datasets) > 1: + data_config.datasets = [data_config.datasets[0]] + + # Override paths to use the prepared Boltz2 training test dataset layout. + data_config.datasets[0].target_dir = str(test_data_dir) + data_config.datasets[0].msa_dir = str(test_data_dir / "msa") + data_config.datasets[0].split = None + data_config.datasets[0].template_dir = None + data_config.datasets[0].prob = 1.0 + + # Keep tests small and deterministic. + data_config.samples_per_epoch = 4 + data_config.num_workers = 0 + data_config.pin_memory = False + data_config.use_templates = False + + # Enable symmetry features for training to test symmetry feature broadcasting + data_config.return_train_symmetries = True + + return data_config diff --git a/src/boltz/distributed/train.py b/src/boltz/distributed/train.py new file mode 100644 index 000000000..a09304cd9 --- /dev/null +++ b/src/boltz/distributed/train.py @@ -0,0 +1,608 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Distributed training entrypoint for Boltz-2 with DTensor context parallelism. + +This module wires the full distributed training infrastructure: + +- Distributed process-group bootstrap via :class:`DistributedManager` +- :class:`Boltz2Distributed` model wrapping with DTensor CP +- :class:`Boltz2TrainingDataModule` with DTensor feature distribution +- :class:`BoltzContextParallelStrategy` for DTensor checkpoint save/load +- Checkpoint callback with Boltz-2 defaults (``monitor="val/lddt"``, etc.) +- Resume / pretrained-loading plumbing +- Precision configuration (``bf16``, ``bf16-mixed``, ``tf32``, etc.) +- WandB logging and config serialization + +Factory functions :func:`_create_distributed_model` and +:func:`_create_distributed_data_module` are extracted as module-level +functions so tests can monkeypatch them with lightweight smoke +implementations (see ``tests/distributed/test_dtensor_stop_and_go.py``). +""" + +from __future__ import annotations + +import atexit +import os +import random +import string +import sys +import warnings +from collections import OrderedDict +from dataclasses import dataclass +from datetime import timedelta +from math import isqrt +from pathlib import Path +from typing import Any, Optional + +import hydra +import omegaconf +import pytorch_lightning as pl +import torch +from hydra import compose, initialize_config_dir +from hydra.core.global_hydra import GlobalHydra +from omegaconf import OmegaConf +from pytorch_lightning import LightningModule, seed_everything +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.utilities import rank_zero_only + +try: + from one_logger_utils.pytorch_lightning import hook_trainer_cls # type: ignore[import-untyped] + + _one_logger_available = True +except ImportError: + _one_logger_available = False + +from boltz.data.module.trainingv2 import DataConfigV2 +from boltz.distributed.data.module.trainingv2 import Boltz2TrainingDataModule +from boltz.distributed.data.utils import map_subgroup_mesh_to_cpu +from boltz.distributed.lightning_strategy import BoltzContextParallelStrategy +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.models.boltz2 import Boltz2 as Boltz2Distributed +from boltz.distributed.model.modules.utils import ( + PRECISION_TO_LIGHTNING, + OffloadActvCkptToCPU, + Precision, + SDPAWithBiasBackend, + SetAttnPairBiasBackend, + SetAttnPairBiasShardwiseBackend, + SetTriAttnBackend, + TriAttnBackend, + setup_tf32_env, +) +from boltz.model.layers.attentionv2 import AttentionPairBias as AttentionPairBiasV2 +from boltz.workflow.utils import ( + _DATASET_KEYS_TO_OVERRIDE, + CUDAMemoryProfile, + convert_datasets_dict_to_list_config, +) + + +@dataclass +class DistributedTrainConfig: + """Configuration dataclass for distributed CP training.""" + + data: Any # DataConfigV2 or OmegaConf dict; converted in _create_distributed_data_module + model: LightningModule + output: str + trainer: Optional[dict[str, Any]] = None + parallel_size: Optional[dict[str, Any]] = None + precision: Precision = Precision.FP32 + matmul_precision: Optional[str] = None + find_unused_parameters: Optional[bool] = False # Retained for boltz-2 config compat; unused by CP strategy + save_top_k: Optional[int] = 1 + checkpoint: Optional[dict[str, Any]] = None + resume: Optional[str] = None + pretrained: Optional[str] = None + wandb: Optional[dict[str, Any]] = None + disable_checkpoint: bool = False + debug: bool = False + strict_loading: bool = True + load_confidence_from_trunk: Optional[bool] = False + seed: Optional[int] = None + validation_only: bool = False + v2: bool = True # Retained for structurev2.yaml compat; always True for Boltz-2 + triattn_backend: TriAttnBackend = TriAttnBackend.CUEQ + sdpa_with_bias_backend: SDPAWithBiasBackend = SDPAWithBiasBackend.TORCH_FLEX_ATTN + sdpa_with_bias_shardwise_backend: SDPAWithBiasBackend = SDPAWithBiasBackend.TORCH_FLEX_ATTN + + +def _load_and_merge_config(raw_config: str, args: list[str]) -> omegaconf.DictConfig: + raw_config_dict = omegaconf.OmegaConf.load(raw_config) + if "defaults" in raw_config_dict: + config_path = Path(raw_config) + GlobalHydra.instance().clear() + with initialize_config_dir(config_dir=str(config_path.parent.absolute()), version_base=None): + config_dict = compose(config_name=config_path.stem) + else: + config_dict = raw_config_dict + omegaconf.OmegaConf.set_struct(config_dict, False) + args_dict = omegaconf.OmegaConf.from_dotlist(args) + if ( + "data" in args_dict + and "datasets" in args_dict.data + and "data" in config_dict + and "datasets" in config_dict.data + ): + args_dict["data"]["datasets"] = convert_datasets_dict_to_list_config( + config_dict.data.datasets, + args_dict.data.datasets, + keys_to_override=_DATASET_KEYS_TO_OVERRIDE, + remove_null_datasets=True, + ) + return omegaconf.OmegaConf.merge(config_dict, args_dict) + + +def _parse_precision(value: Any) -> Precision: + if isinstance(value, Precision): + return value + if isinstance(value, str): + if value in Precision.__members__: + return Precision[value] + for precision in Precision: + if precision.value == value: + return precision + raise ValueError(f"Unsupported precision value: {value!r}") + + +def _parse_backend_enum(value: Any, enum_cls: type) -> Any: + """Parse a string or enum value into the given enum class.""" + if isinstance(value, enum_cls): + return value + if isinstance(value, str): + try: + return enum_cls(value) + except ValueError: + pass + if value in enum_cls.__members__: + return enum_cls[value] + raise ValueError(f"Unsupported {enum_cls.__name__} value: {value!r}. Valid: {[e.value for e in enum_cls]}") + + +def _apply_matmul_precision(matmul_precision: Optional[str]) -> None: + """Apply optional matmul precision setting.""" + if matmul_precision is not None: + torch.set_float32_matmul_precision(matmul_precision) + + +def _create_dist_manager(cfg: DistributedTrainConfig) -> DistributedManager: + trainer_cfg = cfg.trainer or {} + parallel_cfg = cfg.parallel_size or {} + + accelerator = trainer_cfg.get("accelerator", "gpu") + accelerator_to_device_type = {"cpu": "cpu", "gpu": "cuda"} + if accelerator not in accelerator_to_device_type: + raise ValueError( + f"Accelerator {accelerator} is not supported; expected one of {sorted(accelerator_to_device_type)}" + ) + device_type = accelerator_to_device_type[accelerator] + + size_dp = int(parallel_cfg.get("size_dp", 1)) + size_cp = int(parallel_cfg.get("size_cp", 1)) + if size_dp <= 0 or size_cp <= 0: + raise ValueError(f"size_dp and size_cp must be positive; got size_dp={size_dp}, size_cp={size_cp}") + + # If already initialized (e.g. train() called multiple times in the same + # process during testing), validate that the requested topology matches the + # existing singleton and return it. + if DistributedManager.is_initialized(): + existing = DistributedManager() + existing_device_type = existing.device.type + if existing_device_type != device_type: + raise ValueError( + f"DistributedManager already initialized with device_type={existing_device_type!r}, " + f"but this call requests device_type={device_type!r}. " + f"Cannot change device type without cleanup + reinitialization." + ) + existing_dp_size = len(existing.group_ranks.get("dp", [])) + existing_cp_size = len(existing.group_ranks.get("cp", [])) + if existing_dp_size and existing_dp_size != size_dp: + raise ValueError( + f"DistributedManager already initialized with dp group size {existing_dp_size}, " + f"but this call requests size_dp={size_dp}. " + f"Cannot change topology without cleanup + reinitialization." + ) + if existing_cp_size and existing_cp_size != size_cp: + raise ValueError( + f"DistributedManager already initialized with cp group size {existing_cp_size}, " + f"but this call requests size_cp={size_cp}. " + f"Cannot change topology without cleanup + reinitialization." + ) + return existing + + timeout_nccl_minutes = parallel_cfg.get("timeout_nccl") + timeout_gloo_minutes = parallel_cfg.get("timeout_gloo") + if timeout_nccl_minutes is not None and timeout_nccl_minutes <= 0: + raise ValueError("timeout_nccl must be positive when provided") + if timeout_gloo_minutes is not None and timeout_gloo_minutes <= 0: + raise ValueError("timeout_gloo must be positive when provided") + + timeout_nccl = timedelta(minutes=timeout_nccl_minutes) if timeout_nccl_minutes is not None else None + timeout_gloo = timedelta(minutes=timeout_gloo_minutes) if timeout_gloo_minutes is not None else None + timeout_by_device = {"cuda": timeout_nccl, "cpu": timeout_gloo} + + DistributedManager.initialize(device_type=device_type, timeout=timeout_by_device[device_type]) + atexit.register(DistributedManager.cleanup) + dist_manager = DistributedManager() + if not dist_manager.has_dist: + raise RuntimeError( + "DistributedManager did not initialize torch.distributed. " + "Launch this entrypoint under torchrun/slurm with RANK/WORLD_SIZE (or SLURM_* env)." + ) + + if size_dp * size_cp != dist_manager.world_size: + raise ValueError( + f"world_size mismatch: process world_size={dist_manager.world_size}, " + f"expected size_dp*size_cp={size_dp * size_cp}" + ) + + size_cp_axis = isqrt(size_cp) + if size_cp_axis * size_cp_axis != size_cp: + raise ValueError(f"size_cp must be a square integer for 2D CP mesh, got {size_cp}") + + grid_group_sizes: OrderedDict[str, int | tuple[int, ...]] = OrderedDict( + [("dp", size_dp), ("cp", (size_cp_axis, size_cp_axis))] + ) + DistributedManager.create_grid_group(grid_group_sizes) + return dist_manager + + +def _load_pretrained_if_requested( + model_module: LightningModule, + cfg: DistributedTrainConfig, +) -> LightningModule: + """Load pretrained weights into ``model_module`` when ``cfg.pretrained`` is set. + + Returns the model unchanged when no pretrained path is configured or when + resuming from a training checkpoint (``cfg.resume``). + + When ``cfg.load_confidence_from_trunk`` is True, trunk weights (everything + except ``structure_module`` and ``distogram_module``) are duplicated under + the ``confidence_module.`` prefix before loading, so the confidence head + inherits shared encoder parameters. + + Loading uses ``strict=False`` to support reduced-depth fine-tuning (fewer + pairformer layers). ``_validate_checkpoint_architecture`` is called + post-load to guard against silent V1/V2 attention mismatches that + ``strict=False`` would otherwise ignore. + """ + if not cfg.pretrained or cfg.resume: + return model_module + + if cfg.load_confidence_from_trunk: + checkpoint = torch.load(cfg.pretrained, map_location="cpu", weights_only=False) + new_state_dict = {} + for key, value in checkpoint["state_dict"].items(): + if not key.startswith("structure_module") and not key.startswith("distogram_module"): + new_state_dict[f"confidence_module.{key}"] = value + new_state_dict.update(checkpoint["state_dict"]) + checkpoint["state_dict"] = new_state_dict + random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=10)) + temp_path = os.path.join(cfg.output, f".tmp_{random_string}.ckpt") + torch.save(checkpoint, temp_path) + file_path = temp_path + else: + file_path = cfg.pretrained + + hparams = dict(model_module.hparams) + if getattr(model_module, "validate_structure", False) and hasattr(model_module, "validators"): + hparams["validators"] = model_module.validators + + loaded = type(model_module).load_from_checkpoint( + file_path, + map_location="cpu", + strict=False, + **hparams, + ) + if cfg.load_confidence_from_trunk: + os.remove(file_path) + + _validate_checkpoint_architecture(loaded) + return loaded + + +def _validate_checkpoint_architecture(model: LightningModule) -> None: + """Verify the loaded model uses the expected attention implementation. + + When ``load_from_checkpoint`` uses ``strict=False``, mismatched hparams + (e.g. missing ``v2=True`` in ``pairformer_args``) can silently create V1 + attention layers whose extra ``norm_s`` weights are randomly initialized + instead of loaded from the checkpoint. This validation catches that. + """ + pairformer = getattr(model, "pairformer_module", None) + if pairformer is None: + return + layers = getattr(pairformer, "layers", []) + if not layers: + return + layer0_attn = getattr(layers[0], "attention", None) + if layer0_attn is None: + return + if not isinstance(layer0_attn, AttentionPairBiasV2): + raise RuntimeError( + f"Pairformer layer 0 attention is {type(layer0_attn).__module__}." + f"{type(layer0_attn).__name__}, expected AttentionPairBiasV2 " + f"(boltz.model.layers.attentionv2). This usually means " + f"pairformer_args is missing 'v2: true' — pass " + f"pairformer_args=asdict(PairformerArgsV2()) to " + f"load_from_checkpoint." + ) + + +def _cleanup_distributed() -> None: + """Clean up distributed process groups after training. + + Extracted as a module-level function so tests that call ``train()`` + multiple times in the same worker process can monkeypatch it to a + no-op and defer cleanup to the test's own ``finally`` block. + """ + if DistributedManager.is_initialized(): + DistributedManager.cleanup() + + +def _create_distributed_data_module( + data_config: Any, + dist_manager: DistributedManager, +) -> pl.LightningDataModule: + """Construct the distributed Boltz-2 training data module. + + Wraps the serial ``Boltz2TrainingDataModule`` with DTensor context-parallel + distribution. Tests may monkeypatch this function to supply a lightweight + DTensor-producing smoke data module. + + Parameters + ---------- + data_config + Data configuration — either a :class:`DataConfigV2` instance or an + OmegaConf/dict that can be unpacked into one. + dist_manager + Initialized :class:`DistributedManager` with grid groups. + """ + cfg = data_config if isinstance(data_config, DataConfigV2) else DataConfigV2(**data_config) + device_mesh = dist_manager.device_mesh_subgroups + device_mesh_cpu = map_subgroup_mesh_to_cpu(dist_manager) + return Boltz2TrainingDataModule(cfg=cfg, device_mesh=device_mesh, device_mesh_cpu=device_mesh_cpu) + + +def _create_distributed_model( + cfg: DistributedTrainConfig, + dist_manager: DistributedManager, +) -> LightningModule: + """Construct the distributed Boltz-2 model with DTensor CP wrapping. + + Instantiates the serial model from config, loads pretrained weights if + requested, moves to device, and wraps with :class:`Boltz2Distributed`. + Tests may monkeypatch this function to supply a lightweight DTensor-aware + smoke model. + + Parameters + ---------- + cfg + Full training configuration (model, pretrained, strict_loading, etc.). + dist_manager + Initialized :class:`DistributedManager` with grid groups. + """ + model_serial = cfg.model + model_serial = _load_pretrained_if_requested(model_serial, cfg) + model_serial = model_serial.to(dist_manager.device) + dist_model = Boltz2Distributed(model_serial, dist_manager) + if not cfg.strict_loading: + dist_model.strict_loading = False + return dist_model + + +def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912 + """Run distributed training scaffold with strategy/checkpoint wiring.""" + config_dict = _load_and_merge_config(raw_config, args) + if "precision" in config_dict: + config_dict.precision = _parse_precision(config_dict.precision) + for backend_key, backend_cls in ( + ("triattn_backend", TriAttnBackend), + ("sdpa_with_bias_backend", SDPAWithBiasBackend), + ("sdpa_with_bias_shardwise_backend", SDPAWithBiasBackend), + ): + if backend_key in config_dict and isinstance(config_dict[backend_key], str): + config_dict[backend_key] = _parse_backend_enum(config_dict[backend_key], backend_cls) + + cuda_memory_profile_cfg = config_dict.pop("CUDAMemoryProfile", None) + offload_actv_ckpt_cfg = config_dict.pop("OffloadActvCkptToCPU", None) + + cfg = hydra.utils.instantiate(config_dict) + cfg = DistributedTrainConfig(**cfg) + if not cfg.v2: + raise NotImplementedError("DTensor distributed training only supports Boltz-2 (v2=true)") + Path(cfg.output).mkdir(parents=True, exist_ok=True) + _apply_matmul_precision(cfg.matmul_precision) + + dist_manager = _create_dist_manager(cfg) + + # Offset RNG seed by rank and, on resume, by epoch + global_step to avoid + # replaying identical data samples. Boltz's TrainingDataset ignores the + # sampler index, so without a resume-aware offset the RNG repeats itself. + seed_offset = 0 + if cfg.resume: + ckpt_meta = torch.load(cfg.resume, mmap=True, map_location="cpu", weights_only=False) + seed_offset = int(ckpt_meta.get("epoch", 0)) + int(ckpt_meta.get("global_step", 0)) + del ckpt_meta + if cfg.seed is not None: + seed_everything(dist_manager.group_rank["world"] + seed_offset + int(cfg.seed)) + + trainer_cfg = dict(cfg.trainer or {}) + + _EXPECTED_DTENSOR_TRAINER = {"devices": 1, "num_nodes": 1} + for key, expected in _EXPECTED_DTENSOR_TRAINER.items(): + val = trainer_cfg.get(key) + if val is not None and val != expected: + raise ValueError( + f"trainer.{key}={val!r} is incompatible with DTensor context-parallel training " + f"(expected {expected!r}). The distributed topology is managed by " + f"DistributedManager via parallel_size, not by Lightning. " + f"Set parallel_size.size_dp and parallel_size.size_cp instead, " + f"and use trainer.{key}={expected!r} or omit it." + ) + trainer_cfg[key] = expected + + num_workers = getattr(getattr(cfg, "data", None), "num_workers", 0) + if num_workers != 0: + raise ValueError( + f"data.num_workers={num_workers} is not supported in DTensor context-parallel training. " + f"Only num_workers=0 is supported because the DTensor data workflow requires " + f"main-process collation for distributed tensor construction. " + f"Set data.num_workers=0 in your config." + ) + + wandb_cfg = cfg.wandb + if cfg.debug: + wandb_cfg = None + + data_module = _create_distributed_data_module(cfg.data, dist_manager) + model_module = _create_distributed_model(cfg, dist_manager) + + model_module.apply(SetTriAttnBackend(cfg.triattn_backend)) + model_module.apply(SetAttnPairBiasBackend(cfg.sdpa_with_bias_backend)) + model_module.apply(SetAttnPairBiasShardwiseBackend(cfg.sdpa_with_bias_shardwise_backend)) + if offload_actv_ckpt_cfg is not None: + model_module.apply(OffloadActvCkptToCPU(set(offload_actv_ckpt_cfg))) + + if getattr(model_module, "confidence_prediction", False): + model_module.confidence_prediction = False + warnings.warn("Confidence prediction is not supported in distributed training mode") + + steering_args = getattr(model_module, "steering_args", None) + if steering_args is not None: + for attr in ("fk_steering", "guidance_update"): + if getattr(steering_args, attr, False): + setattr(steering_args, attr, False) + warnings.warn("Steering potentials are not supported in distributed training mode") + + callbacks: list[Any] = [] + if not cfg.disable_checkpoint: + # Boltz-2 checkpoint defaults; overridable via the ``checkpoint`` config key. + checkpoint_cfg = dict(cfg.checkpoint or {}) + checkpoint_cfg.setdefault("filename", "{epoch:02d}-{step:05d}") + checkpoint_cfg.setdefault("monitor", "val/lddt") + checkpoint_cfg.setdefault("save_top_k", cfg.save_top_k) + checkpoint_cfg.setdefault("save_last", True) + checkpoint_cfg.setdefault("save_on_train_epoch_end", True) + checkpoint_cfg.setdefault("mode", "max") + checkpoint_cfg.setdefault("every_n_epochs", 1) + callbacks.append(ModelCheckpoint(dirpath=cfg.output, **checkpoint_cfg)) + + if cuda_memory_profile_cfg is not None and cuda_memory_profile_cfg.get("output_path_prefix") is not None: + output_path = cuda_memory_profile_cfg.output_path_prefix + f"_rank{dist_manager.group_rank['world']}.pickle" + memory_profile_kwargs = {k: v for k, v in cuda_memory_profile_cfg.items() if k != "output_path_prefix"} + callbacks.append(CUDAMemoryProfile(output_path=output_path, **memory_profile_kwargs)) + + loggers: list[Any] = [] + if wandb_cfg: + wandb_id = wandb_cfg.get("id") + wandb_resume = "allow" if wandb_id else None + wdb_logger = WandbLogger( + name=wandb_cfg["name"], + group=wandb_cfg["name"], + save_dir=cfg.output, + project=wandb_cfg["project"], + entity=wandb_cfg["entity"], + id=wandb_id, + resume=wandb_resume, + log_model=False, + ) + loggers.append(wdb_logger) + + @rank_zero_only + def save_config_to_wandb() -> None: + config_out = Path(wdb_logger.experiment.dir) / "run.yaml" + with config_out.open("w") as file_handle: + OmegaConf.save(config_dict, file_handle) + wdb_logger.experiment.save(str(config_out)) + + save_config_to_wandb() + + strategy = BoltzContextParallelStrategy(dist_manager=dist_manager) + + if cfg.precision not in PRECISION_TO_LIGHTNING: + raise ValueError(f"Precision {cfg.precision} is not supported") + if trainer_cfg.get("precision") is not None: + raise ValueError( + "Set precision in the top-level config, not inside trainer. " + "The trainer.precision key is superseded by the top-level precision setting." + ) + trainer_cfg["precision"] = PRECISION_TO_LIGHTNING[cfg.precision] + + trainer_kwargs = dict( + default_root_dir=cfg.output, + strategy=strategy, + callbacks=callbacks, + logger=loggers, + enable_checkpointing=not cfg.disable_checkpoint, + reload_dataloaders_every_n_epochs=1, + use_distributed_sampler=False, # distributed data module handles its own sharding + **trainer_cfg, + ) + + if _one_logger_available: + # Compute global batch size for OneLogger compliance. + batch_size = getattr(getattr(data_module, "cfg", None), "batch_size", 1) + one_logger_config: dict[str, Any] = { + "global_batch_size": dist_manager.group["dp"].size() * batch_size, + } + if wandb_cfg: + one_logger_config.update( + { + "name": wandb_cfg["name"], + "group": wandb_cfg["name"], + "save_dir": cfg.output, + "project": wandb_cfg["project"], + "entity": wandb_cfg["entity"], + "log_model": False, + } + ) + HookedTrainer, one_logger_callback = hook_trainer_cls(pl.Trainer, callback_config=one_logger_config) + callbacks.append(one_logger_callback) + trainer = HookedTrainer(**trainer_kwargs) + else: + trainer = pl.Trainer(**trainer_kwargs) + + # Suppress expected Lightning warnings in CP mode. + warnings.filterwarnings( + "ignore", + message="It is recommended to use .* when logging on epoch level in " + "distributed setting to accumulate the metric across devices", + ) + warnings.filterwarnings( + "ignore", + message="The .* does not have many workers which may be a bottleneck. " + "Consider increasing the value of the `num_workers` .* to improve performance.", + ) + + try: + with setup_tf32_env(cfg.precision): + if cfg.validation_only: + trainer.validate(model_module, datamodule=data_module, ckpt_path=cfg.resume) + else: + trainer.fit(model_module, datamodule=data_module, ckpt_path=cfg.resume) + finally: + _cleanup_distributed() + + +if __name__ == "__main__": + train(sys.argv[1], sys.argv[2:]) diff --git a/src/boltz/distributed/utils.py b/src/boltz/distributed/utils.py new file mode 100644 index 000000000..e1769fd59 --- /dev/null +++ b/src/boltz/distributed/utils.py @@ -0,0 +1,1161 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from math import isqrt +from typing import Callable, Iterable, List, Self, Sequence + +import numpy as np +import torch +from torch import Tensor +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor.device_mesh import DeviceMesh + + +class LayoutMap: + """ + A class representing a mapping between multidimensional indices and flat indices. + This is basing on the C++ std::layout_stride::mapping but adds on top it the mapping + from the flat index to the multidimensional indices. + + Parameters + ---------- + strides : tuple of ints + The strides of the layout. + shape : tuple of ints, optional + The shape of the layout. If not provided, the shape will be inferred from the strides. + + Raises + ------ + ValueError + If the input strides or shape is invalid. + + Attributes + ---------- + numel : int + The total number of elements in the layout. + shape : tuple of ints + The shape of the layout. + """ + + def __init__( + self, + strides: tuple[int, ...], + shape: tuple[int, ...], + offset: int = 0, + ): + """ + Initialize the layout mapping. + + Parameters + ---------- + strides : tuple of ints + The strides of the layout. + shape : tuple of ints + The shape of the layout. + offset : int, optional + The offset of the layout. + + Raises + ------ + ValueError + If the input strides or shape is invalid. + + Notes + ----- + The input strides must be a cumulative product of some permutation of the input shape. + """ + if not all(isinstance(stride, (int, np.int64)) and stride > 0 for stride in strides): + raise ValueError(f"Input strides contain non-integer or negative values: {strides}") + + self._has_negative_shape = any(s < 0 for s in shape) + + if self._has_negative_shape: + raise ValueError(f"Input shape contain negative values: {shape}") + + self._has_zero_shape = any(s == 0 for s in shape) + + if self._has_zero_shape: + # NOTE: the C++ standard does allow the shape to be zero along some axes, which + # is not useful for our usage case. We nonetheless can relax the condition to + # allow zero-sized axes but it requires more testing + raise ValueError(f"Input shape contain zero values: {shape}") + + self._strides = strides + self._n_axes = len(strides) + + if len(shape) != self._n_axes: + raise ValueError(f"Shape {shape} and strides {strides} must have the same length") + + self._shape = shape + self._numel = np.prod(self._shape) + self._offset = offset + + # singleton axes can confound the uniqueness and exhaustiveness check, e.g., + # for layout right of (3, 1, 5), the strides are (5, 5, 1) but direct argsort + # on the strides will give the permuted exhaustive stride of (1, 5, 15), which + # corresponds to the strides of (5, 15, 1) (argsort of (2, 0, 1)), which will + # fail the uniqueness check. This is purely artifact of the stable sorting where + # the single axis can potentially be arbitrarily placed before or after the + # other axes with the same stride. The correct thing to do is to handle the ties + # involving the singleton axes so that we sort by shape if two stride elements are + # tied. + shape_and_strides = np.array( + list(zip(self._shape, self._strides)), + dtype=np.dtype([("shape", int), ("strides", int)]), + ) + argsort_ascend_strides_and_shape = np.argsort(shape_and_strides, order=["strides", "shape"]) + + self.is_unique = self._is_unique(argsort_ascend_strides_and_shape) + self.is_exhaustive = self._is_exhaustive(argsort_ascend_strides_and_shape) + + if not self.is_unique: + raise ValueError(f"Input strides {strides} and shape {shape} do not give unique layout.") + + self._required_span_size = self._compute_required_span_size() + self._argsort_descend_strides = argsort_ascend_strides_and_shape[::-1] + self._argsort_ascend_strides = argsort_ascend_strides_and_shape + + def _compute_required_span_size(self) -> int: + """ + Calculate the minimal span size required to represent the layout, e.g., + by storing the elements in a contiguous piece of memory + + See the C++ std::layout_stride's requirements here: + https://eel.is/c++draft/views.multidim#mdspan.layout.stride.expo-1 + + """ + if self._n_axes == 0: + return 1 + if self._has_zero_shape: + return 0 + return 1 + sum((self._shape[i] - 1) * self._strides[i] for i in range(self._n_axes)) + + def _strides_exhaustive(self, permutation: np.ndarray) -> np.ndarray: + """ + Calculate the expected exhaustive strides for a given permutation. + + For a valid exhaustive layout, the strides should follow a pattern where + strides[p[i]] equals strides[p[i-1]] * shape[p[i-1]] for all i > 0, + and strides[p[0]] equals 1, where p is the permutation of indices. + + Parameters + ---------- + permutation : np.ndarray + Permutation of indices in ascending order of strides. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + A tuple containing: + - The permuted strides + - The expected exhaustive strides for the given permutation + """ + strides = np.array(self._strides) + shape = np.array(self._shape) + shape_permuted = shape[permutation] + strides_permuted = strides[permutation] + shape_shifted = np.concatenate([[1], shape_permuted[:-1]]) + strides_shifted = np.concatenate([[1], strides_permuted[:-1]]) + return strides_permuted, strides_shifted * shape_shifted + + def _is_unique(self, permutation: np.ndarray) -> bool: + """ + Check if the layout mapping is unique, i.e., + if the mapping: n-dimensional index -> flat index is injective. + + This implements the necessary and sufficient condition for a + LayoutMap, i.e., the equivalent std::layout_stride::mapping, to + be unique. + + See the C++ std::layout_stride's requirements here for the + uniqueness condition: + https://eel.is/c++draft/views.multidim#mdspan.layout.stride.cons + + Parameters + ---------- + permutation : np.ndarray + Permutation of indices per the standard's requirement. Note that + the C++ standard doesn't specify what the permutation should be + as soon as it exists. In practice, one can construct a proof by + induction that the permutation can be chosen to be the ascending + ordering of the strides. + + Returns + ------- + bool + True if the layout has unique strides, False otherwise. + """ + if self._n_axes == 0: + return True + strides, strides_exhaustive = self._strides_exhaustive(permutation) + ans = np.all(strides >= strides_exhaustive) + return ans + + def _is_exhaustive(self, permutation: np.ndarray) -> bool: + """ + Check if the layout mapping is exhaustive, i.e., + if the mapping: n-dimensional index -> flat index is surjective. + + This implements the necessary and sufficient condition for a + LayoutMap, i.e., the equivalent std::layout_stride::mapping, to + be exhaustive. + + See the C++ std::layout_stride's requirements here for the + exhaustiveness condition: + https://eel.is/c++draft/views.multidim#mdspan.layout.stride.obs-5.2 + + Parameters + ---------- + permutation : np.ndarray + Permutation of indices per the standard's requirement. Note that + the C++ standard doesn't specify what the permutation should be + as soon as it exists. In practice, one can construct a proof by + induction that the permutation can be chosen to be the ascending + ordering of the strides. + + Returns + ------- + bool + True if the layout is exhaustive, False otherwise. + """ + if self._n_axes == 0: + return True + strides, strides_exhaustive = self._strides_exhaustive(permutation) + ans = np.all(strides == strides_exhaustive) + return ans + + @property + def offset(self) -> int: + """ + Get the offset of the layout. + """ + return self._offset + + @property + def required_span_size(self) -> int: + """ + Get the required span size of the layout. + """ + return self._required_span_size + + @property + def numel(self) -> int: + """ + Get the total number of elements in the layout. + Returns + ------- + int + The total number of elements. + """ + return self._numel + + @property + def shape(self) -> tuple[int, ...]: + """ + Get the shape of the layout. + + Returns + ------- + tuple of ints + The shape of the layout. + """ + return self._shape + + @property + def strides(self) -> tuple[int, ...]: + """ + Get the strides of the layout. + + Returns + ------- + tuple of ints + The strides of the layout. + + Notes + ----- + The strides are a cumulative product of some permutation of the input shape. + """ + return self._strides + + def __call__(self, ids: tuple[int, ...]) -> int: + """ + Get the flat index corresponding to the given multidimensional index. + + Parameters + ---------- + ids : tuple of ints + The multidimensional index. + + Returns + ------- + int + The flat index. + + Raises + ------ + ValueError + If the input index is out of range. + """ + if len(ids) != self._n_axes: + raise ValueError(f"Expected {self._n_axes} elements in ids but got only {len(ids)}") + + if len(ids) == 0: + return self._offset + + if self._shape is not None: + for axis, idx in enumerate(ids): + if idx < 0 or idx >= self._shape[axis]: + raise ValueError( + f"Expected ids to satisfy 0 <= ids[{axis}] <= {self._shape[axis] - 1} " + f"but found ids[{axis}] == {idx}" + ) + return np.dot(ids, self._strides) + self._offset + + def unravel(self, flat_index: int) -> tuple[int, ...]: + """ + Convert a flat index to a multidimensional index. + + Parameters + ---------- + flat_index : int + The flat index. + + Returns + ------- + tuple of ints + The multidimensional index. + + Raises + ------ + TypeError + If the input is not an integer. + ValueError + If the input index is out of range. + """ + if not self.is_unique: + # double check the uniqueness of the layout + raise ValueError(f"Layout is not unique, cannot unravel {flat_index}") + + if not isinstance(flat_index, (int, np.integer)): + raise TypeError(f"Expected arg to be an int, but instead got type {type(flat_index)}") + + remaining = flat_index - self._offset + + if remaining < 0 or remaining >= self._required_span_size: + raise ValueError( + f"Expected flat_index in range [{self._offset}, {self._offset + self._required_span_size - 1}], " + f"but instead got {flat_index}" + ) + + indices = [0] * self._n_axes # Initialize indices + + for i_dim in self._argsort_descend_strides: + stride = self._strides[i_dim] + size = self._shape[i_dim] + indices[i_dim] = (remaining // stride) % size + remaining -= indices[i_dim] * stride + + if remaining != 0: + msg = f"Input flat_index {flat_index} is out of the valid range of span." + if not self.is_exhaustive: + msg += " Given the layout is not exhaustive, the input flat_index can fall into the unmapped region." + raise ValueError(msg) + + return tuple(indices) + + def __getitem__(self, slices: tuple[slice | int, ...]) -> Self: + """ + Create a new LayoutMap by slicing the current layout along specified dimensions. + + This method allows for creating sub-layouts by slicing the original layout, + similar to numpy array slicing. Dimensions can be reduced by using integer + indices, or transformed by using slices with custom start, stop, and step values. + + Parameters + ---------- + slices : tuple[slice | int, ...] or slice or int + Slices to apply along each dimension. Can be a single slice/int or a tuple + of slices/integers. Integer indices collapse the corresponding dimension, + while slices transform the dimension according to start, stop, and step. + If fewer slices are provided than dimensions, remaining dimensions will + be sliced with full-range slices (equivalent to ':'). + + Returns + ------- + Self + A new LayoutMap instance with updated shape, strides and offset reflecting + the applied slicing. + + Raises + ------ + ValueError + If slicing with negative or zero steps, or if start is less than or equal to stop. + TypeError + If slice elements are not of type slice or int. + + Examples + -------- + >>> layout = LayoutMap((12, 4, 1), (2, 3, 4)) + >>> # Slice first dimension from indices 1 to 3 with step 2 + >>> sub_layout = layout[slice(1, 3, 2), :, :] + >>> # Collapse second dimension by selecting index 1 + >>> sub_layout = layout[:, 1, :] + >>> # Slice only first dimension, remaining dimensions use full range + >>> sub_layout = layout[1:] # Equivalent to layout[1:, :, :] + """ + if not isinstance(slices, tuple) and (isinstance(slices, slice) or isinstance(slices, int)): + slices = (slices,) + + # Pad slices with full-range slices if needed + if len(slices) < self._n_axes: + full_slice = slice(None) # This is equivalent to ':' + slices = slices + (full_slice,) * (self._n_axes - len(slices)) + + new_shape = [] + new_strides = [] + new_offset = self.offset + + for axis, s in enumerate(slices): + if isinstance(s, (int, np.int64)): + # Collapse dimension and adjust offset + new_offset += s * self.strides[axis] + elif isinstance(s, slice): + start, stop, step = s.indices(self.shape[axis]) + if step <= 0: + raise ValueError("Unsupported slicing: Negative or zero steps") + if start >= stop: + # NOTE: the start == stop could be supported because we could have + # a layout with shape[i] == 0. But it wouldn't be useful for our usage cases. + raise ValueError("Unsupported slicing: start not smaller than stop") + + # Calculate new dimension length + dim_len = (stop - start + step - 1) // step + dim_len = max(0, dim_len) + + # Update metadata + new_shape.append(dim_len) + new_strides.append(self.strides[axis] * step) + new_offset += start * self.strides[axis] + else: + raise TypeError(f"Unsupported slice type: {type(s)}") + + return LayoutMap(tuple(new_strides), tuple(new_shape), new_offset) + + +class LayoutRightMap(LayoutMap): + """ + A class representing a right-aligned layout mapping. + + Parameters + ---------- + shape : tuple of ints + The shape of the layout. + """ + + def __init__(self, shape: tuple[int, ...]): + """ + Initialize the layout mapping. + + Parameters + ---------- + shape : tuple of ints + The shape of the layout. + """ + strides = np.ones_like(shape) + strides[1:] = shape[:0:-1] + strides = np.cumprod(strides)[::-1] + super().__init__(tuple(strides), shape=shape) + + +class LayoutLeftMap(LayoutMap): + """ + A class representing a left-aligned layout mapping. + + Parameters + ---------- + shape : tuple of ints + The shape of the layout. + """ + + def __init__(self, shape: tuple[int, ...]): + """ + Initialize the layout mapping. + + Parameters + ---------- + shape : tuple of ints + The shape of the layout. + """ + strides = np.ones_like(shape) + strides[1:] = shape[:-1] + strides = np.cumprod(strides) + super().__init__(tuple(strides), shape=shape) + + +def slice_repr_mask( + s: torch.Tensor, + z: torch.Tensor, + mask: torch.Tensor, + pair_mask: torch.Tensor, + n_ranks: int, + layout_group: LayoutMap, +) -> tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """ + Slice the single- and pair-representation tensors and masks along the given coordinates. + + This function slices the input tensors into n_ranks segments along the token dimensions, + with the segment boundaries determined by the layout mapping of the ranks. + + Parameters + ---------- + s : torch.Tensor + Single representation tensor. + z : torch.Tensor + Pair representation tensor. + mask : torch.Tensor + Single representation mask. + pair_mask : torch.Tensor + Pair representation mask. + n_ranks : int + The number of ranks. + layout_group : LayoutMap + The layout mapping of the ranks. + Returns + ------- + tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]] + Sliced single representation tensors, pair representation tensors, single representation masks, + and pair representation masks. + + Raises + ------ + ValueError + If the shape of any input tensor is incompatible with the layout mapping. + """ + if z.shape[-2] != z.shape[-3]: + raise ValueError(f"z is not square tensor in the middle two axes but of shape {z.shape}") + if s.shape[-2] != z.shape[-3]: + raise ValueError(f"Incompatible s shape {s.shape} and z shape {z.shape}") + if mask.shape != s.shape[:-1]: + raise ValueError(f"Incompatible s shape {s.shape} and mask shape {mask.shape}") + if pair_mask.shape != z.shape[:-1]: + raise ValueError(f"Incompatible z shape {z.shape} and pair_mask shape {pair_mask.shape}") + n_tokens = s.shape[-2] + coords = [layout_group.unravel(rank) for rank in range(n_ranks)] + n_ranks_axis = isqrt(n_ranks) + if n_ranks_axis * n_ranks_axis != n_ranks: + raise ValueError(f"Input n_ranks is not a square int: {n_ranks}") + if n_tokens % n_ranks_axis: + raise ValueError( + f"Input tensors size along the token dimensions {n_tokens} not divisible by square root of n_ranks {n_ranks}" + ) + stride = n_tokens // n_ranks_axis + s_slices = [] + z_slices = [] + mask_slices = [] + pair_mask_slices = [] + for i_row, j_col in coords: + i_row_begin = i_row * stride + i_row_end = (i_row + 1) * stride + j_col_begin = j_col * stride + j_col_end = (j_col + 1) * stride + s_slices.append(s[..., i_row_begin:i_row_end, :].contiguous()) + mask_slices.append(mask[..., i_row_begin:i_row_end].contiguous()) + z_slices.append( + z[ + ..., + i_row_begin:i_row_end, + j_col_begin:j_col_end, + :, + ].contiguous() + ) + pair_mask_slices.append( + pair_mask[ + ..., + i_row_begin:i_row_end, + j_col_begin:j_col_end, + ].contiguous() + ) + return s_slices, z_slices, mask_slices, pair_mask_slices + + +def gather_repr( + s_slices: List[torch.Tensor], + z_slices: List[torch.Tensor], + rank: int, + group: torch.distributed.ProcessGroup, + layout_group: LayoutMap, + s_global: torch.Tensor, + z_global: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Gather representation slices from all ranks in the group and fill the global representations. + + This function performs all-gather operations on the representation slices from each rank, + and fills the global representations `s_global` and `z_global`. + The single representation `s_global` is assumed to be sliced along the rows of the grid of process ranks, + i.e., rank[i, :] owns s_global[..., i, :]. + The pair representation `z_global` is sliced along both of its middle axes. + + Parameters + ---------- + s_slices : List[torch.Tensor] + Slices of single representation `s` from each rank of the group. + z_slices : List[torch.Tensor] + Slices of pair representation `z` from each rank of the group. + rank : int + The rank of the current process. + group : torch.distributed.ProcessGroup + The process group for all-gather operation. + layout_group : LayoutMap + The layout mapping of the group. + s_global : torch.Tensor + The global representation `s` to be filled. + z_global : torch.Tensor + The global representation `z` to be filled. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + The gathered global representations `s_global` and `z_global`. + + Raises + ------ + ValueError + If the input tensors have incompatible shapes. + """ + if z_global.shape[-2] != z_global.shape[-3]: + raise ValueError(f"z_global is not square tensor in the middle two axes but of shape {z_global.shape}") + if s_global.shape[-2] != z_global.shape[-3]: + raise ValueError(f"Incompatible s_global shape {s_global.shape} and z_global shape {z_global.shape}") + n_tokens = s_global.shape[-2] + ranks = torch.distributed.get_process_group_ranks(group) + n_ranks = len(ranks) + coords = [layout_group.unravel(rank) for rank in range(n_ranks)] + n_ranks_axis = isqrt(n_ranks) + if n_ranks_axis * n_ranks_axis != n_ranks: + raise ValueError(f"Input n_ranks is not a square int: {n_ranks}") + if n_tokens % n_ranks_axis: + raise ValueError( + f"Input tensors size along the token dimensions {n_tokens} not divisible by square root of n_ranks {n_ranks}" + ) + stride = n_tokens // n_ranks_axis + req_gather_z = torch.distributed.all_gather(z_slices, z_slices[rank], group=group, async_op=True) + req_gather_s = torch.distributed.all_gather(s_slices, s_slices[rank], group=group, async_op=True) + req_gather_s.wait() + for i_rank in range(len(coords)): + i_row, j_col = coords[i_rank] + s_global[..., i_row * stride : (i_row + 1) * stride, :] = s_slices[i_rank] + req_gather_z.wait() + for i_rank in range(len(coords)): + i_row, j_col = coords[i_rank] + z_global[..., i_row * stride : (i_row + 1) * stride, j_col * stride : (j_col + 1) * stride, :] = z_slices[ + i_rank + ] + return s_global, z_global + + +def get_group_rank_from_axial_shift(coord: tuple[int, ...], axis: int, delta: int, layout_group: LayoutMap) -> int: + """ + Get the rank of a process after shifting its coordinates along an axis. + + Parameters + ---------- + coord : tuple of ints + The current coordinates of the process in the group layout. + axis : int + The axis along which to shift the coordinates. + delta : int + The amount to shift the coordinates by (can be positive or negative). + layout_group : LayoutMap + The layout mapping of the process group. + + Returns + ------- + int + The rank of the process after shifting its coordinates. + + Raises + ------ + ValueError + If the coordinates are incompatible with the layout shape or if the axis is out of range. + """ + if len(coord) != len(layout_group.shape): + raise ValueError(f"Incompatible coord {coord} and layout_group shape {layout_group.shape}") + if axis >= len(coord): + raise ValueError(f"Axis {axis} is out of range for coord {coord}") + coord_shifted = list(coord) + coord_shifted[axis] = (coord_shifted[axis] + delta) % layout_group.shape[axis] + return layout_group(coord_shifted) + + +def all_reduce_weighted_mean( + weights, + values, + group_reduce: torch.distributed.ProcessGroup, + dim: int | tuple[int, ...] = -1, + eps: float = 0.0, +) -> Tensor: + """Perform distributed weighted sum operation. + + Args: + weights: weights tensor + values: values tensor + group_reduce: process group for reduction + dim: dimension to perform weighted mean operation on; default is -1 + eps: epsilon value to avoid division by zero; default is 0.0 + + Returns: + Tensor: weighted mean of values + """ + values_local = torch.sum(weights * values, dim=dim) + values_work = torch.distributed.all_reduce( + values_local, + op=torch.distributed.ReduceOp.SUM, + group=group_reduce, + async_op=True, + ) + weights_local = torch.sum(weights, dim=dim) + eps + torch.distributed.all_reduce(weights_local, op=torch.distributed.ReduceOp.SUM, group=group_reduce) + values_work.wait() + return values_local / weights_local + + +def tiled_softmax_attention_update( + o_chunk: torch.Tensor, + lse_m_chunk: torch.Tensor, + amax_chunk: torch.Tensor | None, + o: torch.Tensor | None = None, + lse_m: torch.Tensor | None = None, + amax: torch.Tensor | None = None, +): + """ + Update online softmax attention accumulation with a new chunk of data. + + This function implements a numerically stable online softmax computation that processes + data in chunks. It maintains running statistics (output accumulation, log-sum-exp, and + maximum values) to compute the final softmax result incrementally without storing all + intermediate values in memory. + + The algorithm is particularly useful for attention mechanisms where the sequence length + is too large to fit in memory at once, allowing for tiled/chunked processing while + maintaining mathematical equivalence to full softmax computation. + + Args: + o_chunk (torch.Tensor): Output accumulation for the current chunk. + Shape: (..., D) where D is the feature dimension. + lse_m_chunk (torch.Tensor): Log-sum-exp minus max for the current chunk. + Shape: (..., 1) - must have last dimension of size 1. + Note: When amax_chunk is None, this is actually lse (not lse_m = lse - amax). + amax_chunk (torch.Tensor | None): Maximum value for the current chunk. + Shape: (..., 1) - must match lse_m_chunk shape. + If None, the function operates without amax tracking (has_amax == False mode). + See Note below for implications. + o (torch.Tensor | None, optional): Accumulated output from previous chunks. + If None, this is treated as the first chunk. Must have same shape as o_chunk + if provided. Defaults to None. + lse_m (torch.Tensor | None, optional): Accumulated log-sum-exp minus max from + previous chunks. Must have same shape as lse_m_chunk if provided. + Note: When amax is None, this is actually lse (not lse_m = lse - amax). + Defaults to None. + amax (torch.Tensor | None, optional): Maximum value across all previous chunks. + Must have same shape as amax_chunk if provided. If amax_chunk is None, + this must also be None, and the returned amax will also be None. Defaults to None. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - o (torch.Tensor): Updated accumulated output with current chunk incorporated. + - lse_m (torch.Tensor): Updated log-sum-exp minus max value. + - amax (torch.Tensor): Updated maximum value across all processed chunks. + + Raises: + ValueError: If input tensors have incompatible shapes or if the None/not-None + consistency of o, lse_m, and amax is violated. + + Mathematical Derivation: + The online softmax update rules are derived as follows: + + Given: + - amax: cumulative max until current block + - amax_chunk: max of current block + - amax_next: resulting amax post-update = max(amax, amax_chunk) + - lse_chunk = log(sum(exp(a_chunk))) + - lse_m_chunk = log(sum(exp(a_chunk - amax_chunk))) = lse_chunk - amax_chunk + - lse = log(sum(exp(a))) + - lse_m = log(sum(exp(a - amax))) = lse - amax + - out = unormalized_out / exp(lse) = unormalized_out / exp(lse_m + amax) + + Update rule for lse_next: + lse_next = lse + torch.log(1 + torch.exp(lse_chunk - lse)) + = lse - F.logsigmoid(lse - lse_chunk) + + Update rule for lse_m_next: + lse_m_next = lse_next - amax_next + = lse - F.logsigmoid(lse - lse_chunk) - amax_next + = lse - (lse - lse_chunk + F.logsigmoid(lse_chunk - lse)) - amax_next + = lse_chunk - amax_next - F.logsigmoid(lse_chunk - lse) + = lse_m_chunk + amax_chunk - amax_next - F.logsigmoid(lse_m_chunk + amax_chunk - lse_m - amax) + = lse_m_chunk + amax_chunk - amax_next - F.logsigmoid(lse_m_chunk - lse_m + amax_chunk - amax) + = lse_m_chunk + amax_chunk - amax_next - (amax_chunk - amax - logsumexp([-lse_m_chunk + lse_m, amax_chunk - amax])) + = lse_m_chunk + amax - amax_next + logsumexp([-lse_m_chunk + lse_m, amax_chunk - amax]) + = lse_m_chunk + logsumexp([amax - amax_next - lse_m_chunk + lse_m, amax_chunk - amax_next]) + + Update rule for o_next (numerically stable form): + o_next = torch.exp(lse - lse_next) * o + torch.exp(lse_chunk - lse_next) * o_chunk + + The following computation is more numerically stable: + o_next = o / (1 + torch.exp(lse_chunk - lse)) + torch.exp(lse_chunk - lse + F.logsigmoid(lse - lse_chunk)) * o_chunk + = o * F.sigmoid(lse - lse_chunk) + torch.exp(lse_chunk - lse) * F.sigmoid(lse - lse_chunk) * o_chunk + = o * F.sigmoid(lse - lse_chunk) + F.sigmoid(lse_chunk - lse) * o_chunk + = o * (1 - F.sigmoid(lse_chunk - lse)) + F.sigmoid(lse_chunk - lse) * o_chunk + = o - F.sigmoid(lse_chunk - lse) * (o - o_chunk) + = o - F.sigmoid(lse_m_chunk + amax_chunk - lse_m - amax) * (o - o_chunk) + + Note: + All three optional parameters (o, lse_m, amax) must be either all None (indicating + this is the first chunk) or all provided with compatible shapes. The function uses + numerically stable computations with log-space arithmetic to prevent overflow/underflow + issues common in softmax computations. + + The mathematical operations maintain the invariant that the final result is equivalent + to computing softmax over the concatenation of all processed chunks. + + **Behavior when amax_chunk is None (has_amax == False mode):** + When amax_chunk (and correspondingly amax) is None, the function operates in a mode + where: + + a) lse_m is actually lse (i.e., we assume lse_m = lse and don't track amax separately) + b) The returned amax will always be None throughout all chunk updates + c) The update equations have **limited dynamic range** due to catastrophic cancellation + in the computation of d_lse_m = lse_m - lse_m_chunk. Without amax to handle extreme + values, this subtraction can lose precision when values differ significantly. + + Despite the limited dynamic range, this mode still works correctly in common use cases + where "-inf" padding in attention scores appears only after normal values (not before). + In such cases, the lse pattern is [...normal values..., -inf, -inf, ..., -inf], and + the trailing -inf values are effectively discarded because sigmoid(-inf) ≈ 0 and + logsigmoid(inf) ≈ 0, even though -inf + normal_value = -inf arithmetically. + + For better numerical stability across all scenarios, it is recommended to use + has_amax == True mode by providing amax_chunk. + """ + if not ((o is None) == (lse_m is None)): + raise ValueError("o and lse_m must both be None or both be not None") + + # has_amax == False will ignored amax terms entirely but assumes lse_m = lse_m + amax + has_amax = amax_chunk is not None + + is_initial_chunk = o is None + + if has_amax and lse_m_chunk.shape != amax_chunk.shape: + raise ValueError("lse_m_chunk and amax_chunk must have the same shape") + + shape_o = o_chunk.shape + + if lse_m_chunk.shape[-1] != 1: + raise ValueError("lse_m_chunk must have shape (..., 1)") + + if o_chunk.ndim != lse_m_chunk.ndim: + raise ValueError("o_chunk and lse_m_chunk must have the same number of dimensions") + + if lse_m_chunk.shape[:-1] != shape_o[:-1]: + raise ValueError("o_chunk and lse_m_chunk must have the same shape except for the last dimension") + + if not is_initial_chunk: + if o_chunk.shape != o.shape: + raise ValueError("o_chunk and o must have the same shape") + if lse_m_chunk.shape != lse_m.shape: + raise ValueError("lse_m_chunk and lse_m must have the same shape") + if (amax is None) != (amax_chunk is None): + raise ValueError("amax and amax_chunk must both be None or both be not None for non-initial chunks") + if has_amax and amax_chunk.shape != amax.shape: + raise ValueError("amax_chunk and amax must have the same shape") + + if is_initial_chunk: + o = o_chunk + lse_m = lse_m_chunk + amax = amax_chunk + else: + if has_amax: + d_lse_m = lse_m - lse_m_chunk + amax_next = torch.maximum(amax_chunk, amax) + delta_lse = amax_chunk - amax - d_lse_m + o = o - torch.sigmoid(delta_lse) * (o - o_chunk) + # torch.logsumexp unconditionally promotes BF16/FP16 → FP32. + # Cast back to the accumulator dtype to prevent FP32 leaking into + # lse_m → delta_lse → sigmoid → o on subsequent ring steps. + lse_m = lse_m_chunk + torch.logsumexp( + torch.cat([(amax - amax_next) + d_lse_m, amax_chunk - amax_next], dim=-1), + dim=-1, + keepdim=True, + ).to(dtype=lse_m_chunk.dtype) + # # TODO: the following double-buffer approach can be used equivalently + # # but save some memory + # # First create the double buffer to store: + # # d_lse_m[..., 0] = amax_chunk - amax + lse_m - lse_m_chunk + # # d_lse_m[..., 1] = (o - o_chunk) + # d_lse_m = amax_chunk.repeat_interleave(2, dim=-1) + # amax = amax.squeeze(-1) + # lse_m_chunk = lse_m_chunk.squeeze(-1) + # lse_m = lse_m.squeeze(-1) + # d_lse_m[..., 0] -= amax + # d_lse_m[..., 0] += lse_m_chunk + # d_lse_m[..., 0] -= lse_m + # d_lse_m[..., 1] = o.squeeze(-1) + # d_lse_m[..., 1] -= o_chunk.squeeze(-1) + # # then do the sigmoid and update o + # d_lse_m[..., 0].sigmoid_() + # d_lse_m[..., 0] *= d_lse_m[..., 1] + # o = o - d_lse_m[..., 0].reshape_as(o) + # # reuse the double buffer to update lse_m + # # amax_next = torch.maximum(amax_chunk, amax) + # # d_lse_m[..., 0] = -amax_next + # # d_lse_m[..., 1] = -amax_next + # d_lse_m[..., 0] = amax_next.squeeze(-1) + # d_lse_m[..., 0].neg_() + # d_lse_m[..., 1] = d_lse_m[..., 0] + # # d_lse_m[..., 0] = amax - amax_next + lse_m - lse_m_chunk + # d_lse_m[..., 0] += amax + # d_lse_m[..., 0] += lse_m + # d_lse_m[..., 0] -= lse_m_chunk + # # d_lse_m[..., 1] = amax_chunk - amax_next + # d_lse_m[..., 1] += amax_chunk.squeeze(-1) + # lse_m = d_lse_m.logsumexp(dim=-1, keepdim=True) + # lse_m += lse_m_chunk.unsqueeze(-1) + amax = amax_next + else: + # NOTE: without amax taking away contribution from extreme values, + # d_lse_m can result in catastrophic cancellation. This whole branch + # of update therefore has much smaller dynamic range as compared to + # has_max being True. Nonetheless, in commonly encountered usage cases + # where the "-inf" padding in the attention score only shows up after + # those normal values but not preceding them, we would have a lse pattern of: + # [... , -inf, -inf, ..., -inf], which works + # still despite -inf + normal_value = -inf because sigmoid(-inf) ~= 0 + # and logsigmoid(inf) ~= 0 so the trailing -inf would be virtually + # discarded + d_lse_m = lse_m - lse_m_chunk + delta_lse = -d_lse_m + o = o - torch.sigmoid(delta_lse) * (o - o_chunk) + # when amax is None, lse_m is lse, i.e., we assume lse_m = lse_m + amax + lse_m = lse_m - torch.nn.functional.logsigmoid(d_lse_m) + amax = None + + return o, lse_m, amax + + +def create_and_broadcast_tensor_into_placements( + shape: tuple[int, ...], + create_local_fn: Callable[[Iterable[int], torch.dtype, torch.device], torch.Tensor], + device_mesh: DeviceMesh, + placements: tuple[Shard | Replicate, ...], + dtype: torch.dtype = torch.float32, +) -> Tensor: + """Create a local tensor and broadcast it independently along each replicate axis. + + With multiple replicate axes, we first create a local tensor at the source rank and broadcast + it to all other ranks. Source rank is identified as the intersection of group rank zeros along + all replicate axes. As such, create_local_fn is only called on a subset of ranks, i.e. the + source rank(s). + + Parameters + ---------- + shape : tuple[int, ...] + Shape of the random tensor. + create_local_fn : Callable[[Iterable[int], torch.dtype, torch.device], torch.Tensor] + Function to create a local tensor from the shape of local tensor and dtype. + device_mesh : DeviceMesh + The device mesh for DTensor operations. + placements : tuple[Shard, ...] + The placements of the random tensor. + dtype : torch.dtype, optional + The dtype of the target tensor. Defaults to torch.float32. + + Returns + ------- + Tensor + The local tensor after broadcasting. + """ + axis_groups = ( + device_mesh.get_all_groups() + ) # reference: https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/device_mesh.py#L793 + + shape_local = list(shape) + is_mesh_dim_replicated = [placement.is_replicate() for placement in placements] + + # check placements + for i_dim_mesh, placement in enumerate(placements): + if placement.is_shard(): + # check if sharding is even on cp axes + if shape[placement.dim] % device_mesh.shape[i_dim_mesh] != 0: + raise ValueError( + f"Uneven sharding tensor dimension {placement.dim} of size {shape[placement.dim]} " + f"along device mesh dimension {i_dim_mesh} of size {device_mesh.shape[i_dim_mesh]} is not supported" + ) + shape_local[placement.dim] = shape[placement.dim] // device_mesh.shape[i_dim_mesh] + elif placement.is_partial(): + raise ValueError(f"Partial placements are not supported yet but got {placements}") + + is_source_rank = all( + not is_replicated or group_rank == 0 + for is_replicated, group_rank in zip(is_mesh_dim_replicated, device_mesh.get_coordinate()) + ) + + if is_source_rank: + tensor_local = create_local_fn(shape_local=shape_local, dtype=dtype, device=device_mesh.device_type) + else: + tensor_local = torch.empty(shape_local, device=device_mesh.device_type, dtype=dtype) + + for axis_group, is_replicated in zip(axis_groups, is_mesh_dim_replicated): + if not is_replicated: + continue + torch.distributed.broadcast( + tensor_local, + torch.distributed.get_global_rank(axis_group, 0), + group=axis_group, + ) + + return tensor_local + + +def create_distributed_randn( + shape: tuple[int, ...], + device_mesh: DeviceMesh, + placements: tuple[Shard | Replicate, ...], + dtype: torch.dtype = torch.float32, + scale: float = 1.0, +): + """Create a distributed random normal distributed tensor. + + Parameters + ---------- + shape : tuple[int, ...] + Shape of the random tensor. + device_mesh : DeviceMesh + The device mesh for DTensor operations. + placements : tuple[Shard, ...] + The placements of the random tensor. + dtype : torch.dtype, optional + The dtype of the random tensor. Defaults to torch.float32. + scale : float, optional + Scale of the normal distribution. Defaults to 1.0. + + Returns + ------- + DTensor + The randn DTensor with the corresponding placements. + """ + + def create_randn_fn(shape_local, dtype, device): + return torch.randn(shape_local, dtype=dtype, device=device) * scale + + tensor_local = create_and_broadcast_tensor_into_placements( + shape=shape, + create_local_fn=create_randn_fn, + device_mesh=device_mesh, + placements=placements, + dtype=dtype, + ) + # by broadcasting inside the create_and_broadcast_tensor_into_placements + # tensor_local is guaranteed to be of same shape across ranks + # FIXME: create_and_broadcast_tensor_into_placements should be responsible + # for creating the DTensor output + shape_output = list(tensor_local.shape) + for i_dim_mesh, p in enumerate(placements): + if isinstance(p, Shard): + shape_output[p.dim] *= device_mesh.shape[i_dim_mesh] + shape_output = tuple(shape_output) + stride_output = update_exhaustive_strides(tensor_local.shape, tensor_local.stride(), shape_output) + + return DTensor.from_local(tensor_local, device_mesh, placements, shape=shape_output, stride=stride_output) + + +def update_exhaustive_strides( + shape_original: Sequence[int], strides_original: Sequence[int], shape_new: Sequence[int] +) -> Sequence[int]: + """ + Update strides to maintain the same memory layout pattern when shape changes. + + This function computes new strides that preserve the same axis ordering and memory + layout pattern as the original exhaustive layout, but with a new shape. The resulting + strides will create an exhaustive layout with the same dimension ordering as the + original layout. + + An exhaustive layout is one where the mapping from multidimensional indices to flat + indices is surjective, meaning every valid flat index corresponds to at least one + multidimensional index. Meanwhile, a non-unique layout is not practically useful + for our application so we further require the input shape and strides to form an + unique layout, which implies the output layout is also unique. Overall, both the + input and output layouts are bijective + + Parameters + ---------- + shape_original : Sequence[int] + The original shape of the tensor layout. + strides_original : Sequence[int] + The original strides of the tensor layout. Must form an exhaustive layout + with shape_original. + shape_new : Sequence[int] + The new shape for which to compute corresponding strides. Must have the + same number of dimensions as shape_original. + + Returns + ------- + Sequence[int] + New strides that maintain the same memory layout pattern as the original + but are compatible with the new shape. The resulting strides will form + an exhaustive layout with shape_new. + + Raises + ------ + ValueError + If the original layout (shape_original, strides_original) is not exhaustive. + + Examples + -------- + >>> # Original layout: right-aligned (row-major) for shape (2, 3, 4) + >>> shape_orig = (2, 3, 4) + >>> strides_orig = (12, 4, 1) # exhaustive right-aligned strides + >>> shape_new = (3, 5, 2) + >>> new_strides = update_exhaustive_strides(shape_orig, strides_orig, shape_new) + >>> # Result: (10, 2, 1) - maintains right-aligned pattern + + Notes + ----- + The algorithm works by: + 1. Creating a LayoutMap from the original shape and strides + 2. Verifying the original layout is exhaustive + 3. Reordering the new shape according to the original layout's stride ordering + 4. Computing exhaustive strides for the reordered new shape + 5. Reordering the computed strides back to match the original dimension order + + This is useful when reshaping tensors while preserving their memory access patterns, + particularly in distributed computing scenarios where maintaining consistent + memory layouts across different tensor shapes is important. + """ + layout_original = LayoutMap(tuple(strides_original), tuple(shape_original)) + if not layout_original.is_exhaustive: + raise ValueError(f"Input layout with shape {shape_original} and strides {strides_original} is not exhaustive") + shape_new_ascending = np.array(shape_new)[layout_original._argsort_ascend_strides] + argsort_output = np.argsort(layout_original._argsort_ascend_strides) + strides_new_ascending = np.concatenate(([1], shape_new_ascending[:-1])).cumprod() + strides_new = strides_new_ascending[argsort_output] + return tuple(strides_new.tolist()) diff --git a/src/boltz/main.py b/src/boltz/main.py index ba28220fe..70045c8fa 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + import multiprocessing import os import pickle @@ -1035,6 +1058,13 @@ def cli() -> None: is_flag=True, help=" to dump the s and z embeddings into a npz file. Default is False.", ) +@click.option( + "--input_format", + type=click.Choice(["preprocessed", "config_files"], case_sensitive=False), + help="Data format for input. If 'config_files', expects a yaml, fasta, or directory containing multiple of those." + "If preprocessed, expects a folder with manifest.json, msa/ folder and structures/ folder with preprocessed data.", + default="config_files", +) def predict( # noqa: C901, PLR0915, PLR0912 data: str, out_dir: str, @@ -1073,6 +1103,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 num_subsampled_msa: int = 1024, no_kernels: bool = False, write_embeddings: bool = False, + input_format: str = "config_files", ) -> None: """Run predictions with Boltz.""" # If cpu, write a friendly warning @@ -1139,69 +1170,77 @@ def predict( # noqa: C901, PLR0915, PLR0912 msg = f"Model {model} not supported. Supported: boltz1, boltz2." raise ValueError(f"Model {model} not supported.") - # Validate inputs - data = check_inputs(data) - - # Check method - if method is not None: - if model == "boltz1": - msg = "Method conditioning is not supported for Boltz-1." - raise ValueError(msg) - if method.lower() not in const.method_types_ids: - method_names = list(const.method_types_ids.keys()) - msg = f"Method {method} not supported. Supported: {method_names}" - raise ValueError(msg) - - # Process inputs - ccd_path = cache / "ccd.pkl" mol_dir = cache / "mols" - process_inputs( - data=data, - out_dir=out_dir, - ccd_path=ccd_path, - mol_dir=mol_dir, - use_msa_server=use_msa_server, - msa_server_url=msa_server_url, - msa_pairing_strategy=msa_pairing_strategy, - msa_server_username=msa_server_username, - msa_server_password=msa_server_password, - api_key_header=api_key_header, - api_key_value=api_key_value, - boltz2=model == "boltz2", - preprocessing_threads=preprocessing_threads, - max_msa_seqs=max_msa_seqs, - ) - # Load manifest - manifest = Manifest.load(out_dir / "processed" / "manifest.json") + if input_format == "preprocessed": + processed_dir = data + manifest = Manifest.load(processed_dir / "manifest.json") + filtered_manifest = filter_inputs_structure( + manifest=manifest, + outdir=out_dir, + override=override, + ) + processed = BoltzProcessedInput( + manifest=filtered_manifest, + targets_dir=processed_dir / "structures", + msa_dir=processed_dir / "msa", + constraints_dir=((processed_dir / "constraints") if (processed_dir / "constraints").exists() else None), + template_dir=((processed_dir / "templates") if (processed_dir / "templates").exists() else None), + extra_mols_dir=((processed_dir / "extra_mols") if (processed_dir / "extra_mols").exists() else None), + ) + else: + # Validate inputs + data = check_inputs(data) + + # Check method + if method is not None: + if model == "boltz1": + msg = "Method conditioning is not supported for Boltz-1." + raise ValueError(msg) + if method.lower() not in const.method_types_ids: + method_names = list(const.method_types_ids.keys()) + msg = f"Method {method} not supported. Supported: {method_names}" + raise ValueError(msg) + + # Process inputs + ccd_path = cache / "ccd.pkl" + process_inputs( + data=data, + out_dir=out_dir, + ccd_path=ccd_path, + mol_dir=mol_dir, + use_msa_server=use_msa_server, + msa_server_url=msa_server_url, + msa_pairing_strategy=msa_pairing_strategy, + msa_server_username=msa_server_username, + msa_server_password=msa_server_password, + api_key_header=api_key_header, + api_key_value=api_key_value, + boltz2=model == "boltz2", + preprocessing_threads=preprocessing_threads, + max_msa_seqs=max_msa_seqs, + ) - # Filter out existing predictions - filtered_manifest = filter_inputs_structure( - manifest=manifest, - outdir=out_dir, - override=override, - ) + # Load manifest + manifest = Manifest.load(out_dir / "processed" / "manifest.json") - # Load processed data - processed_dir = out_dir / "processed" - processed = BoltzProcessedInput( - manifest=filtered_manifest, - targets_dir=processed_dir / "structures", - msa_dir=processed_dir / "msa", - constraints_dir=( - (processed_dir / "constraints") - if (processed_dir / "constraints").exists() - else None - ), - template_dir=( - (processed_dir / "templates") - if (processed_dir / "templates").exists() - else None - ), - extra_mols_dir=( - (processed_dir / "mols") if (processed_dir / "mols").exists() else None - ), - ) + # Filter out existing predictions + filtered_manifest = filter_inputs_structure( + manifest=manifest, + outdir=out_dir, + override=override, + ) + + # Load processed data + processed_dir = out_dir / "processed" + processed = BoltzProcessedInput( + manifest=filtered_manifest, + targets_dir=processed_dir / "structures", + msa_dir=processed_dir / "msa", + constraints_dir=((processed_dir / "constraints") if (processed_dir / "constraints").exists() else None), + template_dir=((processed_dir / "templates") if (processed_dir / "templates").exists() else None), + extra_mols_dir=((processed_dir / "mols") if (processed_dir / "mols").exists() else None), + ) # Set up trainer strategy = "auto" diff --git a/src/boltz/model/layers/attentionv2.py b/src/boltz/model/layers/attentionv2.py index 6381f69fd..ad9a9c820 100644 --- a/src/boltz/model/layers/attentionv2.py +++ b/src/boltz/model/layers/attentionv2.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + from typing import Optional import torch @@ -98,13 +121,14 @@ def forward( with torch.autocast("cuda", enabled=False): # Compute attention weights - attn = torch.einsum("bihd,bjhd->bhij", q.float(), k.float()) - attn = attn / (self.head_dim**0.5) + bias.float() - attn = attn + (1 - mask[:, None, None].float()) * -self.inf + compute_dtype = torch.promote_types(q.dtype, torch.float32) + attn = torch.einsum("bihd,bjhd->bhij", q.to(compute_dtype), k.to(compute_dtype)) + attn = attn / (self.head_dim**0.5) + bias.to(compute_dtype) + attn = attn + (1 - mask[:, None, None].to(compute_dtype)) * -self.inf attn = attn.softmax(dim=-1) # Compute output - o = torch.einsum("bhij,bjhd->bihd", attn, v.float()).to(v.dtype) + o = torch.einsum("bhij,bjhd->bihd", attn, v.to(compute_dtype)).to(v.dtype) o = o.reshape(B, -1, self.c_s) o = self.proj_o(g * o) diff --git a/src/boltz/model/layers/confidence_utils.py b/src/boltz/model/layers/confidence_utils.py index de9eb50e5..01a9e9c5d 100644 --- a/src/boltz/model/layers/confidence_utils.py +++ b/src/boltz/model/layers/confidence_utils.py @@ -1,3 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off import torch from torch import nn @@ -30,13 +52,34 @@ def compute_frame_pred( ).squeeze(-1) B, N, _ = pred_atom_coords.shape - pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3) + if B % multiplicity != 0: + raise ValueError( + f"pred_atom_coords batch dim ({B}) not divisible by multiplicity ({multiplicity})" + ) + if resolved_mask is not None and resolved_mask.shape != pred_atom_coords.shape[:2]: + raise ValueError( + f"resolved_mask shape {tuple(resolved_mask.shape)} must match " + f"pred_atom_coords[:2] {tuple(pred_atom_coords.shape[:2])}" + ) + B_batch = B // multiplicity + if frames_idx_true.shape[0] != B_batch: + raise ValueError( + f"frames_idx_true batch dim ({frames_idx_true.shape[0]}) must equal " + f"B // multiplicity ({B_batch})" + ) + pred_atom_coords = pred_atom_coords.reshape(B_batch, multiplicity, -1, 3) frames_idx_pred = ( frames_idx_true.clone() .repeat_interleave(multiplicity, 0) - .reshape(B // multiplicity, multiplicity, -1, 3) + .reshape(B_batch, multiplicity, -1, 3) ) + # resolved_mask arrives as (B*mult, N_atom). Reshape to (B_batch, mult, + # N_atom) so that each diffusion sample's per-sample resolved mask is + # preserved (symmetry_correction can produce different masks per sample). + if resolved_mask is not None: + resolved_mask = resolved_mask.reshape(B_batch, multiplicity, -1) + # Iterate through the batch and modify the frames for nonpolymers for i, pred_atom_coord in enumerate(pred_atom_coords): token_idx = 0 @@ -69,10 +112,15 @@ def compute_frame_pred( indices = torch.sort(dist_mat + resolved_pair, axis=2).indices else: if resolved_mask is None: - resolved_mask = feats["atom_resolved_mask"] + # atom_resolved_mask is (B_batch, N_atom); expand to + # (B_batch, mult, N_atom) so indexing is uniform. + resolved_mask = feats["atom_resolved_mask"][:, None, :].expand( + -1, multiplicity, -1 + ) + # resolved_mask[i]: (mult, N_atom) + rm_chain = resolved_mask[i][:, mask_chain_atom.bool()] # (mult, N_chain) resolved_pair = 1 - ( - resolved_mask[i][mask_chain_atom.bool()][None, :] - * resolved_mask[i][mask_chain_atom.bool()][:, None] + rm_chain[:, None, :] * rm_chain[:, :, None] ).to(torch.float32) resolved_pair[resolved_pair == 1] = torch.inf indices = torch.sort(dist_mat + resolved_pair, axis=2).indices diff --git a/src/boltz/model/layers/outer_product_mean.py b/src/boltz/model/layers/outer_product_mean.py index 9a4a607d0..c46f1fceb 100644 --- a/src/boltz/model/layers/outer_product_mean.py +++ b/src/boltz/model/layers/outer_product_mean.py @@ -1,3 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + import torch from torch import Tensor, nn @@ -58,21 +80,15 @@ def forward(self, m: Tensor, mask: Tensor, chunk_size: int = None) -> Tensor: # Compute pairwise mask for i in range(0, mask.shape[1], 64): if i == 0: - num_mask = ( - mask[:, i : i + 64, None, :] * mask[:, i : i + 64, :, None] - ).sum(1) + num_mask = (mask[:, i : i + 64, None, :] * mask[:, i : i + 64, :, None]).sum(1) else: - num_mask += ( - mask[:, i : i + 64, None, :] * mask[:, i : i + 64, :, None] - ).sum(1) + num_mask += (mask[:, i : i + 64, None, :] * mask[:, i : i + 64, :, None]).sum(1) num_mask = num_mask.clamp(min=1) # Compute squentially in chunks for i in range(0, self.c_hidden, chunk_size): a_chunk = a[:, :, :, i : i + chunk_size] - sliced_weight_proj_o = self.proj_o.weight[ - :, i * self.c_hidden : (i + chunk_size) * self.c_hidden - ] + sliced_weight_proj_o = self.proj_o.weight[:, i * self.c_hidden : (i + chunk_size) * self.c_hidden] z = torch.einsum("bsic,bsjd->bijcd", a_chunk, b) z = z.reshape(*z.shape[:3], -1) @@ -89,7 +105,10 @@ def forward(self, m: Tensor, mask: Tensor, chunk_size: int = None) -> Tensor: else: mask = mask[:, :, None, :] * mask[:, :, :, None] num_mask = mask.sum(1).clamp(min=1) - z = torch.einsum("bsic,bsjd->bijcd", a.float(), b.float()) + # Cast to at least float32 for numerical stability, using + # promote_types to preserve higher-precision dtypes (e.g. float64). + compute_dtype = torch.promote_types(a.dtype, torch.float32) + z = torch.einsum("bsic,bsjd->bijcd", a.to(compute_dtype), b.to(compute_dtype)) z = z.reshape(*z.shape[:3], -1) z = z / num_mask diff --git a/src/boltz/model/layers/pairformer.py b/src/boltz/model/layers/pairformer.py index 7edadbfe9..020c39952 100644 --- a/src/boltz/model/layers/pairformer.py +++ b/src/boltz/model/layers/pairformer.py @@ -1,3 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + from typing import Optional import torch @@ -47,19 +69,13 @@ def __init__( self.tri_mul_out = TriangleMultiplicationOutgoing(token_z) self.tri_mul_in = TriangleMultiplicationIncoming(token_z) - self.tri_att_start = TriangleAttentionStartingNode( - token_z, pairwise_head_width, pairwise_num_heads, inf=1e9 - ) - self.tri_att_end = TriangleAttentionEndingNode( - token_z, pairwise_head_width, pairwise_num_heads, inf=1e9 - ) + self.tri_att_start = TriangleAttentionStartingNode(token_z, pairwise_head_width, pairwise_num_heads, inf=1e9) + self.tri_att_end = TriangleAttentionEndingNode(token_z, pairwise_head_width, pairwise_num_heads, inf=1e9) self.transition_s = Transition(token_s, token_s * 4) self.transition_z = Transition(token_z, token_z * 4) - self.s_post_norm = ( - nn.LayerNorm(token_s) if self.post_layer_norm else nn.Identity() - ) + self.s_post_norm = nn.LayerNorm(token_s) if self.post_layer_norm else nn.Identity() def forward( self, @@ -74,14 +90,10 @@ def forward( ) -> tuple[Tensor, Tensor]: # Compute pairwise stack dropout = get_dropout_mask(self.dropout, z, self.training) - z = z + dropout * self.tri_mul_out( - z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels - ) + z = z + dropout * self.tri_mul_out(z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels) dropout = get_dropout_mask(self.dropout, z, self.training) - z = z + dropout * self.tri_mul_in( - z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels - ) + z = z + dropout * self.tri_mul_in(z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels) dropout = get_dropout_mask(self.dropout, z, self.training) z = z + dropout * self.tri_att_start( @@ -103,11 +115,12 @@ def forward( # Compute sequence stack with torch.autocast("cuda", enabled=False): - s_normed = self.pre_norm_s(s.float()) - s = s.float() + self.attention( - s=s_normed, z=z.float(), mask=mask.float(), k_in=s_normed + safe_dtype = torch.promote_types(s.dtype, torch.float32) + s_normed = self.pre_norm_s(s.to(dtype=safe_dtype)) + s = s.to(dtype=safe_dtype) + self.attention( + s=s_normed, z=z.to(dtype=safe_dtype), mask=mask.to(dtype=safe_dtype), k_in=s_normed ) - s = s + self.transition_s(s) + s = s + self.transition_s(s.to(dtype=safe_dtype)) s = self.s_post_norm(s) return s, z @@ -220,12 +233,8 @@ def __init__( self.tri_mul_out = TriangleMultiplicationOutgoing(token_z) self.tri_mul_in = TriangleMultiplicationIncoming(token_z) - self.tri_att_start = TriangleAttentionStartingNode( - token_z, pairwise_head_width, pairwise_num_heads, inf=1e9 - ) - self.tri_att_end = TriangleAttentionEndingNode( - token_z, pairwise_head_width, pairwise_num_heads, inf=1e9 - ) + self.tri_att_start = TriangleAttentionStartingNode(token_z, pairwise_head_width, pairwise_num_heads, inf=1e9) + self.tri_att_end = TriangleAttentionEndingNode(token_z, pairwise_head_width, pairwise_num_heads, inf=1e9) self.transition_z = Transition(token_z, token_z * 4) @@ -240,14 +249,10 @@ def forward( ) -> Tensor: # Compute pairwise stack dropout = get_dropout_mask(self.dropout, z, self.training) - z = z + dropout * self.tri_mul_out( - z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels - ) + z = z + dropout * self.tri_mul_out(z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels) dropout = get_dropout_mask(self.dropout, z, self.training) - z = z + dropout * self.tri_mul_in( - z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels - ) + z = z + dropout * self.tri_mul_in(z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels) dropout = get_dropout_mask(self.dropout, z, self.training) z = z + dropout * self.tri_att_start( diff --git a/src/boltz/model/layers/triangular_attention/primitives.py b/src/boltz/model/layers/triangular_attention/primitives.py index 26aabc7b8..7366af634 100644 --- a/src/boltz/model/layers/triangular_attention/primitives.py +++ b/src/boltz/model/layers/triangular_attention/primitives.py @@ -1,3 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + # Copyright 2021 AlQuraishi Laboratory # Copyright 2021 DeepMind Technologies Limited # @@ -13,6 +35,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib import math from typing import Callable, List, Optional, Tuple @@ -26,6 +49,10 @@ permute_final_dims, ) +trifast_is_installed = importlib.util.find_spec("trifast") is not None + +cueq_is_installed = importlib.util.find_spec("cuequivariance_torch.primitives.triangle") is not None + class Linear(nn.Linear): """ @@ -104,11 +131,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: d = input.dtype if self.precision is not None: with torch.autocast("cuda", enabled=False): - bias = ( - self.bias.to(dtype=self.precision) - if self.bias is not None - else None - ) + bias = self.bias.to(dtype=self.precision) if self.bias is not None else None return nn.functional.linear( input.to(dtype=self.precision), self.weight.to(dtype=self.precision), @@ -178,6 +201,7 @@ def _attention( key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor], + return_lse: bool = False, ) -> torch.Tensor: # [*, H, C_hidden, K] key = permute_final_dims(key, (1, 0)) @@ -188,17 +212,32 @@ def _attention( for b in biases: a += b - a = softmax_no_cast(a, -1) + if return_lse: + # [B, I, H, Q, 1] + amax = a.amax(dim=-1, keepdim=True) + # [B, I, H, Q, 1] + lse = torch.logsumexp(a - amax, dim=-1, keepdim=True) + # [B, I, H, Q, K] + a = torch.exp(a - amax - lse) + else: + amax = None + lse = None + # [B, I, H, Q, K] + a = softmax_no_cast(a, -1) # [*, H, Q, C_hidden] a = torch.matmul(a, value) + if return_lse: + return a, lse, amax + return a @torch.compiler.disable def kernel_triangular_attn(q, k, v, tri_bias, mask, scale): from cuequivariance_torch.primitives.triangle import triangle_attention + return triangle_attention(q, k, v, tri_bias, mask=mask, scale=scale) @@ -247,24 +286,14 @@ def __init__( # DISCREPANCY: c_hidden is not the per-head channel dimension, as # stated in the supplement, but the overall channel dimension. - self.linear_q = Linear( - self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot" - ) - self.linear_k = Linear( - self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot" - ) - self.linear_v = Linear( - self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot" - ) - self.linear_o = Linear( - self.c_hidden * self.no_heads, self.c_q, bias=False, init="final" - ) + self.linear_q = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_k = Linear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_v = Linear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, bias=False, init="final") self.linear_g = None if self.gating: - self.linear_g = Linear( - self.c_q, self.c_hidden * self.no_heads, bias=False, init="gating" - ) + self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="gating") self.sigmoid = nn.Sigmoid() diff --git a/src/boltz/model/layers/triangular_mult.py b/src/boltz/model/layers/triangular_mult.py index 5c52958af..6ed988cf3 100644 --- a/src/boltz/model/layers/triangular_mult.py +++ b/src/boltz/model/layers/triangular_mult.py @@ -1,3 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + import torch from torch import Tensor, nn @@ -20,6 +42,7 @@ def kernel_triangular_mult( eps, ): from cuequivariance_torch.primitives.triangle import triangle_multiplicative_update + return triangle_multiplicative_update( x, direction=direction, @@ -113,7 +136,8 @@ def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor: x = x * mask.unsqueeze(-1) # Split input and cast to float - a, b = torch.chunk(x.float(), 2, dim=-1) + safe_dtype = torch.promote_types(x.dtype, torch.float32) + a, b = torch.chunk(x.to(dtype=safe_dtype), 2, dim=-1) # Triangular projection x = torch.einsum("bikd,bjkd->bijd", a, b) @@ -201,7 +225,8 @@ def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor: x = x * mask.unsqueeze(-1) # Split input and cast to float - a, b = torch.chunk(x.float(), 2, dim=-1) + safe_dtype = torch.promote_types(x.dtype, torch.float32) + a, b = torch.chunk(x.to(dtype=safe_dtype), 2, dim=-1) # Triangular projection x = torch.einsum("bkid,bkjd->bijd", a, b) diff --git a/src/boltz/model/loss/bfactor.py b/src/boltz/model/loss/bfactor.py index c650f332e..4762a368c 100644 --- a/src/boltz/model/loss/bfactor.py +++ b/src/boltz/model/loss/bfactor.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + import torch from torch import Tensor @@ -22,14 +45,14 @@ def bfactor_loss_fn( """ with torch.autocast("cuda", enabled=False): - # Get predicted distograms - pred = output["pbfactor"].float() # (B, L, bins) + compute_dtype = torch.promote_types(output["pbfactor"].dtype, torch.float32) + pred = output["pbfactor"].to(compute_dtype) # (B, L, bins) bins = pred.shape[2] # num_bins token_to_rep_atom = feats["token_to_rep_atom"] # Compute target histogram bfactor_atom = feats["bfactor"].unsqueeze(-1) # (B, L) - bfactor_token = torch.bmm(token_to_rep_atom.float(), bfactor_atom) + bfactor_token = torch.bmm(token_to_rep_atom.to(compute_dtype), bfactor_atom) boundaries = torch.linspace(0, 100, bins - 1, device=bfactor_token.device) bfactor_token_bin = (bfactor_token > boundaries).sum(dim=-1).long() @@ -38,7 +61,7 @@ def bfactor_loss_fn( ) # Combine target mask and padding mask - token_mask = (bfactor_token > 1e-5).squeeze(-1).float() + token_mask = (bfactor_token > 1e-5).squeeze(-1).to(compute_dtype) # Compute the bfactor loss errors = -1 * torch.sum( diff --git a/src/boltz/model/loss/confidence.py b/src/boltz/model/loss/confidence.py index 7080c9d68..0db35dd10 100644 --- a/src/boltz/model/loss/confidence.py +++ b/src/boltz/model/loss/confidence.py @@ -113,18 +113,13 @@ def resolved_loss( # extract necessary features token_to_rep_atom = feats["token_to_rep_atom"] token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0).float() - ref_mask = torch.bmm( - token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float() - ).squeeze(-1) + ref_mask = torch.bmm(token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float()).squeeze(-1) pad_mask = feats["token_pad_mask"] pad_mask = pad_mask.repeat_interleave(multiplicity, 0).float() # compute loss log_softmax_resolved = torch.nn.functional.log_softmax(pred_resolved, dim=-1) - errors = ( - -ref_mask * log_softmax_resolved[:, :, 0] - - (1 - ref_mask) * log_softmax_resolved[:, :, 1] - ) + errors = -ref_mask * log_softmax_resolved[:, :, 0] - (1 - ref_mask) * log_softmax_resolved[:, :, 1] loss = torch.sum(errors * pad_mask, dim=-1) / (1e-7 + torch.sum(pad_mask, dim=-1)) # Average over the batch dimension @@ -200,24 +195,17 @@ def plddt_loss( # compute mask pair_mask = atom_mask.unsqueeze(-1) * atom_mask.unsqueeze(-2) - pair_mask = ( - pair_mask - * (1 - torch.eye(pair_mask.shape[1], device=pair_mask.device))[None, :, :] - ) + pair_mask = pair_mask * (1 - torch.eye(pair_mask.shape[1], device=pair_mask.device))[None, :, :] pair_mask = torch.einsum("bnm,bkm->bnk", pair_mask, R_set_to_rep_atom) pair_mask = torch.bmm(token_to_rep_atom, pair_mask) atom_mask = torch.bmm(token_to_rep_atom, atom_mask.unsqueeze(-1).float()) is_nucleotide_R_element = torch.bmm( R_set_to_rep_atom, torch.bmm(atom_to_token, is_nucleotide_token.unsqueeze(-1)) ).squeeze(-1) - cutoff = 15 + 15 * is_nucleotide_R_element.reshape(B, 1, -1).repeat( - 1, true_d.shape[1], 1 - ) + cutoff = 15 + 15 * is_nucleotide_R_element.reshape(B, 1, -1).repeat(1, true_d.shape[1], 1) # compute lddt - target_lddt, mask_no_match = lddt_dist( - pred_d, true_d, pair_mask, cutoff, per_atom=True - ) + target_lddt, mask_no_match = lddt_dist(pred_d, true_d, pair_mask, cutoff, per_atom=True) # compute loss num_bins = pred_lddt.shape[-1] @@ -229,9 +217,7 @@ def plddt_loss( dim=-1, ) atom_mask = atom_mask.squeeze(-1) - loss = torch.sum(errors * atom_mask * mask_no_match, dim=-1) / ( - 1e-7 + torch.sum(atom_mask * mask_no_match, dim=-1) - ) + loss = torch.sum(errors * atom_mask * mask_no_match, dim=-1) / (1e-7 + torch.sum(atom_mask * mask_no_match, dim=-1)) # Average over the batch dimension loss = torch.mean(loss) @@ -275,9 +261,7 @@ def pde_loss( # extract necessary features token_to_rep_atom = feats["token_to_rep_atom"] token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0).float() - token_mask = torch.bmm( - token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float() - ).squeeze(-1) + token_mask = torch.bmm(token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float()).squeeze(-1) mask = token_mask.unsqueeze(-1) * token_mask.unsqueeze(-2) # compute true pde @@ -297,9 +281,7 @@ def pde_loss( pde_one_hot * torch.nn.functional.log_softmax(pred_pde, dim=-1), dim=-1, ) - loss = torch.sum(errors * mask, dim=(-2, -1)) / ( - 1e-7 + torch.sum(mask, dim=(-2, -1)) - ) + loss = torch.sum(errors * mask, dim=(-2, -1)) / (1e-7 + torch.sum(mask, dim=(-2, -1))) # Average over the batch dimension loss = torch.mean(loss) @@ -381,15 +363,11 @@ def pae_loss( pred_atom_coords, frame_pred_atom_a, frame_pred_atom_b, frame_pred_atom_c ) - target_pae = torch.sqrt( - ((true_coords_transformed - pred_coords_transformed) ** 2).sum(-1) + 1e-8 - ) + target_pae = torch.sqrt(((true_coords_transformed - pred_coords_transformed) ** 2).sum(-1) + 1e-8) # Compute mask for the pae loss b_true_resolved_mask = true_coords_resolved_mask[ - torch.arange(B // multiplicity)[:, None, None].to( - pred_coords_transformed.device - ), + torch.arange(0, B, multiplicity)[:, None, None].to(pred_coords_transformed.device), frame_true_atom_b, ] @@ -408,13 +386,10 @@ def pae_loss( bin_index = torch.clamp(bin_index, max=(num_bins - 1)) pae_one_hot = nn.functional.one_hot(bin_index, num_classes=num_bins) errors = -1 * torch.sum( - pae_one_hot - * torch.nn.functional.log_softmax(pred_pae.reshape(pae_one_hot.shape), dim=-1), + pae_one_hot * torch.nn.functional.log_softmax(pred_pae.reshape(pae_one_hot.shape), dim=-1), dim=-1, ) - loss = torch.sum(errors * pair_mask, dim=(-2, -1)) / ( - 1e-7 + torch.sum(pair_mask, dim=(-2, -1)) - ) + loss = torch.sum(errors * pair_mask, dim=(-2, -1)) / (1e-7 + torch.sum(pair_mask, dim=(-2, -1))) # Average over the batch dimension loss = torch.mean(loss) @@ -428,10 +403,7 @@ def lddt_dist(dmat_predicted, dmat_true, mask, cutoff=15.0, per_atom=False): dist_l1 = torch.abs(dmat_true - dmat_predicted) score = 0.25 * ( - (dist_l1 < 0.5).float() - + (dist_l1 < 1.0).float() - + (dist_l1 < 2.0).float() - + (dist_l1 < 4.0).float() + (dist_l1 < 0.5).float() + (dist_l1 < 1.0).float() + (dist_l1 < 2.0).float() + (dist_l1 < 4.0).float() ) # Normalize over the appropriate axes. @@ -501,17 +473,18 @@ def compute_frame_pred( ): # extract necessary features asym_id_token = feats["asym_id"] - asym_id_atom = torch.bmm( - feats["atom_to_token"].float(), asym_id_token.unsqueeze(-1).float() - ).squeeze(-1) + asym_id_atom = torch.bmm(feats["atom_to_token"].float(), asym_id_token.unsqueeze(-1).float()).squeeze(-1) B, N, _ = pred_atom_coords.shape pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3) frames_idx_pred = ( - frames_idx_true.clone() - .repeat_interleave(multiplicity, 0) - .reshape(B // multiplicity, multiplicity, -1, 3) + frames_idx_true.clone().repeat_interleave(multiplicity, 0).reshape(B // multiplicity, multiplicity, -1, 3) ) + # resolved_mask is (B*mult, N_atom); reduce to (B_batch, N_atom) so indexing + # by batch element i is correct (frames are shared across multiplicity copies). + if resolved_mask is not None: + resolved_mask = resolved_mask[::multiplicity] + # Iterate through the batch and update the frames for nonpolymers for i, pred_atom_coord in enumerate(pred_atom_coords): token_idx = 0 @@ -521,10 +494,7 @@ def compute_frame_pred( mask_chain_atom = (asym_id_atom[i] == id) * feats["atom_pad_mask"][i] num_tokens = int(mask_chain_token.sum().item()) num_atoms = int(mask_chain_atom.sum().item()) - if ( - feats["mol_type"][i, token_idx] != const.chain_type_ids["NONPOLYMER"] - or num_atoms < 3 - ): + if feats["mol_type"][i, token_idx] != const.chain_type_ids["NONPOLYMER"] or num_atoms < 3: token_idx += num_tokens atom_idx += num_atoms continue @@ -543,7 +513,7 @@ def compute_frame_pred( * feats["atom_pad_mask"][i][mask_chain_atom.bool()][:, None] ).to(torch.float32) resolved_pair[resolved_pair == 1] = torch.inf - indices = torch.sort(dist_mat + resolved_pair, axis=2).indices + indices = torch.sort(dist_mat + resolved_pair, axis=2, stable=True).indices else: if resolved_mask is None: resolved_mask = feats["atom_resolved_mask"] @@ -552,7 +522,7 @@ def compute_frame_pred( * resolved_mask[i][mask_chain_atom.bool()][:, None] ).to(torch.float32) resolved_pair[resolved_pair == 1] = torch.inf - indices = torch.sort(dist_mat + resolved_pair, axis=2).indices + indices = torch.sort(dist_mat + resolved_pair, axis=2, stable=True).indices # Compute the frames frames = ( @@ -572,12 +542,8 @@ def compute_frame_pred( # Expand the frames with the multiplicity frames_expanded = pred_atom_coords[ - torch.arange(0, B // multiplicity, 1)[:, None, None, None].to( - frames_idx_pred.device - ), - torch.arange(0, multiplicity, 1)[None, :, None, None].to( - frames_idx_pred.device - ), + torch.arange(0, B // multiplicity, 1)[:, None, None, None].to(frames_idx_pred.device), + torch.arange(0, multiplicity, 1)[None, :, None, None].to(frames_idx_pred.device), frames_idx_pred, ].reshape(-1, 3, 3) diff --git a/src/boltz/model/loss/confidencev2.py b/src/boltz/model/loss/confidencev2.py index 9f641ba48..bbf6ce5ba 100644 --- a/src/boltz/model/loss/confidencev2.py +++ b/src/boltz/model/loss/confidencev2.py @@ -1,8 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off import torch from torch import nn from boltz.data import const -from boltz.model.layers.confidence_utils import compute_frame_pred, tm_function +from boltz.model.layers.confidence_utils import compute_frame_pred def confidence_loss( @@ -96,24 +118,28 @@ def resolved_loss( mask_loss=None, ): with torch.autocast("cuda", enabled=False): + # Bugfix: use promote_types instead of hardcoded .float() so that + # float64 test inputs are not silently truncated to float32, which + # causes dtype mismatches between serial and distributed code paths. + compute_dtype = torch.promote_types(pred_resolved.dtype, torch.float32) if token_level_confidence: token_to_rep_atom = feats["token_to_rep_atom"] token_to_rep_atom = token_to_rep_atom.repeat_interleave( multiplicity, 0 - ).float() + ).to(dtype=compute_dtype) ref_mask = torch.bmm( - token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float() + token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).to(dtype=compute_dtype) ).squeeze(-1) pad_mask = feats["token_pad_mask"] - pad_mask = pad_mask.repeat_interleave(multiplicity, 0).float() + pad_mask = pad_mask.repeat_interleave(multiplicity, 0).to(dtype=compute_dtype) else: - ref_mask = true_coords_resolved_mask.float() + ref_mask = true_coords_resolved_mask.to(dtype=compute_dtype) pad_mask = feats["atom_pad_mask"] - pad_mask = pad_mask.repeat_interleave(multiplicity, 0).float() + pad_mask = pad_mask.repeat_interleave(multiplicity, 0).to(dtype=compute_dtype) # compute loss log_softmax_resolved = torch.nn.functional.log_softmax( - pred_resolved.float(), dim=-1 + pred_resolved.to(dtype=compute_dtype), dim=-1 ) errors = ( -ref_mask * log_softmax_resolved[:, :, 0] @@ -128,7 +154,7 @@ def resolved_loss( mask_loss = ( mask_loss.repeat_interleave(multiplicity, 0) .reshape(-1, multiplicity) - .float() + .to(dtype=compute_dtype) ) loss = torch.sum(loss.reshape(-1, multiplicity) * mask_loss) / ( torch.sum(mask_loss) + 1e-7 @@ -410,11 +436,22 @@ def get_target_pae( ((true_coords_transformed - pred_coords_transformed) ** 2).sum(-1) + 1e-8 ) - # Compute mask for the pae loss - b_true_resolved_mask = true_coords_resolved_mask[ - torch.arange(B // multiplicity)[:, None, None].to( - pred_coords_transformed.device - ), + # Reshape to (B_batch, mult, N_atom) so each diffusion sample uses + # its own resolved mask (symmetry_correction can differ per sample). + B = true_coords_resolved_mask.shape[0] + if B % multiplicity != 0: + raise ValueError( + f"true_coords_resolved_mask batch dim ({B}) not divisible by multiplicity ({multiplicity})" + ) + if true_coords_resolved_mask.ndim != 2: + raise ValueError( + f"true_coords_resolved_mask must be 2D, got ndim={true_coords_resolved_mask.ndim}" + ) + B_batch = B // multiplicity + resolved_mask_3d = true_coords_resolved_mask.reshape(B_batch, multiplicity, -1) + b_true_resolved_mask = resolved_mask_3d[ + torch.arange(B_batch)[:, None, None].to(pred_coords_transformed.device), + torch.arange(multiplicity)[None, :, None].to(pred_coords_transformed.device), frame_true_atom_b, ] diff --git a/src/boltz/model/loss/diffusion.py b/src/boltz/model/loss/diffusion.py index 3433e4299..1924199e0 100644 --- a/src/boltz/model/loss/diffusion.py +++ b/src/boltz/model/loss/diffusion.py @@ -1,8 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang -from einops import einsum import torch import torch.nn.functional as F +from einops import einsum def weighted_rigid_align( @@ -57,11 +80,12 @@ def weighted_rigid_align( weights * pred_coords_centered, true_coords_centered, "b n i, b n j -> b i j" ) - # Compute the SVD of the covariance matrix, required float32 for svd and determinant + # SVD requires at least float32; preserve float64 when present. original_dtype = cov_matrix.dtype - cov_matrix_32 = cov_matrix.to(dtype=torch.float32) + svd_dtype = torch.promote_types(cov_matrix.dtype, torch.float32) + cov_matrix_svd = cov_matrix.to(dtype=svd_dtype) U, S, V = torch.linalg.svd( - cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None + cov_matrix_svd, driver="gesvd" if cov_matrix_svd.is_cuda else None ) V = V.mH @@ -74,10 +98,10 @@ def weighted_rigid_align( ) # Compute the rotation matrix - rot_matrix = torch.einsum("b i j, b k j -> b i k", U, V).to(dtype=torch.float32) + rot_matrix = torch.einsum("b i j, b k j -> b i k", U, V).to(dtype=svd_dtype) # Ensure proper rotation matrix with determinant 1 - F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[ + F = torch.eye(dim, dtype=svd_dtype, device=cov_matrix.device)[ None ].repeat(batch_size, 1, 1) F[:, -1, -1] = torch.det(rot_matrix) @@ -148,22 +172,17 @@ def smooth_lddt_loss( dist_diff = torch.abs(true_dists - pred_dists) # Compute epsilon values + # Fixed the bug in v1, as it should be ".view(B // multiplicity, multiplicity, N, N).mean(dim=1).repeat_interleave(multiplicity, 0)" + # instead of ".view(multiplicity, B // multiplicity, N, N).mean(dim=0).repeat_interleave(multiplicity, 0)" + # Here we use the same but simplified version as in diffusionv2. eps = ( - ( - ( - F.sigmoid(0.5 - dist_diff) - + F.sigmoid(1.0 - dist_diff) - + F.sigmoid(2.0 - dist_diff) - + F.sigmoid(4.0 - dist_diff) - ) - / 4.0 - ) - .view(multiplicity, B // multiplicity, N, N) - .mean(dim=0) - ) + F.sigmoid(0.5 - dist_diff) + + F.sigmoid(1.0 - dist_diff) + + F.sigmoid(2.0 - dist_diff) + + F.sigmoid(4.0 - dist_diff) + ) / 4.0 # Calculate masked averaging - eps = eps.repeat_interleave(multiplicity, 0) num = (eps * mask).sum(dim=(-1, -2)) den = mask.sum(dim=(-1, -2)).clamp(min=1) lddt = num / den diff --git a/src/boltz/model/loss/diffusionv2.py b/src/boltz/model/loss/diffusionv2.py index 457ab838c..cd486fd62 100644 --- a/src/boltz/model/loss/diffusionv2.py +++ b/src/boltz/model/loss/diffusionv2.py @@ -1,9 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang -import einx +import einx # noqa: F401 import torch import torch.nn.functional as F -from einops import einsum, rearrange +from einops import einsum, rearrange # noqa: F401 def weighted_rigid_align( @@ -44,12 +67,13 @@ def weighted_rigid_align( "... n i, ... n j -> ... i j", ) - # Compute the SVD of the covariance matrix, required float32 for svd and determinant + # SVD requires at least float32; preserve float64 when present. original_dtype = cov_matrix.dtype - cov_matrix_32 = cov_matrix.to(dtype=torch.float32) + svd_dtype = torch.promote_types(cov_matrix.dtype, torch.float32) + cov_matrix_svd = cov_matrix.to(dtype=svd_dtype) U, S, V = torch.linalg.svd( - cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None + cov_matrix_svd, driver="gesvd" if cov_matrix_svd.is_cuda else None ) V = V.mH @@ -63,11 +87,11 @@ def weighted_rigid_align( # Compute the rotation matrix rot_matrix = torch.einsum("... i j, ... k j -> ... i k", U, V).to( - dtype=torch.float32 + dtype=svd_dtype ) # Ensure proper rotation matrix with determinant 1 - F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[ + F = torch.eye(dim, dtype=svd_dtype, device=cov_matrix.device)[ None ].repeat(*batch_size, 1, 1) F[..., -1, -1] = torch.det(rot_matrix) @@ -110,8 +134,9 @@ def smooth_lddt_loss( -1, is_nucleotide_i.shape[-1] ) - mask = is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).float() - mask += (1 - is_nucleotide_pair) * (true_dists < other_cutoff).float() + compute_dtype = torch.promote_types(pred_coords.dtype, torch.float32) + mask = is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).to(compute_dtype) + mask += (1 - is_nucleotide_pair) * (true_dists < other_cutoff).to(compute_dtype) mask *= 1 - torch.eye(pred_coords.shape[1], device=pred_coords.device) mask *= coords_mask_i.unsqueeze(-1) mask *= coords_mask_i.unsqueeze(-2) diff --git a/src/boltz/model/loss/distogramv2.py b/src/boltz/model/loss/distogramv2.py index 1a72a28ac..ac309c13e 100644 --- a/src/boltz/model/loss/distogramv2.py +++ b/src/boltz/model/loss/distogramv2.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + import torch from torch import Tensor @@ -26,7 +49,8 @@ def distogram_loss( """ with torch.autocast("cuda", enabled=False): # Get predicted distograms - pred = output["pdistogram"].float() # (B, L, L, num_distograms, disto_bins) + compute_dtype = torch.promote_types(output["pdistogram"].dtype, torch.float32) + pred = output["pdistogram"].to(compute_dtype) # (B, L, L, num_distograms, disto_bins) D = pred.shape[3] # num_distograms # noqa: N806 assert len(pred.shape) == 5 # noqa: PLR2004 diff --git a/src/boltz/model/loss/inference.py b/src/boltz/model/loss/inference.py index 252c7587a..e97669c80 100644 --- a/src/boltz/model/loss/inference.py +++ b/src/boltz/model/loss/inference.py @@ -1,3 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off import torch from boltz.data import const @@ -112,7 +134,7 @@ def compute_pb_geometry_metrics( .squeeze(-1) .long() ) - is_ligand_mask = ( + _is_ligand_mask = ( # noqa: F841 torch.bmm( feats["atom_to_token"].float(), feats["mol_type"].unsqueeze(-1).float() ) @@ -135,7 +157,7 @@ def compute_pb_geometry_metrics( multiplicity, dtype=torch.float32, device=pred_atom_coords.device ) - for index_batch in range(len(feats["ligand_edge_index"])): + for index_batch in range(len(feats.get("ligand_edge_index", []))): if feats["ligand_edge_index"][index_batch].shape[1] == 0: continue dists = torch.linalg.norm( @@ -275,7 +297,7 @@ def compute_torsion_angles(coords, torsion_index): n_ijk = torch.cross(r_ij, r_kj, dim=-1) n_jkl = torch.cross(r_kj, r_kl, dim=-1) - r_kj_norm = torch.linalg.norm(r_kj, dim=-1) + _r_kj_norm = torch.linalg.norm(r_kj, dim=-1) # noqa: F841 n_ijk_norm = torch.linalg.norm(n_ijk, dim=-1) n_jkl_norm = torch.linalg.norm(n_jkl, dim=-1) @@ -308,7 +330,7 @@ def compute_stereo_metrics(pred_atom_coords, feats): multiplicity, dtype=torch.float32, device=pred_atom_coords.device ) - for index_batch in range(len(feats["ligand_edge_index"])): + for index_batch in range(len(feats.get("ligand_edge_index", []))): if feats["ligand_chiral_atom_index"][index_batch].shape[1] > 0: pred_chiral_torsion_angles = compute_torsion_angles( pred_atom_coords, @@ -378,7 +400,7 @@ def compute_pb_flatness_metrics(pred_atom_coords, feats, buffer=0.25): multiplicity, dtype=torch.float32, device=pred_atom_coords.device ) - for index_batch in range(len(feats["ligand_aromatic_5_ring_index"])): + for index_batch in range(len(feats.get("ligand_aromatic_5_ring_index", []))): ring_5_index = feats["ligand_aromatic_5_ring_index"][index_batch].T ring_6_index = feats["ligand_aromatic_6_ring_index"][index_batch].T double_bond_index = feats["ligand_planar_double_bond_index"][index_batch].T diff --git a/src/boltz/model/loss/validation.py b/src/boltz/model/loss/validation.py index 00d1aa7c3..de1afec05 100644 --- a/src/boltz/model/loss/validation.py +++ b/src/boltz/model/loss/validation.py @@ -1,8 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off import torch from boltz.data import const -from boltz.model.loss.confidence import ( - compute_frame_pred, +from boltz.model.layers.confidence_utils import compute_frame_pred +from boltz.model.loss.confidencev2 import ( express_coordinate_in_frame, lddt_dist, ) @@ -749,11 +771,22 @@ def compute_pae_mae( + 0.25 ) - # Compute mask for the pae loss - b_true_resolved_mask = true_coords_resolved_mask[ - torch.arange(B // multiplicity)[:, None, None].to( - pred_coords_transformed.device - ), + # Reshape to (B_batch, mult, N_atom) so each diffusion sample uses + # its own resolved mask (symmetry_correction can differ per sample). + if true_coords_resolved_mask.shape[0] != B: + raise ValueError( + f"true_coords_resolved_mask batch dim ({true_coords_resolved_mask.shape[0]}) " + f"!= expected ({B})" + ) + if true_coords_resolved_mask.ndim != 2: + raise ValueError( + f"true_coords_resolved_mask must be 2D, got ndim={true_coords_resolved_mask.ndim}" + ) + B_batch = B // multiplicity + resolved_mask_3d = true_coords_resolved_mask.reshape(B_batch, multiplicity, -1) + b_true_resolved_mask = resolved_mask_3d[ + torch.arange(B_batch)[:, None, None].to(pred_coords_transformed.device), + torch.arange(multiplicity)[None, :, None].to(pred_coords_transformed.device), frame_true_atom_b, ] diff --git a/src/boltz/model/models/boltz2.py b/src/boltz/model/models/boltz2.py index 9d36bb3a3..ada2cc8b3 100644 --- a/src/boltz/model/models/boltz2.py +++ b/src/boltz/model/models/boltz2.py @@ -1,3 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off import gc from typing import Any, Optional @@ -308,7 +330,7 @@ def __init__( cyclic_pos_enc=cyclic_pos_enc, conditioning_cutoff_min=conditioning_cutoff_min, conditioning_cutoff_max=conditioning_cutoff_max, - **confidence_model_args, + **(confidence_model_args or {}), ) if compile_confidence: self.confidence_module = torch.compile( @@ -420,7 +442,7 @@ def forward( ) relative_position_encoding = self.rel_pos(feats) z_init = z_init + relative_position_encoding - z_init = z_init + self.token_bonds(feats["token_bonds"].float()) + z_init = z_init + self.token_bonds(feats["token_bonds"].to(dtype=z_init.dtype)) if self.bond_type_feature: z_init = z_init + self.token_bonds_type(feats["type_bonds"].long()) z_init = z_init + self.contact_conditioning(feats) @@ -430,7 +452,10 @@ def forward( z = torch.zeros_like(z_init) # Compute pairwise mask - mask = feats["token_pad_mask"].float() + # promote_types preserves float64 for testing while promoting lower + # dtypes to at least float32 (no-op at production float32). + compute_dtype = torch.promote_types(s_init.dtype, torch.float32) + mask = feats["token_pad_mask"].to(dtype=compute_dtype) pair_mask = mask[:, :, None] * mask[:, None, :] if self.run_trunk_and_structure: for i in range(recycling_steps + 1): @@ -525,11 +550,11 @@ def forward( if (not self.training) or self.confidence_prediction: with torch.autocast("cuda", enabled=False): struct_out = self.structure_module.sample( - s_trunk=s.float(), - s_inputs=s_inputs.float(), + s_trunk=s.to(compute_dtype), + s_inputs=s_inputs.to(compute_dtype), feats=feats, num_sampling_steps=num_sampling_steps, - atom_mask=feats["atom_pad_mask"].float(), + atom_mask=feats["atom_pad_mask"].to(compute_dtype), multiplicity=diffusion_samples, max_parallel_samples=max_parallel_samples, steering_args=self.steering_args, @@ -564,8 +589,8 @@ def forward( with torch.autocast("cuda", enabled=False): struct_out = self.structure_module( - s_trunk=s.float(), - s_inputs=s_inputs.float(), + s_trunk=s.to(compute_dtype), + s_inputs=s_inputs.to(compute_dtype), feats=feats, multiplicity=multiplicity_diffusion_train, diffusion_conditioning=diffusion_conditioning, @@ -923,7 +948,6 @@ def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor: return loss def training_log(self): - self.log("train/grad_norm", self.gradient_norm(self), prog_bar=False) self.log("train/param_norm", self.parameter_norm(self), prog_bar=False) lr = self.trainer.optimizers[0].param_groups[0]["lr"] @@ -949,15 +973,35 @@ def training_log(self): if self.confidence_prediction: self.log( - "train/grad_norm_confidence_module", - self.gradient_norm(self.confidence_module), + "train/param_norm_confidence_module", + self.parameter_norm(self.confidence_module), prog_bar=False, ) + + def on_after_backward(self): + if not (self.global_step % self.log_loss_every_steps): + self.log("train/grad_norm", self.gradient_norm(self), prog_bar=False) self.log( - "train/param_norm_confidence_module", - self.parameter_norm(self.confidence_module), + "train/grad_norm_msa_module", + self.gradient_norm(self.msa_module), + prog_bar=False, + ) + self.log( + "train/grad_norm_pairformer_module", + self.gradient_norm(self.pairformer_module), prog_bar=False, ) + self.log( + "train/grad_norm_structure_module", + self.gradient_norm(self.structure_module), + prog_bar=False, + ) + if self.confidence_prediction: + self.log( + "train/grad_norm_confidence_module", + self.gradient_norm(self.confidence_module), + prog_bar=False, + ) def on_train_epoch_end(self): if self.confidence_prediction: diff --git a/src/boltz/model/modules/confidencev2.py b/src/boltz/model/modules/confidencev2.py index 3dedc5f56..98b41c032 100644 --- a/src/boltz/model/modules/confidencev2.py +++ b/src/boltz/model/modules/confidencev2.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off +# ruff: noqa import torch from torch import nn from torch.nn.functional import pad @@ -15,6 +38,10 @@ ) from boltz.model.modules.utils import LinearNoBias +IPLDDT_LIGAND_WEIGHT = 20 +IPLDDT_INTERFACE_WEIGHT = 10 +IPLDDT_NON_INTERFACE_WEIGHT = 1 + class ConfidenceModule(nn.Module): """Algorithm 31""" @@ -165,7 +192,9 @@ def forward( if self.add_z_input_to_z: relative_position_encoding = self.rel_pos(feats) z = z + relative_position_encoding - z = z + self.token_bonds(feats["token_bonds"].float()) + z = z + self.token_bonds( + feats["token_bonds"].to(dtype=torch.promote_types(feats["token_bonds"].dtype, torch.float32)) + ) if self.bond_type_feature: z = z + self.token_bonds_type(feats["type_bonds"].long()) z = z + self.contact_conditioning(feats) @@ -193,7 +222,11 @@ def forward( x_pred = x_pred.reshape(B * mult, N, -1) else: BM, N, _ = x_pred.shape - x_pred_repr = torch.bmm(token_to_rep_atom.float(), x_pred) + compute_dtype = torch.promote_types(token_to_rep_atom.dtype, torch.float32) + x_pred_repr = torch.bmm( + token_to_rep_atom.to(dtype=compute_dtype), + x_pred.to(dtype=compute_dtype), + ) d = torch.cdist(x_pred_repr, x_pred_repr) distogram = (d.unsqueeze(-1) > self.boundaries).sum(dim=-1).long() distogram = self.dist_bin_pairwise_embed(distogram) @@ -283,7 +316,9 @@ def forward( ): if self.use_separate_heads: asym_id_token = feats["asym_id"] - is_same_chain = asym_id_token.unsqueeze(-1) == asym_id_token.unsqueeze(-2) + is_same_chain = asym_id_token.unsqueeze(-1) == asym_id_token.unsqueeze(-2) # (B, N, N) + if multiplicity > 1: + is_same_chain = is_same_chain.repeat_interleave(multiplicity, dim=0) is_different_chain = ~is_same_chain if self.use_separate_heads: @@ -314,9 +349,9 @@ def forward( resolved_logits = self.to_resolved_logits(s) plddt_logits = self.to_plddt_logits(s) - ligand_weight = 20 - non_interface_weight = 1 - interface_weight = 10 + ligand_weight = IPLDDT_LIGAND_WEIGHT + non_interface_weight = IPLDDT_NON_INTERFACE_WEIGHT + interface_weight = IPLDDT_INTERFACE_WEIGHT token_type = feats["mol_type"] token_type = token_type.repeat_interleave(multiplicity, 0) @@ -400,9 +435,10 @@ def forward( complex_plddt = (plddt * atom_pad_mask).sum(dim=-1) / atom_pad_mask.sum( dim=-1 ) - token_type = feats["mol_type"].float() - atom_to_token = feats["atom_to_token"].float() - chain_id_token = feats["asym_id"].float() + _promote = lambda t: t.to(dtype=torch.promote_types(t.dtype, torch.float32)) # noqa: E731 + token_type = _promote(feats["mol_type"]) + atom_to_token = _promote(feats["atom_to_token"]) + chain_id_token = _promote(feats["asym_id"]) atom_type = torch.bmm(atom_to_token, token_type.unsqueeze(-1)).squeeze(-1) is_ligand_atom = (atom_type == const.chain_type_ids["NONPOLYMER"]).float() d_atom = torch.cdist(x_pred, x_pred) diff --git a/src/boltz/model/modules/diffusion.py b/src/boltz/model/modules/diffusion.py index d209f2884..389cc4d26 100644 --- a/src/boltz/model/modules/diffusion.py +++ b/src/boltz/model/modules/diffusion.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang from __future__ import annotations @@ -16,7 +39,6 @@ smooth_lddt_loss, weighted_rigid_align, ) -from boltz.model.modules.utils import center_random_augmentation from boltz.model.modules.encoders import ( AtomAttentionDecoder, AtomAttentionEncoder, @@ -30,8 +52,8 @@ ) from boltz.model.modules.utils import ( LinearNoBias, - compute_random_augmentation, center_random_augmentation, + compute_random_augmentation, default, log, ) @@ -183,6 +205,11 @@ def forward( s_inputs=s_inputs.repeat_interleave(multiplicity, 0), ) + # Promote to at least float32 for numerical stability, but preserve + # higher precision (e.g. float64) if available. Replaces hardcoded .float() + # to support float64 DTensor testing. + compute_dtype = torch.promote_types(r_noisy.dtype, torch.float32) + if model_cache is None or len(model_cache) == 0: z = self.pairwise_conditioner( z_trunk=z_trunk, token_rel_pos_feats=relative_position_encoding @@ -206,7 +233,7 @@ def forward( mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) a = self.token_transformer( a, - mask=mask.float(), + mask=mask.to(compute_dtype), s=s, z=z, # note z is not expanded with multiplicity until after bias is computed multiplicity=multiplicity, @@ -411,7 +438,10 @@ def preconditioned_network_forward( batch, device = noised_atom_coords.shape[0], noised_atom_coords.device if isinstance(sigma, float): - sigma = torch.full((batch,), sigma, device=device) + # Preserve dtype of noised_atom_coords (e.g. float64 for testing). + # Without this, torch.full defaults to float32, causing dtype mismatch + # in downstream FourierEmbedding when the model is in float64. + sigma = torch.full((batch,), sigma, device=device, dtype=noised_atom_coords.dtype) padded_sigma = rearrange(sigma, "b -> b 1 1") @@ -473,12 +503,19 @@ def sample( device=self.device, ) + # Default to processing all multiplicity samples at once (no chunking). + # The original code did not handle max_parallel_samples=None, causing TypeError. + if max_parallel_samples is None: + max_parallel_samples = multiplicity + num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps) atom_mask = atom_mask.repeat_interleave(multiplicity, 0) shape = (*atom_mask.shape, 3) + # atom_mask is already (B*M, N_atoms) after repeat_interleave above, + # so atom_mask.shape[0] == B*M — use that for token_a's batch dimension. token_repr_shape = ( - multiplicity, + atom_mask.shape[0], network_condition_kwargs["feats"]["token_index"].shape[1], 2 * self.token_s, ) @@ -533,25 +570,43 @@ def sample( atom_coords_denoised = torch.zeros_like(atom_coords_noisy) token_a = torch.zeros(token_repr_shape).to(atom_coords_noisy) + B_local = atom_coords_noisy.shape[0] // multiplicity sample_ids = torch.arange(multiplicity).to(atom_coords_noisy.device) + # ceiling division: ceil(multiplicity / max_parallel_samples). + # The original formula (multiplicity % max_parallel_samples + 1) was incorrect + # when multiplicity is an exact non-trivial multiple of max_parallel_samples + # (e.g., multiplicity=4, max_parallel_samples=2 gave 1 chunk instead of 2). sample_ids_chunks = sample_ids.chunk( - multiplicity % max_parallel_samples + 1 + (multiplicity + max_parallel_samples - 1) // max_parallel_samples ) for sample_ids_chunk in sample_ids_chunks: + chunk_M = sample_ids_chunk.numel() + # atom_coords_noisy is (B*M, N, 3). We must unflatten to (B, M, N, 3) + # and index the M axis — not dim 0 directly — because + # preconditioned_network_forward passes **network_condition_kwargs + # (feats, s_inputs, s_trunk, z_trunk, etc.) which have the full B + # batch dimension. A naïve dim-0 index would select B*chunk_M rows + # that mix batch and multiplicity positions, causing a shape mismatch + # with the un-indexed conditioning tensors inside the score model. + noisy_chunk = atom_coords_noisy.unflatten(0, (B_local, multiplicity))[:, sample_ids_chunk].flatten(0, 1) atom_coords_denoised_chunk, token_a_chunk = ( self.preconditioned_network_forward( - atom_coords_noisy[sample_ids_chunk], + noisy_chunk, t_hat, training=False, network_condition_kwargs=dict( - multiplicity=sample_ids_chunk.numel(), + multiplicity=chunk_M, model_cache=model_cache, **network_condition_kwargs, ), ) ) - atom_coords_denoised[sample_ids_chunk] = atom_coords_denoised_chunk - token_a[sample_ids_chunk] = token_a_chunk + atom_coords_denoised.unflatten(0, (B_local, multiplicity))[:, sample_ids_chunk] = ( + atom_coords_denoised_chunk.unflatten(0, (B_local, chunk_M)) + ) + token_a.unflatten(0, (B_local, multiplicity))[:, sample_ids_chunk] = ( + token_a_chunk.unflatten(0, (B_local, chunk_M)) + ) if ( steering_args is not None @@ -705,7 +760,7 @@ def sample( atom_coords = atom_coords_next - return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr) + return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr) # noqa: C408 def loss_weight(self, sigma): return (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2) @@ -759,17 +814,17 @@ def forward( noised_atom_coords, sigmas, training=True, - network_condition_kwargs=dict( + network_condition_kwargs=dict( # noqa: C408 s_inputs=s_inputs, s_trunk=s_trunk, z_trunk=z_trunk, relative_position_encoding=relative_position_encoding, feats=feats, multiplicity=multiplicity, - ), + ) ) - return dict( + return dict( # noqa: C408 noised_atom_coords=noised_atom_coords, denoised_atom_coords=denoised_atom_coords, sigmas=sigmas, @@ -855,9 +910,9 @@ def compute_loss( total_loss = total_loss + lddt_loss - loss_breakdown = dict( + loss_breakdown = dict( # noqa: C408 mse_loss=mse_loss, smooth_lddt_loss=lddt_loss, ) - return dict(loss=total_loss, loss_breakdown=loss_breakdown) + return dict(loss=total_loss, loss_breakdown=loss_breakdown) # noqa: C408 diff --git a/src/boltz/model/modules/diffusionv2.py b/src/boltz/model/modules/diffusionv2.py index 7876ee181..d47acfc59 100644 --- a/src/boltz/model/modules/diffusionv2.py +++ b/src/boltz/model/modules/diffusionv2.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang from __future__ import annotations @@ -136,12 +159,16 @@ def forward( s_inputs.repeat_interleave(multiplicity, 0), ) + # Promote to at least float32 for numerical stability, but preserve + # higher precision (e.g. float64) if available. + compute_dtype = torch.promote_types(r_noisy.dtype, torch.float32) + # Sequence-local Atom Attention and aggregation to coarse-grained tokens a, q_skip, c_skip, to_keys = self.atom_attention_encoder( feats=feats, - q=diffusion_conditioning["q"].float(), - c=diffusion_conditioning["c"].float(), - atom_enc_bias=diffusion_conditioning["atom_enc_bias"].float(), + q=diffusion_conditioning["q"].to(compute_dtype), + c=diffusion_conditioning["c"].to(compute_dtype), + atom_enc_bias=diffusion_conditioning["atom_enc_bias"].to(compute_dtype), to_keys=diffusion_conditioning["to_keys"], r=r_noisy, # Float['b m 3'], multiplicity=multiplicity, @@ -153,11 +180,11 @@ def forward( mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) a = self.token_transformer( a, - mask=mask.float(), + mask=mask.to(compute_dtype), s=s, bias=diffusion_conditioning[ "token_trans_bias" - ].float(), # note z is not expanded with multiplicity until after bias is computed + ].to(compute_dtype), # note z is not expanded with multiplicity until after bias is computed multiplicity=multiplicity, ) a = self.a_norm(a) @@ -167,7 +194,7 @@ def forward( a=a, q=q_skip, c=c_skip, - atom_dec_bias=diffusion_conditioning["atom_dec_bias"].float(), + atom_dec_bias=diffusion_conditioning["atom_dec_bias"].to(compute_dtype), feats=feats, multiplicity=multiplicity, to_keys=to_keys, @@ -257,7 +284,10 @@ def preconditioned_network_forward( batch, device = noised_atom_coords.shape[0], noised_atom_coords.device if isinstance(sigma, float): - sigma = torch.full((batch,), sigma, device=device) + # Preserve dtype of noised_atom_coords (e.g. float64 for testing). + # Without this, torch.full defaults to float32, causing dtype mismatch + # in downstream FourierEmbedding when the model is in float64. + sigma = torch.full((batch,), sigma, device=device, dtype=noised_atom_coords.dtype) padded_sigma = rearrange(sigma, "b -> b 1 1") @@ -382,21 +412,37 @@ def sample( with torch.no_grad(): atom_coords_denoised = torch.zeros_like(atom_coords_noisy) + B_local = atom_coords_noisy.shape[0] // multiplicity sample_ids = torch.arange(multiplicity).to(atom_coords_noisy.device) + # ceiling division: ceil(multiplicity / max_parallel_samples). + # The original formula (multiplicity % max_parallel_samples + 1) was incorrect + # when multiplicity is an exact non-trivial multiple of max_parallel_samples + # (e.g., multiplicity=4, max_parallel_samples=2 gave 1 chunk instead of 2). sample_ids_chunks = sample_ids.chunk( - multiplicity % max_parallel_samples + 1 + (multiplicity + max_parallel_samples - 1) // max_parallel_samples ) for sample_ids_chunk in sample_ids_chunks: + chunk_M = sample_ids_chunk.numel() + # atom_coords_noisy is (B*M, N, 3). We must unflatten to (B, M, N, 3) + # and index the M axis — not dim 0 directly — because + # preconditioned_network_forward passes **network_condition_kwargs + # (feats, s_inputs, s_trunk, etc.) which have the full B batch + # dimension. A naïve dim-0 index would select B*chunk_M rows that + # mix batch and multiplicity positions, causing a shape mismatch + # with the un-indexed conditioning tensors inside the score model. + noisy_chunk = atom_coords_noisy.unflatten(0, (B_local, multiplicity))[:, sample_ids_chunk].flatten(0, 1) atom_coords_denoised_chunk = self.preconditioned_network_forward( - atom_coords_noisy[sample_ids_chunk], + noisy_chunk, t_hat, network_condition_kwargs=dict( - multiplicity=sample_ids_chunk.numel(), + multiplicity=chunk_M, **network_condition_kwargs, ), ) - atom_coords_denoised[sample_ids_chunk] = atom_coords_denoised_chunk + atom_coords_denoised.unflatten(0, (B_local, multiplicity))[:, sample_ids_chunk] = ( + atom_coords_denoised_chunk.unflatten(0, (B_local, chunk_M)) + ) if ( steering_args is not None @@ -528,11 +574,12 @@ def sample( if self.alignment_reverse_diff: with torch.autocast("cuda", enabled=False): + align_dtype = torch.promote_types(atom_coords_noisy.dtype, torch.float32) atom_coords_noisy = weighted_rigid_align( - atom_coords_noisy.float(), - atom_coords_denoised.float(), - atom_mask.float(), - atom_mask.float(), + atom_coords_noisy.to(align_dtype), + atom_coords_denoised.to(align_dtype), + atom_mask.to(align_dtype), + atom_mask.to(align_dtype), ) atom_coords_noisy = atom_coords_noisy.to(atom_coords_denoised) @@ -544,7 +591,7 @@ def sample( atom_coords = atom_coords_next - return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr) + return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr) # noqa: C408 def loss_weight(self, sigma): return (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2) @@ -592,20 +639,20 @@ def forward( denoised_atom_coords = self.preconditioned_network_forward( noised_atom_coords, sigmas, - network_condition_kwargs={ - "s_inputs": s_inputs, - "s_trunk": s_trunk, - "feats": feats, - "multiplicity": multiplicity, - "diffusion_conditioning": diffusion_conditioning, - }, + network_condition_kwargs=dict( # noqa: C408 + s_inputs=s_inputs, + s_trunk=s_trunk, + feats=feats, + multiplicity=multiplicity, + diffusion_conditioning=diffusion_conditioning, + ), ) - return { - "denoised_atom_coords": denoised_atom_coords, - "sigmas": sigmas, - "aligned_true_atom_coords": atom_coords, - } + return dict( # noqa: C408 + denoised_atom_coords=denoised_atom_coords, + sigmas=sigmas, + aligned_true_atom_coords=atom_coords, + ) def compute_loss( self, @@ -618,14 +665,15 @@ def compute_loss( filter_by_plddt=0.0, ): with torch.autocast("cuda", enabled=False): - denoised_atom_coords = out_dict["denoised_atom_coords"].float() - sigmas = out_dict["sigmas"].float() + compute_dtype = torch.promote_types(out_dict["denoised_atom_coords"].dtype, torch.float32) + denoised_atom_coords = out_dict["denoised_atom_coords"].to(compute_dtype) + sigmas = out_dict["sigmas"].to(compute_dtype) - resolved_atom_mask_uni = feats["atom_resolved_mask"].float() + resolved_atom_mask_uni = feats["atom_resolved_mask"].to(compute_dtype) if filter_by_plddt > 0: plddt_mask = feats["plddt"] > filter_by_plddt - resolved_atom_mask_uni = resolved_atom_mask_uni * plddt_mask.float() + resolved_atom_mask_uni = resolved_atom_mask_uni * plddt_mask.to(compute_dtype) resolved_atom_mask = resolved_atom_mask_uni.repeat_interleave( multiplicity, 0 @@ -636,8 +684,8 @@ def compute_loss( ) atom_type = ( torch.bmm( - feats["atom_to_token"].float(), - feats["mol_type"].unsqueeze(-1).float(), + feats["atom_to_token"].to(compute_dtype), + feats["mol_type"].unsqueeze(-1).to(compute_dtype), ) .squeeze(-1) .long() @@ -650,23 +698,23 @@ def compute_loss( 1 + nucleotide_loss_weight * ( - torch.eq(atom_type_mult, const.chain_type_ids["DNA"]).float() - + torch.eq(atom_type_mult, const.chain_type_ids["RNA"]).float() + torch.eq(atom_type_mult, const.chain_type_ids["DNA"]).to(compute_dtype) + + torch.eq(atom_type_mult, const.chain_type_ids["RNA"]).to(compute_dtype) ) + ligand_loss_weight * torch.eq( atom_type_mult, const.chain_type_ids["NONPOLYMER"] - ).float() - ).float() + ).to(compute_dtype) + ).to(compute_dtype) ) - atom_coords = out_dict["aligned_true_atom_coords"].float() + atom_coords = out_dict["aligned_true_atom_coords"].to(compute_dtype) atom_coords_aligned_ground_truth = weighted_rigid_align( atom_coords.detach(), denoised_atom_coords.detach(), align_weights.detach(), mask=feats["atom_resolved_mask"] - .float() + .to(compute_dtype) .repeat_interleave(multiplicity, 0) .detach(), ) @@ -696,8 +744,8 @@ def compute_loss( lddt_loss = smooth_lddt_loss( denoised_atom_coords, feats["coords"], - torch.eq(atom_type, const.chain_type_ids["DNA"]).float() - + torch.eq(atom_type, const.chain_type_ids["RNA"]).float(), + torch.eq(atom_type, const.chain_type_ids["DNA"]).to(compute_dtype) + + torch.eq(atom_type, const.chain_type_ids["RNA"]).to(compute_dtype), coords_mask=resolved_atom_mask_uni, multiplicity=multiplicity, ) @@ -709,4 +757,4 @@ def compute_loss( "smooth_lddt_loss": lddt_loss, } - return {"loss": total_loss, "loss_breakdown": loss_breakdown} + return dict(loss=total_loss, loss_breakdown=loss_breakdown) # noqa: C408 diff --git a/src/boltz/model/modules/encoders.py b/src/boltz/model/modules/encoders.py index d5054de99..2b88e2edf 100644 --- a/src/boltz/model/modules/encoders.py +++ b/src/boltz/model/modules/encoders.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang from functools import partial from math import pi @@ -404,6 +427,11 @@ def forward( B, N, _ = feats["ref_pos"].shape atom_mask = feats["atom_pad_mask"].bool() + # Promote to at least float32 for numerical stability, but preserve + # higher precision (e.g. float64) if available. Replaces hardcoded .float() + # to support float64 DTensor testing. + compute_dtype = torch.promote_types(feats["ref_pos"].dtype, torch.float32) + layer_cache = None if model_cache is not None: cache_prefix = "atomencoder" @@ -433,7 +461,7 @@ def forward( W, H = self.atoms_per_window_queries, self.atoms_per_window_keys B, N = c.shape[:2] K = N // W - keys_indexing_matrix = get_indexing_matrix(K, W, H, c.device) + keys_indexing_matrix = get_indexing_matrix(K, W, H, c.device).to(dtype=compute_dtype) to_keys = partial( single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H ) @@ -447,11 +475,11 @@ def forward( atom_mask_queries = atom_mask.view(B, K, W, 1) atom_mask_keys = ( - to_keys(atom_mask.unsqueeze(-1).float()).view(B, K, 1, H).bool() + to_keys(atom_mask.unsqueeze(-1).to(compute_dtype)).view(B, K, 1, H).bool() ) atom_uid_queries = atom_uid.view(B, K, W, 1) atom_uid_keys = ( - to_keys(atom_uid.unsqueeze(-1).float()).view(B, K, 1, H).long() + to_keys(atom_uid.unsqueeze(-1).to(compute_dtype)).view(B, K, 1, H).long() ) v = ( ( @@ -459,7 +487,7 @@ def forward( & atom_mask_keys & (atom_uid_queries == atom_uid_keys) ) - .float() + .to(compute_dtype) .unsqueeze(-1) ) @@ -471,7 +499,7 @@ def forward( if self.structure_prediction: # run only in structure model not in initial encoding - atom_to_token = feats["atom_to_token"].float() + atom_to_token = feats["atom_to_token"].to(compute_dtype) s_to_c = self.s_to_c_trans(s_trunk) s_to_c = torch.bmm(atom_to_token, s_to_c) @@ -530,7 +558,7 @@ def forward( ) q_to_a = self.atom_to_token_trans(q) - atom_to_token = feats["atom_to_token"].float() + atom_to_token = feats["atom_to_token"].to(compute_dtype) atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0) atom_to_token_mean = atom_to_token / ( atom_to_token.sum(dim=1, keepdim=True) + 1e-6 @@ -611,7 +639,11 @@ def forward( atom_mask = feats["atom_pad_mask"] atom_mask = atom_mask.repeat_interleave(multiplicity, 0) - atom_to_token = feats["atom_to_token"].float() + # Promote to at least float32 for numerical stability, but preserve + # higher precision (e.g. float64) if available. Replaces hardcoded .float() + # to support float64 DTensor testing. + compute_dtype = torch.promote_types(a.dtype, torch.float32) + atom_to_token = feats["atom_to_token"].to(compute_dtype) atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0) a_to_q = self.a_to_q_trans(a) diff --git a/src/boltz/model/modules/encodersv2.py b/src/boltz/model/modules/encodersv2.py index f02cc3ecf..a87700223 100644 --- a/src/boltz/model/modules/encodersv2.py +++ b/src/boltz/model/modules/encodersv2.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang from functools import partial from math import pi @@ -36,9 +59,7 @@ def forward( class RelativePositionEncoder(Module): """Algorithm 3.""" - def __init__( - self, token_z, r_max=32, s_max=2, fix_sym_check=False, cyclic_pos_enc=False - ): + def __init__(self, token_z, r_max=32, s_max=2, fix_sym_check=False, cyclic_pos_enc=False): super().__init__() self.r_max = r_max self.s_max = s_max @@ -47,26 +68,22 @@ def __init__( self.cyclic_pos_enc = cyclic_pos_enc def forward(self, feats): - b_same_chain = torch.eq( - feats["asym_id"][:, :, None], feats["asym_id"][:, None, :] - ) - b_same_residue = torch.eq( - feats["residue_index"][:, :, None], feats["residue_index"][:, None, :] - ) - b_same_entity = torch.eq( - feats["entity_id"][:, :, None], feats["entity_id"][:, None, :] - ) + b_same_chain = torch.eq(feats["asym_id"][:, :, None], feats["asym_id"][:, None, :]) + b_same_residue = torch.eq(feats["residue_index"][:, :, None], feats["residue_index"][:, None, :]) + b_same_entity = torch.eq(feats["entity_id"][:, :, None], feats["entity_id"][:, None, :]) - d_residue = ( - feats["residue_index"][:, :, None] - feats["residue_index"][:, None, :] - ) + d_residue = feats["residue_index"][:, :, None] - feats["residue_index"][:, None, :] if self.cyclic_pos_enc and torch.any(feats["cyclic_period"] > 0): + # BugFix: unsqueeze(1) reshapes period from (B, N) to (B, 1, N) so it + # broadcasts correctly against d_residue (B, N, N). Without it, + # (B, N, N) / (B, N) aligns from the right and fails when B != N. + # The V1 encoder already had this unsqueeze at the same position. period = torch.where( feats["cyclic_period"] > 0, feats["cyclic_period"], torch.zeros_like(feats["cyclic_period"]) + 10000, - ) + ).unsqueeze(1) d_residue = (d_residue - period * torch.round(d_residue / period)).long() d_residue = torch.clip( @@ -74,15 +91,11 @@ def forward(self, feats): 0, 2 * self.r_max, ) - d_residue = torch.where( - b_same_chain, d_residue, torch.zeros_like(d_residue) + 2 * self.r_max + 1 - ) + d_residue = torch.where(b_same_chain, d_residue, torch.zeros_like(d_residue) + 2 * self.r_max + 1) a_rel_pos = one_hot(d_residue, 2 * self.r_max + 2) d_token = torch.clip( - feats["token_index"][:, :, None] - - feats["token_index"][:, None, :] - + self.r_max, + feats["token_index"][:, :, None] - feats["token_index"][:, None, :] + self.r_max, 0, 2 * self.r_max, ) @@ -106,13 +119,14 @@ def forward(self, feats): # Note: added | (~b_same_entity) based on observation of ProteinX manuscript a_rel_chain = one_hot(d_chain, 2 * self.s_max + 2) + cast_dtype = torch.promote_types(self.linear_layer.weight.dtype, torch.float32) p = self.linear_layer( torch.cat( [ - a_rel_pos.float(), - a_rel_token.float(), - b_same_entity.unsqueeze(-1).float(), - a_rel_chain.float(), + a_rel_pos.to(cast_dtype), + a_rel_token.to(cast_dtype), + b_same_entity.unsqueeze(-1).to(cast_dtype), + a_rel_chain.to(cast_dtype), ], dim=-1, ) @@ -147,9 +161,7 @@ def __init__( transitions = ModuleList([]) for _ in range(num_transitions): - transition = Transition( - dim=2 * token_s, hidden=transition_expansion_factor * 2 * token_s - ) + transition = Transition(dim=2 * token_s, hidden=transition_expansion_factor * 2 * token_s) transitions.append(transition) self.transitions = transitions @@ -163,9 +175,7 @@ def forward( s = torch.cat((s_trunk, s_inputs), dim=-1) s = self.single_embed(self.norm_single(s)) if not self.disable_times: - fourier_embed = self.fourier_embed( - times - ) # note: sigma rescaling done in diffusion module + fourier_embed = self.fourier_embed(times) # note: sigma rescaling done in diffusion module normed_fourier = self.norm_fourier(fourier_embed) fourier_to_single = self.fourier_to_single(normed_fourier) @@ -196,9 +206,7 @@ def __init__( transitions = ModuleList([]) for _ in range(num_transitions): - transition = Transition( - dim=token_z, hidden=transition_expansion_factor * token_z - ) + transition = Transition(dim=token_z, hidden=transition_expansion_factor * token_z) transitions.append(transition) self.transitions = transitions @@ -217,7 +225,7 @@ def forward( return z -def get_indexing_matrix(K, W, H, device): +def get_indexing_matrix(K, W, H, device, dtype=torch.float32): assert W % 2 == 0 assert H % (W // 2) == 0 @@ -225,12 +233,10 @@ def get_indexing_matrix(K, W, H, device): assert h % 2 == 0 arange = torch.arange(2 * K, device=device) - index = ((arange.unsqueeze(0) - arange.unsqueeze(1)) + h // 2).clamp( - min=0, max=h + 1 - ) + index = ((arange.unsqueeze(0) - arange.unsqueeze(1)) + h // 2).clamp(min=0, max=h + 1) index = index.view(K, 2, 2 * K)[:, 0, :] onehot = one_hot(index, num_classes=h + 2)[..., 1:-1].transpose(1, 0) - return onehot.reshape(2 * K, h * K).float() + return onehot.reshape(2 * K, h * K).to(dtype) def single_to_keys(single, indexing_matrix, W, H): @@ -271,14 +277,10 @@ def __init__( self.structure_prediction = structure_prediction if structure_prediction: - self.s_to_c_trans = nn.Sequential( - nn.LayerNorm(token_s), LinearNoBias(token_s, atom_s) - ) + self.s_to_c_trans = nn.Sequential(nn.LayerNorm(token_s), LinearNoBias(token_s, atom_s)) init.final_init_(self.s_to_c_trans[1].weight) - self.z_to_p_trans = nn.Sequential( - nn.LayerNorm(token_z), LinearNoBias(token_z, atom_z) - ) + self.z_to_p_trans = nn.Sequential(nn.LayerNorm(token_z), LinearNoBias(token_z, atom_z)) init.final_init_(self.z_to_p_trans[1].weight) self.c_to_p_trans_k = nn.Sequential( @@ -316,6 +318,10 @@ def forward( atom_ref_pos = feats["ref_pos"] # Float['b m 3'], atom_uid = feats["ref_space_uid"] # Long['b m'], + # Promote to at least float32 for numerical stability, but preserve + # higher precision (e.g. float64) if available. + compute_dtype = torch.promote_types(atom_ref_pos.dtype, torch.float32) + atom_feats = [ atom_ref_pos, feats["ref_charge"].unsqueeze(-1), @@ -330,11 +336,11 @@ def forward( [ feats["res_type"], feats["modified"].unsqueeze(-1), - one_hot(feats["mol_type"], num_classes=4).float(), + one_hot(feats["mol_type"], num_classes=4).to(compute_dtype), ], dim=-1, ) - atom_to_token = feats["atom_to_token"].float() + atom_to_token = feats["atom_to_token"].to(compute_dtype) atom_res_feats = torch.bmm(atom_to_token, res_feats) atom_feats.append(atom_res_feats) @@ -346,35 +352,23 @@ def forward( W, H = self.atoms_per_window_queries, self.atoms_per_window_keys B, N = c.shape[:2] K = N // W - keys_indexing_matrix = get_indexing_matrix(K, W, H, c.device) - to_keys = partial( - single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H - ) + keys_indexing_matrix = get_indexing_matrix(K, W, H, c.device, dtype=compute_dtype) + to_keys = partial(single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H) atom_ref_pos_queries = atom_ref_pos.view(B, K, W, 1, 3) atom_ref_pos_keys = to_keys(atom_ref_pos).view(B, K, 1, H, 3) d = atom_ref_pos_keys - atom_ref_pos_queries # Float['b k w h 3'] d_norm = torch.sum(d * d, dim=-1, keepdim=True) # Float['b k w h 1'] - d_norm = 1 / ( - 1 + d_norm - ) # AF3 feeds in the reciprocal of the distance norm + d_norm = 1 / (1 + d_norm) # AF3 feeds in the reciprocal of the distance norm atom_mask_queries = atom_mask.view(B, K, W, 1) - atom_mask_keys = ( - to_keys(atom_mask.unsqueeze(-1).float()).view(B, K, 1, H).bool() - ) + atom_mask_keys = to_keys(atom_mask.unsqueeze(-1).to(compute_dtype)).view(B, K, 1, H).bool() atom_uid_queries = atom_uid.view(B, K, W, 1) - atom_uid_keys = ( - to_keys(atom_uid.unsqueeze(-1).float()).view(B, K, 1, H).long() - ) + atom_uid_keys = to_keys(atom_uid.unsqueeze(-1).to(compute_dtype)).view(B, K, 1, H).long() v = ( - ( - atom_mask_queries - & atom_mask_keys - & (atom_uid_queries == atom_uid_keys) - ) - .float() + (atom_mask_queries & atom_mask_keys & (atom_uid_queries == atom_uid_keys)) + .to(compute_dtype) .unsqueeze(-1) ) # Bool['b k w h 1'] @@ -386,24 +380,22 @@ def forward( if self.structure_prediction: # run only in structure model not in initial encoding - atom_to_token = feats["atom_to_token"].float() # Long['b m n'], + atom_to_token = feats["atom_to_token"].to(compute_dtype) # Long['b m n'], - s_to_c = self.s_to_c_trans(s_trunk.float()) + s_to_c = self.s_to_c_trans(s_trunk.to(compute_dtype)) s_to_c = torch.bmm(atom_to_token, s_to_c) - c = c + s_to_c.to(c) + c = c + s_to_c.to(c.dtype) - atom_to_token_queries = atom_to_token.view( - B, K, W, atom_to_token.shape[-1] - ) + atom_to_token_queries = atom_to_token.view(B, K, W, atom_to_token.shape[-1]) atom_to_token_keys = to_keys(atom_to_token) - z_to_p = self.z_to_p_trans(z.float()) + z_to_p = self.z_to_p_trans(z.to(compute_dtype)) z_to_p = torch.einsum( "bijd,bwki,bwlj->bwkld", z_to_p, atom_to_token_queries, atom_to_token_keys, ) - p = p + z_to_p.to(p) + p = p + z_to_p.to(p.dtype) p = p + self.c_to_p_trans_q(c.view(B, K, W, 1, c.shape[-1])) p = p + self.c_to_p_trans_k(to_keys(c).view(B, K, 1, H, c.shape[-1])) @@ -479,15 +471,14 @@ def forward( ) with torch.autocast("cuda", enabled=False): - q_to_a = self.atom_to_token_trans(q).float() - atom_to_token = feats["atom_to_token"].float() + compute_dtype = torch.promote_types(feats["ref_pos"].dtype, torch.float32) + q_to_a = self.atom_to_token_trans(q).to(compute_dtype) + atom_to_token = feats["atom_to_token"].to(compute_dtype) atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0) - atom_to_token_mean = atom_to_token / ( - atom_to_token.sum(dim=1, keepdim=True) + 1e-6 - ) + atom_to_token_mean = atom_to_token / atom_to_token.sum(dim=1, keepdim=True).clamp(min=1) a = torch.bmm(atom_to_token_mean.transpose(1, 2), q_to_a) - a = a.to(q) + a = a.to(q.dtype) return a, q, c, to_keys @@ -526,9 +517,7 @@ def __init__( self.atom_feat_to_atom_pos_update = LinearNoBias(atom_s, 3) init.final_init_(self.atom_feat_to_atom_pos_update.weight) else: - self.atom_feat_to_atom_pos_update = nn.Sequential( - nn.LayerNorm(atom_s), LinearNoBias(atom_s, 3) - ) + self.atom_feat_to_atom_pos_update = nn.Sequential(nn.LayerNorm(atom_s), LinearNoBias(atom_s, 3)) init.final_init_(self.atom_feat_to_atom_pos_update[1].weight) def forward( @@ -542,10 +531,13 @@ def forward( multiplicity=1, ): with torch.autocast("cuda", enabled=False): - atom_to_token = feats["atom_to_token"].float() + # Promote to at least float32 for numerical stability, but preserve + # higher precision (e.g. float64) if available. + compute_dtype = torch.promote_types(a.dtype, torch.float32) + atom_to_token = feats["atom_to_token"].to(compute_dtype) atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0) - a_to_q = self.a_to_q_trans(a.float()) + a_to_q = self.a_to_q_trans(a.to(compute_dtype)) a_to_q = torch.bmm(atom_to_token, a_to_q) q = q + a_to_q.to(q) diff --git a/src/boltz/model/modules/transformers.py b/src/boltz/model/modules/transformers.py index f3255d2ac..0e87775d7 100644 --- a/src/boltz/model/modules/transformers.py +++ b/src/boltz/model/modules/transformers.py @@ -1,5 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang +# Added for torch.promote_types used in AtomTransformer.forward() to support float64 testing. +import torch from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from torch import nn, sigmoid from torch.nn import ( @@ -302,16 +327,20 @@ def forward( p = p.repeat_interleave(multiplicity, 0) p = p.view((p.shape[0] * NW, W, H, -1)) - to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1) + to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1) # noqa: E731 else: to_keys_new = None # main transformer + # Promote to at least float32 for numerical stability, but preserve + # higher precision (e.g. float64) if available. Replaces hardcoded .float() + # to support float64 DTensor testing. + compute_dtype = torch.promote_types(q.dtype, torch.float32) q = self.diffusion_transformer( a=q, s=c, z=p, - mask=mask.float(), + mask=mask.to(compute_dtype), multiplicity=1, # bias term already expanded with multiplicity to_keys=to_keys_new, model_cache=model_cache, diff --git a/src/boltz/model/modules/transformersv2.py b/src/boltz/model/modules/transformersv2.py index 9d785d80a..f51998c66 100644 --- a/src/boltz/model/modules/transformersv2.py +++ b/src/boltz/model/modules/transformersv2.py @@ -1,3 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang import torch @@ -152,20 +174,14 @@ def __init__( dim_single_cond = default(dim_single_cond, dim) self.adaln = AdaLN(dim, dim_single_cond) - self.pair_bias_attn = AttentionPairBias( - c_s=dim, num_heads=heads, compute_pair_bias=False - ) + self.pair_bias_attn = AttentionPairBias(c_s=dim, num_heads=heads, compute_pair_bias=False) self.output_projection_linear = Linear(dim_single_cond, dim) nn.init.zeros_(self.output_projection_linear.weight) nn.init.constant_(self.output_projection_linear.bias, -2.0) - self.output_projection = nn.Sequential( - self.output_projection_linear, nn.Sigmoid() - ) - self.transition = ConditionedTransitionBlock( - dim_single=dim, dim_single_cond=dim_single_cond - ) + self.output_projection = nn.Sequential(self.output_projection_linear, nn.Sigmoid()) + self.transition = ConditionedTransitionBlock(dim_single=dim, dim_single_cond=dim_single_cond) if post_layer_norm: self.post_lnorm = nn.LayerNorm(dim) @@ -220,9 +236,7 @@ def __init__( super().__init__() self.attn_window_queries = attn_window_queries self.attn_window_keys = attn_window_keys - self.diffusion_transformer = DiffusionTransformer( - **diffusion_transformer_kwargs - ) + self.diffusion_transformer = DiffusionTransformer(**diffusion_transformer_kwargs) def forward( self, @@ -246,15 +260,16 @@ def forward( bias = bias.repeat_interleave(multiplicity, 0) bias = bias.view((bias.shape[0] * NW, W, H, -1)) - to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1) + to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1) # noqa: E731 # main transformer + compute_dtype = torch.promote_types(q.dtype, torch.float32) q = self.diffusion_transformer( a=q, s=c, bias=bias, - mask=mask.float(), - multiplicity=1, # bias term already expanded with multiplicity + mask=mask.to(compute_dtype), + multiplicity=1, # bias term already expanded with multiplicity to_keys=to_keys_new, ) diff --git a/src/boltz/model/modules/trunkv2.py b/src/boltz/model/modules/trunkv2.py index ebd7a2ec6..2a051cfa6 100644 --- a/src/boltz/model/modules/trunkv2.py +++ b/src/boltz/model/modules/trunkv2.py @@ -1,3 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + import torch from torch import Tensor, nn from torch.nn.functional import one_hot @@ -170,7 +193,9 @@ def forward(self, feats: dict[str, Tensor], affinity: bool = False) -> Tensor: """ # Load relevant features - res_type = feats["res_type"].float() + # res_type is integer one-hot; cast to the layer's weight dtype so it + # works with any precision (float32, float64, bfloat16, etc.). + res_type = feats["res_type"].to(self.res_type_encoding.weight.dtype) if affinity: profile = feats["profile_affinity"] deletion_mean = feats["deletion_mean_affinity"].unsqueeze(-1) @@ -200,7 +225,10 @@ def forward(self, feats: dict[str, Tensor], affinity: bool = False) -> Tensor: if self.add_modified_flag: s = s + self.modified_conditioning_init(feats["modified"]) if self.add_cyclic_flag: - cyclic = feats["cyclic_period"].clamp(max=1.0).unsqueeze(-1) + # cyclic_period is integer; cast to the layer's weight dtype so it + # works with any precision (float32, float64, bfloat16, etc.). + cyclic = feats["cyclic_period"].to(self.cyclic_conditioning_init.weight.dtype) + cyclic = cyclic.clamp(max=1.0).unsqueeze(-1) s = s + self.cyclic_conditioning_init(cyclic) if self.add_mol_type_feat: s = s + self.mol_type_conditioning_init(feats["mol_type"]) @@ -291,7 +319,8 @@ def forward( cb_coords = feats["template_cb"] ca_coords = feats["template_ca"] cb_mask = feats["template_mask_cb"] - template_mask = feats["template_mask"].any(dim=2).float() + compute_dtype = torch.promote_types(z.dtype, torch.float32) + template_mask = feats["template_mask"].any(dim=2).to(compute_dtype) num_templates = template_mask.sum(dim=1) num_templates = num_templates.clamp(min=1) @@ -304,7 +333,7 @@ def forward( # Compute asym mask, template features only attend within the same chain B, T = res_type.shape[:2] # noqa: N806 - asym_mask = (asym_id[:, :, None] == asym_id[:, None, :]).float() + asym_mask = (asym_id[:, :, None] == asym_id[:, None, :]).to(compute_dtype) asym_mask = asym_mask[:, None].expand(-1, T, -1, -1) # Compute template features @@ -441,7 +470,8 @@ def forward( ca_coords = feats["template_ca"] cb_mask = feats["template_mask_cb"] visibility_ids = feats["visibility_ids"] - template_mask = feats["template_mask"].any(dim=2).float() + compute_dtype = torch.promote_types(z.dtype, torch.float32) + template_mask = feats["template_mask"].any(dim=2).to(compute_dtype) num_templates = template_mask.sum(dim=1) num_templates = num_templates.clamp(min=1) @@ -456,7 +486,7 @@ def forward( B, T = res_type.shape[:2] # noqa: N806 tmlp_pair_mask = ( visibility_ids[:, :, :, None] == visibility_ids[:, :, None, :] - ).float() + ).to(compute_dtype) # Compute template features with torch.autocast(device_type="cuda", enabled=False): @@ -618,7 +648,8 @@ def forward( deletion_value = feats["deletion_value"].unsqueeze(-1) is_paired = feats["msa_paired"].unsqueeze(-1) msa_mask = feats["msa_mask"] - token_mask = feats["token_pad_mask"].float() + mask_dtype = torch.promote_types(self.msa_proj.weight.dtype, torch.float32) + token_mask = feats["token_pad_mask"].to(mask_dtype) token_mask = token_mask[:, :, None] * token_mask[:, None, :] # Compute MSA embeddings diff --git a/src/boltz/model/validation/rcsb.py b/src/boltz/model/validation/rcsb.py index 69a0a8319..3fd14d014 100644 --- a/src/boltz/model/validation/rcsb.py +++ b/src/boltz/model/validation/rcsb.py @@ -1,3 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + + from typing import Optional import torch @@ -13,11 +37,13 @@ def __init__( self, val_names: list[str], confidence_prediction: bool = False, + physicalism_metrics: bool = False, override_val_method: Optional[str] = None, ) -> None: super().__init__( val_names=val_names, confidence_prediction=confidence_prediction, + physicalism_metrics=physicalism_metrics, override_val_method=override_val_method, ) diff --git a/src/boltz/model/validation/validator.py b/src/boltz/model/validation/validator.py index 393e146ea..71b24cb38 100644 --- a/src/boltz/model/validation/validator.py +++ b/src/boltz/model/validation/validator.py @@ -1,3 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + from collections import defaultdict from typing import Optional @@ -57,10 +79,7 @@ def __init__( "disto_loss", ] self.folding_metrics = nn.ModuleDict( - { - k: nn.ModuleList([nn.ModuleDict() for _ in range(num_val_datasets)]) - for k in folding_metric_labels - } + {k: nn.ModuleList([nn.ModuleDict() for _ in range(num_val_datasets)]) for k in folding_metric_labels} ) self.physicalism_metrics = physicalism_metrics @@ -97,19 +116,22 @@ def __init__( "avg", ] mae_metric_labels = ["plddt_mae", "pde_mae", "pae_mae"] - lddt_confidence_metric_labels = [ - prefix + "_lddt" for prefix in confidence_metric_prefixes - ] + lddt_confidence_metric_labels = [prefix + "_lddt" for prefix in confidence_metric_prefixes] if physicalism_metrics: - clash_confidence_metric_labels = [ - prefix + "_clash" for prefix in confidence_metric_prefixes - ] - pb_confidence_metric_labels = [ - prefix + "_pb" for prefix in confidence_metric_prefixes - ] + clash_confidence_metric_labels = [prefix + "_clash" for prefix in confidence_metric_prefixes] + pb_confidence_metric_labels = [prefix + "_pb" for prefix in confidence_metric_prefixes] else: clash_confidence_metric_labels, pb_confidence_metric_labels = [], [] + # All multi-sample metrics (both confidence-dependent and purely structural + # like avg_lddt, avg_clash, avg_pb) are stored in self.confidence_metrics. + # This ModuleDict is only created when confidence_prediction=True because + # most of its entries (top1_lddt, plddt_mae, pde_mae, etc.) require + # confidence module outputs (out["plddt"], out["complex_plddt"], etc.) for + # ranking or MAE computation. The avg_* entries are purely structural + # (mean across diffusion samples) and don't need the confidence module, + # but are co-located here for convenience. If avg_* metrics are needed + # without confidence_prediction, they would need a separate ModuleDict. if confidence_prediction: self.confidence_metrics = nn.ModuleDict( { @@ -128,12 +150,23 @@ def __init__( "pocket_ligand_protein", "contact_protein_protein", ]: + self.folding_metrics["lddt"][val_idx][m_] = MeanMetric() self.folding_metrics["disto_lddt"][val_idx][m_] = MeanMetric() + self.folding_metrics["complex_lddt"][val_idx][m_] = MeanMetric() for m in const.out_single_types: if confidence_prediction: self.confidence_metrics["plddt_mae"][val_idx][m] = MeanMetric() + if confidence_prediction: + for m_ in const.out_types: + if m_ == "modified": + continue + for k in lddt_confidence_metric_labels: + self.confidence_metrics[k][val_idx][m_] = MeanMetric() + self.confidence_metrics["pde_mae"][val_idx][m_] = MeanMetric() + self.confidence_metrics["pae_mae"][val_idx][m_] = MeanMetric() + for m in ["disto_loss"]: self.folding_metrics["disto_loss"][val_idx][m] = MeanMetric() @@ -171,9 +204,7 @@ def run_model( recycling_steps=model.validation_args.recycling_steps, num_sampling_steps=model.validation_args.sampling_steps, diffusion_samples=model.validation_args.diffusion_samples, - run_confidence_sequentially=model.validation_args.get( - "run_confidence_sequentially", False - ), + run_confidence_sequentially=model.validation_args.get("run_confidence_sequentially", False), ) return out @@ -226,9 +257,7 @@ def compute_disto_loss( ) -> None: """Compute distogram loss.""" # Compute validation disto loss - val_disto_loss, _ = distogram_loss( - out, batch, aggregate_distogram=model.aggregate_distogram - ) + val_disto_loss, _ = distogram_loss(out, batch, aggregate_distogram=model.aggregate_distogram) return val_disto_loss @@ -238,9 +267,7 @@ def compute_disto_lddt(self, model, batch, out, idx_dataset) -> tuple[dict, dict lower = torch.tensor([1.0]) upper = torch.tensor([model.max_dist + 5.0]) exp_boundaries = torch.cat((lower, boundaries, upper)) - mid_points = ((exp_boundaries[:-1] + exp_boundaries[1:]) / 2).to( - out["pdistogram"] - ) + mid_points = ((exp_boundaries[:-1] + exp_boundaries[1:]) / 2).to(out["pdistogram"]) # Compute true distogram K = batch["coords"].shape[1] @@ -249,20 +276,14 @@ def compute_disto_lddt(self, model, batch, out, idx_dataset) -> tuple[dict, dict batch["token_disto_mask"] = batch["token_disto_mask"] # Compute distogram lddt by looping over predicted distograms - disto_lddt_dict = defaultdict( - lambda: torch.zeros(K, model.num_distograms).to(model.device) - ) - disto_total_dict = defaultdict( - lambda: torch.zeros(K, model.num_distograms).to(model.device) - ) + disto_lddt_dict = defaultdict(lambda: torch.zeros(K, model.num_distograms).to(model.device)) + disto_total_dict = defaultdict(lambda: torch.zeros(K, model.num_distograms).to(model.device)) for i in range(model.num_distograms): # Compute predicted dists preds = out["pdistogram"][:, :, :, i] pred_softmax = torch.softmax(preds, dim=-1) pred_softmax = pred_softmax.argmax(dim=-1) - pred_softmax = torch.nn.functional.one_hot( - pred_softmax, num_classes=preds.shape[-1] - ) + pred_softmax = torch.nn.functional.one_hot(pred_softmax, num_classes=preds.shape[-1]) pred_dist_i = (pred_softmax * mid_points).sum(dim=-1) # (B, L, L) del pred_softmax @@ -271,9 +292,7 @@ def compute_disto_lddt(self, model, batch, out, idx_dataset) -> tuple[dict, dict # conformers. Batched version over K factored_token_lddt_dist_loss_ensemble # more efficient for small K. for k in range(K): - true_dists_k = torch.cdist(true_center[k], true_center[k])[ - None - ] # (1, L * L) + true_dists_k = torch.cdist(true_center[k], true_center[k])[None] # (1, L * L) # Compute lddt disto_lddt_dict_, disto_total_dict_ = factored_token_lddt_dist_loss( @@ -288,12 +307,8 @@ def compute_disto_lddt(self, model, batch, out, idx_dataset) -> tuple[dict, dict for key in disto_lddt_dict: # Take min over distograms and average over conformers. Add batch dimension. - disto_lddt_dict[key] = ( - disto_lddt_dict[key].min(dim=1).values.mean(dim=0)[None] - ) - disto_total_dict[key] = ( - disto_total_dict[key].min(dim=1).values.mean(dim=0)[None] - ) + disto_lddt_dict[key] = disto_lddt_dict[key].min(dim=1).values.mean(dim=0)[None] + disto_total_dict[key] = disto_total_dict[key].min(dim=1).values.mean(dim=0)[None] del true_center del preds @@ -332,9 +347,7 @@ def get_lddt_metrics( K = batch["coords"].shape[1] if not expand_to_diffusion_samples: - true_coords_resolved_mask = true_coords_resolved_mask.unsqueeze(0).repeat( - (n_samples, 1) - ) + true_coords_resolved_mask = true_coords_resolved_mask.unsqueeze(0).repeat((n_samples, 1)) ### Compute lddt ### # Implemented in a loop to avoid memory issues with large number @@ -347,9 +360,7 @@ def get_lddt_metrics( if expand_to_diffusion_samples: true_coords_k = true_coords[:, ensemble_idx] else: - true_coords_k = ( - true_coords[ensemble_idx].unsqueeze(0).repeat((n_samples, 1, 1)) - ) + true_coords_k = true_coords[ensemble_idx].unsqueeze(0).repeat((n_samples, 1, 1)) all_lddt_dict_s, all_total_dict_s = factored_lddt_loss( feats=batch, @@ -363,9 +374,7 @@ def get_lddt_metrics( all_total_dict[key].append(all_total_dict_s[key]) for key in all_lddt_dict: - all_lddt_dict[key] = torch.stack( - all_lddt_dict[key], dim=1 - ) # (multiplicity, K) + all_lddt_dict[key] = torch.stack(all_lddt_dict[key], dim=1) # (multiplicity, K) all_total_dict[key] = torch.stack(all_total_dict[key], dim=1) return all_lddt_dict, all_total_dict @@ -400,9 +409,7 @@ def get_pb_metrics( num_chiral_atoms, num_stereo_bond_violations, num_stereo_bonds, - ) = compute_stereo_metrics( - pred_atom_coords=out["sample_atom_coords"], feats=batch - ) + ) = compute_stereo_metrics(pred_atom_coords=out["sample_atom_coords"], feats=batch) ( num_aromatic_5_violations, @@ -411,9 +418,7 @@ def get_pb_metrics( num_aromatic_6_rings, num_double_bond_violations, num_double_bonds, - ) = compute_pb_flatness_metrics( - pred_atom_coords=out["sample_atom_coords"], feats=batch - ) + ) = compute_pb_flatness_metrics(pred_atom_coords=out["sample_atom_coords"], feats=batch) pb_failure_dict = { "bond_length": num_bond_length_failures, @@ -465,17 +470,13 @@ def get_confidence_metrics( # All ensembles have same mask if not expand_to_diffusion_samples: - true_coords_resolved_mask = true_coords_resolved_mask.unsqueeze(0).repeat( - (n_samples, 1) - ) + true_coords_resolved_mask = true_coords_resolved_mask.unsqueeze(0).repeat((n_samples, 1)) for ensemble_idx in range(K): if expand_to_diffusion_samples: true_coords_k = true_coords[:, ensemble_idx] else: - true_coords_k = ( - true_coords[ensemble_idx].unsqueeze(0).repeat((n_samples, 1, 1)) - ) + true_coords_k = true_coords[ensemble_idx].unsqueeze(0).repeat((n_samples, 1, 1)) mae_plddt_dict, total_mae_plddt_dict = compute_plddt_mae( pred_atom_coords=out["sample_atom_coords"], @@ -483,7 +484,6 @@ def get_confidence_metrics( true_atom_coords=true_coords_k, pred_lddt=out["plddt"], true_coords_resolved_mask=true_coords_resolved_mask, - token_level_confidence=model.token_level_confidence, multiplicity=n_samples, ) for key in mae_plddt_dict: @@ -519,21 +519,15 @@ def get_confidence_metrics( # Take mean over ensembles for key in mae_plddt_dicts: mae_plddt_dicts[key] = torch.stack(mae_plddt_dicts[key], dim=0).mean(dim=0) - total_mae_plddt_dicts[key] = torch.stack( - total_mae_plddt_dicts[key], dim=0 - ).mean(dim=0) + total_mae_plddt_dicts[key] = torch.stack(total_mae_plddt_dicts[key], dim=0).mean(dim=0) for key in mae_pde_dicts: mae_pde_dicts[key] = torch.stack(mae_pde_dicts[key], dim=0).mean(dim=0) - total_mae_pde_dicts[key] = torch.stack( - total_mae_pde_dicts[key], dim=0 - ).mean(dim=0) + total_mae_pde_dicts[key] = torch.stack(total_mae_pde_dicts[key], dim=0).mean(dim=0) for key in mae_pae_dicts: mae_pae_dicts[key] = torch.stack(mae_pae_dicts[key], dim=0).mean(dim=0) - total_mae_pae_dicts[key] = torch.stack( - total_mae_pae_dicts[key], dim=0 - ).mean(dim=0) + total_mae_pae_dicts[key] = torch.stack(total_mae_pae_dicts[key], dim=0).mean(dim=0) return ( mae_plddt_dicts, @@ -581,54 +575,38 @@ def update_confidence_metrics( if confidence_metric_name == "complex_plddt": confidence_metric_prefix = "top1" elif "complex" in confidence_metric_name: - confidence_metric_prefix = ( - confidence_metric_name.split("_")[1] + "_top1" - ) + confidence_metric_prefix = confidence_metric_name.split("_")[1] + "_top1" else: confidence_metric_prefix = confidence_metric_name + "_top1" for key in all_lddt_dict: if key == "modified": continue - top1_val = ( - all_lddt_dict[key] - .reshape(n_samples, K)[top1_idx, torch.arange(K)] - .mean(dim=0) - ) - top1_total = ( - all_total_dict[key] - .reshape(n_samples, K)[top1_idx, torch.arange(K)] - .mean(dim=0) + top1_val = all_lddt_dict[key].reshape(n_samples, K)[top1_idx, torch.arange(K)].mean(dim=0) + top1_total = all_total_dict[key].reshape(n_samples, K)[top1_idx, torch.arange(K)].mean(dim=0) + self.confidence_metrics[confidence_metric_prefix + "_lddt"][idx_dataset][key].update( + top1_val, top1_total ) - self.confidence_metrics[confidence_metric_prefix + "_lddt"][ - idx_dataset - ][key].update(top1_val, top1_total) if physicalism_metrics: for key in pair_clash_dict: top1_val = pair_clash_dict[key][top1_idx] top1_total = pair_total_dict[key][top1_idx] - self.confidence_metrics[confidence_metric_prefix + "_clash"][ - idx_dataset - ][key].update(top1_val, top1_total) + self.confidence_metrics[confidence_metric_prefix + "_clash"][idx_dataset][key].update( + top1_val, top1_total + ) for key in pb_failure_dict: top1_val = pb_failure_dict[key][top1_idx] top1_total = pb_total_dict[key][top1_idx] - self.confidence_metrics[confidence_metric_prefix + "_pb"][ - idx_dataset - ][key].update(top1_val, top1_total) + self.confidence_metrics[confidence_metric_prefix + "_pb"][idx_dataset][key].update( + top1_val, top1_total + ) for key in all_lddt_dict: if key == "modified": continue - self.confidence_metrics["avg_lddt"][idx_dataset][key].update( - all_lddt_dict[key], all_total_dict[key] - ) - self.confidence_metrics["pde_mae"][idx_dataset][key].update( - mae_pde_dicts[key], total_mae_pde_dicts[key] - ) - self.confidence_metrics["pae_mae"][idx_dataset][key].update( - mae_pae_dicts[key], total_mae_pae_dicts[key] - ) + self.confidence_metrics["avg_lddt"][idx_dataset][key].update(all_lddt_dict[key], all_total_dict[key]) + self.confidence_metrics["pde_mae"][idx_dataset][key].update(mae_pde_dicts[key], total_mae_pde_dicts[key]) + self.confidence_metrics["pae_mae"][idx_dataset][key].update(mae_pae_dicts[key], total_mae_pae_dicts[key]) for key in mae_plddt_dicts: self.confidence_metrics["plddt_mae"][idx_dataset][key].update( mae_plddt_dicts[key], total_mae_plddt_dicts[key] @@ -640,52 +618,84 @@ def update_confidence_metrics( pair_clash_dict[key], pair_total_dict[key] ) for key in pb_failure_dict: - self.confidence_metrics["avg_pb"][idx_dataset][key].update( - pb_failure_dict[key], pb_total_dict[key] - ) + self.confidence_metrics["avg_pb"][idx_dataset][key].update(pb_failure_dict[key], pb_total_dict[key]) def update_lddt_rmsd_metrics( self, batch, + all_lddt_dict, + all_total_dict, disto_lddt_dict, disto_total_dict, idx_dataset, - return_dict, ): + any_key = next(iter(all_lddt_dict)) + n_samples = all_lddt_dict[any_key].shape[0] + K = batch["coords"].shape[1] + + if n_samples > 1: + complex_total = 0 + complex_lddt = 0 + for key in all_lddt_dict: + if key == "modified": + continue + complex_lddt += all_lddt_dict[key] * all_total_dict[key] + complex_total += all_total_dict[key] + complex_lddt /= complex_total + 1e-7 + best_complex_idx = complex_lddt.reshape(n_samples, K).argmax(dim=0) + + best_lddt_dict = {} + best_total_dict = {} + best_complex_lddt_dict = {} + best_complex_total_dict = {} + conformer_idx = torch.arange(K, device=batch["coords"].device) + for key in all_lddt_dict: + lddt_values = all_lddt_dict[key].reshape(n_samples, K) + total_values = all_total_dict[key].reshape(n_samples, K) + best_idx = lddt_values.argmax(dim=0) + best_lddt_dict[key] = lddt_values[best_idx, conformer_idx] + best_total_dict[key] = total_values[best_idx, conformer_idx] + best_complex_lddt_dict[key] = lddt_values[best_complex_idx, conformer_idx] + best_complex_total_dict[key] = total_values[best_complex_idx, conformer_idx] + else: + best_lddt_dict = all_lddt_dict + best_total_dict = all_total_dict + best_complex_lddt_dict = all_lddt_dict + best_complex_total_dict = all_total_dict + # Folding metrics for m_ in const.out_types: + target_m = m_ if m_ == "ligand_protein": if torch.any( - batch["contact_conditioning"][ - :, :, :, const.contact_conditioning_info["BINDER>POCKET"] - ].bool() + batch["contact_conditioning"][:, :, :, const.contact_conditioning_info["BINDER>POCKET"]].bool() ): - self.folding_metrics["disto_lddt"][idx_dataset][ - "pocket_ligand_protein" - ].update(disto_lddt_dict[m_], disto_total_dict[m_]) + target_m = "pocket_ligand_protein" else: - self.folding_metrics["disto_lddt"][idx_dataset][ - "ligand_protein" - ].update(disto_lddt_dict[m_], disto_total_dict[m_]) + target_m = "ligand_protein" elif m_ == "protein_protein": - if torch.any( - batch["contact_conditioning"][ - :, :, :, const.contact_conditioning_info["CONTACT"] - ].bool() - ): - self.folding_metrics["disto_lddt"][idx_dataset][ - "contact_protein_protein" - ].update(disto_lddt_dict[m_], disto_total_dict[m_]) + if torch.any(batch["contact_conditioning"][:, :, :, const.contact_conditioning_info["CONTACT"]].bool()): + target_m = "contact_protein_protein" else: - self.folding_metrics["disto_lddt"][idx_dataset][ - "protein_protein" - ].update(disto_lddt_dict[m_], disto_total_dict[m_]) + target_m = "protein_protein" + + if ( + m_ not in best_lddt_dict + or m_ not in best_total_dict + or m_ not in disto_lddt_dict + or m_ not in disto_total_dict + or m_ not in best_complex_lddt_dict + or m_ not in best_complex_total_dict + ): + # Some classes (e.g. modified) may be absent for a batch. + continue - else: - self.folding_metrics["disto_lddt"][idx_dataset][m_].update( - disto_lddt_dict[m_], disto_total_dict[m_] - ) + self.folding_metrics["lddt"][idx_dataset][target_m].update(best_lddt_dict[m_], best_total_dict[m_]) + self.folding_metrics["disto_lddt"][idx_dataset][target_m].update(disto_lddt_dict[m_], disto_total_dict[m_]) + self.folding_metrics["complex_lddt"][idx_dataset][target_m].update( + best_complex_lddt_dict[m_], best_complex_total_dict[m_] + ) def update_physcialism_metrics( self, @@ -696,14 +706,10 @@ def update_physcialism_metrics( idx_dataset, ): for key in pair_clash_dict: - self.physicalism_metrics["clash"][idx_dataset][key].update( - pair_clash_dict[key], pair_total_dict[key] - ) + self.physicalism_metrics["clash"][idx_dataset][key].update(pair_clash_dict[key], pair_total_dict[key]) for key in pb_failure_dict: - self.physicalism_metrics["pb"][idx_dataset][key].update( - pb_failure_dict[key], pb_total_dict[key] - ) + self.physicalism_metrics["pb"][idx_dataset][key].update(pb_failure_dict[key], pb_total_dict[key]) def common_val_step( self, @@ -724,9 +730,7 @@ def common_val_step( out : dict[str, torch.Tensor] The output of the model. """ - symmetry_correction = model.val_group_mapper[idx_dataset][ - "symmetry_correction" - ] # global val index + symmetry_correction = model.val_group_mapper[idx_dataset]["symmetry_correction"] # global val index # Get the local validation index from the global index idx_dataset = self.get_local_val_index(model, idx_dataset) @@ -737,9 +741,7 @@ def common_val_step( val_disto_loss = self.compute_disto_loss(model, out, batch, idx_dataset) # Compute distogram lddt and update metrics - disto_lddt_dict, disto_total_dict = self.compute_disto_lddt( - model, batch, out, idx_dataset - ) + disto_lddt_dict, disto_total_dict = self.compute_disto_lddt(model, batch, out, idx_dataset) # Get true coords return_dict = self.get_true_coords( @@ -786,7 +788,14 @@ def common_val_step( pair_clash_dict, pair_total_dict = None, None pb_failure_dict, pb_total_dict = None, None - # Filtering based on confidence + # Gated on confidence_prediction because get_confidence_metrics reads + # out["plddt"], out["pde"], out["pae"] which only exist when the + # confidence module runs. Also gated on n_samples > 1 because ranking + # and averaging across samples is trivial with a single sample. + # Note: avg_lddt/avg_clash/avg_pb within update_confidence_metrics are + # purely structural and don't need the confidence module, but are stored + # in self.confidence_metrics which only exists when confidence_prediction + # is True (see __init__). if model.confidence_prediction and n_samples > 1: ( mae_plddt_dicts, @@ -807,17 +816,16 @@ def common_val_step( ) # Update distogram loss - self.folding_metrics["disto_loss"][idx_dataset]["disto_loss"].update( - val_disto_loss - ) + self.folding_metrics["disto_loss"][idx_dataset]["disto_loss"].update(val_disto_loss) # Update folding metrics self.update_lddt_rmsd_metrics( batch, + all_lddt_dict, + all_total_dict, disto_lddt_dict, disto_total_dict, idx_dataset, - return_dict, ) # Update physcial realism metrics @@ -886,9 +894,7 @@ def common_on_epoch_end(self, model: LightningModule): avg_protein_iptm_top1_pb = [{} for _ in range(self.num_val_datasets)] for idx_dataset in range(self.num_val_datasets): # local idx_dataset - dataset_name_ori = self.val_names[ - idx_dataset - ] # self.val_group_mapper[idx_dataset]["label"] + dataset_name_ori = self.val_names[idx_dataset] # self.val_group_mapper[idx_dataset]["label"] # TODO this is harcodeded for now to compare with Boltz-1 metrics dataset_name = "" if dataset_name_ori == "RCSB" else f"__{dataset_name_ori}" @@ -898,14 +904,20 @@ def common_on_epoch_end(self, model: LightningModule): "pocket_ligand_protein", "contact_protein_protein", ]: - avg_disto_lddt[idx_dataset][m_] = self.folding_metrics["disto_lddt"][ - idx_dataset - ][m_].compute() + avg_lddt[idx_dataset][m_] = self.folding_metrics["lddt"][idx_dataset][m_].compute() + avg_lddt[idx_dataset][m_] = ( + 0.0 if torch.isnan(avg_lddt[idx_dataset][m_]) else avg_lddt[idx_dataset][m_].item() + ) + self.folding_metrics["lddt"][idx_dataset][m_].reset() + model.log( + f"val/lddt_{m_}{dataset_name}", + avg_lddt[idx_dataset][m_], + ) + + avg_disto_lddt[idx_dataset][m_] = self.folding_metrics["disto_lddt"][idx_dataset][m_].compute() avg_disto_lddt[idx_dataset][m_] = ( - 0.0 - if torch.isnan(avg_disto_lddt[idx_dataset][m_]) - else avg_disto_lddt[idx_dataset][m_].item() + 0.0 if torch.isnan(avg_disto_lddt[idx_dataset][m_]) else avg_disto_lddt[idx_dataset][m_].item() ) self.folding_metrics["disto_lddt"][idx_dataset][m_].reset() model.log( @@ -913,28 +925,76 @@ def common_on_epoch_end(self, model: LightningModule): avg_disto_lddt[idx_dataset][m_], ) + avg_complex_lddt[idx_dataset][m_] = self.folding_metrics["complex_lddt"][idx_dataset][m_].compute() + avg_complex_lddt[idx_dataset][m_] = ( + 0.0 if torch.isnan(avg_complex_lddt[idx_dataset][m_]) else avg_complex_lddt[idx_dataset][m_].item() + ) + self.folding_metrics["complex_lddt"][idx_dataset][m_].reset() + model.log( + f"val/complex_lddt_{m_}{dataset_name}", + avg_complex_lddt[idx_dataset][m_], + ) + for m in const.out_single_types: if model.confidence_prediction: - avg_mae_plddt[idx_dataset][m] = ( - self.confidence_metrics["plddt_mae"][idx_dataset][m] - .compute() - .item() - ) + val = self.confidence_metrics["plddt_mae"][idx_dataset][m].compute() + avg_mae_plddt[idx_dataset][m] = 0.0 if torch.isnan(val) else val.item() self.confidence_metrics["plddt_mae"][idx_dataset][m].reset() model.log( f"val/MAE_plddt_{m}{dataset_name}", avg_mae_plddt[idx_dataset][m], ) + if model.confidence_prediction: + confidence_pair_keys = [m_ for m_ in const.out_types if m_ != "modified"] + for m_ in confidence_pair_keys: + for prefix in [ + "top1", + "iplddt_top1", + "ipde_top1", + "pde_top1", + "ptm_top1", + "iptm_top1", + "ligand_iptm_top1", + "protein_iptm_top1", + "avg", + ]: + label = f"{prefix}_lddt" + val = self.confidence_metrics[label][idx_dataset][m_].compute() + val = 0.0 if torch.isnan(val) else val.item() + self.confidence_metrics[label][idx_dataset][m_].reset() + model.log(f"val/{label}_{m_}{dataset_name}", val) + + for mae_label, log_prefix in [("pde_mae", "MAE_pde"), ("pae_mae", "MAE_pae")]: + val = self.confidence_metrics[mae_label][idx_dataset][m_].compute() + val = 0.0 if torch.isnan(val) else val.item() + self.confidence_metrics[mae_label][idx_dataset][m_].reset() + model.log(f"val/{log_prefix}_{m_}{dataset_name}", val) + overall_disto_lddt = sum( - avg_disto_lddt[idx_dataset][m] * w - for (m, w) in const.out_types_weights.items() + avg_disto_lddt[idx_dataset][m] * w for (m, w) in const.out_types_weights.items() ) / sum(const.out_types_weights.values()) model.log( f"val/disto_lddt{dataset_name}", overall_disto_lddt, ) + overall_lddt = sum(avg_lddt[idx_dataset][m] * w for (m, w) in const.out_types_weights.items()) / sum( + const.out_types_weights.values() + ) + model.log( + f"val/lddt{dataset_name}", + overall_lddt, + ) + + overall_complex_lddt = sum( + avg_complex_lddt[idx_dataset][m] * w for (m, w) in const.out_types_weights.items() + ) / sum(const.out_types_weights.values()) + model.log( + f"val/complex_lddt{dataset_name}", + overall_complex_lddt, + ) + # Distogram loss r = self.folding_metrics["disto_loss"][idx_dataset]["disto_loss"].compute() model.log(f"val/disto_loss{dataset_name}", r) @@ -942,16 +1002,10 @@ def common_on_epoch_end(self, model: LightningModule): # Physical realism metrics if self.physicalism_metrics: - for m in ["asym_" + m_ for m_ in const.clash_types] + [ - "sym_" + m_ for m_ in const.out_single_types - ]: - avg_clash[idx_dataset][m] = self.physicalism_metrics["clash"][ - idx_dataset - ][m].compute() + for m in ["asym_" + m_ for m_ in const.clash_types] + ["sym_" + m_ for m_ in const.out_single_types]: + avg_clash[idx_dataset][m] = self.physicalism_metrics["clash"][idx_dataset][m].compute() avg_clash[idx_dataset][m] = ( - 0.0 - if torch.isnan(avg_clash[idx_dataset][m]) - else avg_clash[idx_dataset][m].item() + 0.0 if torch.isnan(avg_clash[idx_dataset][m]) else avg_clash[idx_dataset][m].item() ) self.physicalism_metrics["clash"][idx_dataset][m].reset() model.log( @@ -960,9 +1014,7 @@ def common_on_epoch_end(self, model: LightningModule): ) if model.confidence_prediction: - avg_top1_clash[idx_dataset][m] = self.confidence_metrics[ - "top1_clash" - ][idx_dataset][m].compute() + avg_top1_clash[idx_dataset][m] = self.confidence_metrics["top1_clash"][idx_dataset][m].compute() avg_top1_clash[idx_dataset][m] = ( 0.0 if torch.isnan(avg_top1_clash[idx_dataset][m]) @@ -974,129 +1026,107 @@ def common_on_epoch_end(self, model: LightningModule): avg_top1_clash[idx_dataset][m], ) - avg_iplddt_top1_clash[idx_dataset][m] = self.confidence_metrics[ - "iplddt_top1_clash" - ][idx_dataset][m].compute() + avg_iplddt_top1_clash[idx_dataset][m] = self.confidence_metrics["iplddt_top1_clash"][ + idx_dataset + ][m].compute() avg_iplddt_top1_clash[idx_dataset][m] = ( 0.0 if torch.isnan(avg_iplddt_top1_clash[idx_dataset][m]) else avg_iplddt_top1_clash[idx_dataset][m].item() ) - self.confidence_metrics["iplddt_top1_clash"][idx_dataset][ - m - ].reset() + self.confidence_metrics["iplddt_top1_clash"][idx_dataset][m].reset() model.log( f"val/iplddt_top1_clash_{m}{dataset_name}", avg_iplddt_top1_clash[idx_dataset][m], ) - avg_pde_top1_clash[idx_dataset][m] = self.confidence_metrics[ - "pde_top1_clash" - ][idx_dataset][m].compute() + avg_pde_top1_clash[idx_dataset][m] = self.confidence_metrics["pde_top1_clash"][idx_dataset][ + m + ].compute() avg_pde_top1_clash[idx_dataset][m] = ( 0.0 if torch.isnan(avg_pde_top1_clash[idx_dataset][m]) else avg_pde_top1_clash[idx_dataset][m].item() ) - self.confidence_metrics["pde_top1_clash"][idx_dataset][ - m - ].reset() + self.confidence_metrics["pde_top1_clash"][idx_dataset][m].reset() model.log( f"val/pde_top1_clash_{m}{dataset_name}", avg_pde_top1_clash[idx_dataset][m], ) - avg_ipde_top1_clash[idx_dataset][m] = self.confidence_metrics[ - "ipde_top1_clash" - ][idx_dataset][m].compute() + avg_ipde_top1_clash[idx_dataset][m] = self.confidence_metrics["ipde_top1_clash"][idx_dataset][ + m + ].compute() avg_ipde_top1_clash[idx_dataset][m] = ( 0.0 if torch.isnan(avg_ipde_top1_clash[idx_dataset][m]) else avg_ipde_top1_clash[idx_dataset][m].item() ) - self.confidence_metrics["ipde_top1_clash"][idx_dataset][ - m - ].reset() + self.confidence_metrics["ipde_top1_clash"][idx_dataset][m].reset() model.log( f"val/ipde_top1_clash_{m}{dataset_name}", avg_ipde_top1_clash[idx_dataset][m], ) - avg_ptm_top1_clash[idx_dataset][m] = self.confidence_metrics[ - "ptm_top1_clash" - ][idx_dataset][m].compute() + avg_ptm_top1_clash[idx_dataset][m] = self.confidence_metrics["ptm_top1_clash"][idx_dataset][ + m + ].compute() avg_ptm_top1_clash[idx_dataset][m] = ( 0.0 if torch.isnan(avg_ptm_top1_clash[idx_dataset][m]) else avg_ptm_top1_clash[idx_dataset][m].item() ) - self.confidence_metrics["ptm_top1_clash"][idx_dataset][ - m - ].reset() + self.confidence_metrics["ptm_top1_clash"][idx_dataset][m].reset() model.log( f"val/ptm_top1_clash_{m}{dataset_name}", avg_ptm_top1_clash[idx_dataset][m], ) - avg_iptm_top1_clash[idx_dataset][m] = self.confidence_metrics[ - "iptm_top1_clash" - ][idx_dataset][m].compute() + avg_iptm_top1_clash[idx_dataset][m] = self.confidence_metrics["iptm_top1_clash"][idx_dataset][ + m + ].compute() avg_iptm_top1_clash[idx_dataset][m] = ( 0.0 if torch.isnan(avg_iptm_top1_clash[idx_dataset][m]) else avg_iptm_top1_clash[idx_dataset][m].item() ) - self.confidence_metrics["iptm_top1_clash"][idx_dataset][ - m - ].reset() + self.confidence_metrics["iptm_top1_clash"][idx_dataset][m].reset() model.log( f"val/iptm_top1_clash_{m}{dataset_name}", avg_iptm_top1_clash[idx_dataset][m], ) - avg_ligand_iptm_top1_clash[idx_dataset][m] = ( - self.confidence_metrics["ligand_iptm_top1_clash"][ - idx_dataset - ][m].compute() - ) + avg_ligand_iptm_top1_clash[idx_dataset][m] = self.confidence_metrics["ligand_iptm_top1_clash"][ + idx_dataset + ][m].compute() avg_ligand_iptm_top1_clash[idx_dataset][m] = ( 0.0 if torch.isnan(avg_ligand_iptm_top1_clash[idx_dataset][m]) else avg_ligand_iptm_top1_clash[idx_dataset][m].item() ) - self.confidence_metrics["ligand_iptm_top1_clash"][idx_dataset][ - m - ].reset() + self.confidence_metrics["ligand_iptm_top1_clash"][idx_dataset][m].reset() model.log( f"val/ligand_iptm_top1_clash_{m}{dataset_name}", avg_ligand_iptm_top1_clash[idx_dataset][m], ) - avg_protein_iptm_top1_clash[idx_dataset][m] = ( - self.confidence_metrics["protein_iptm_top1_clash"][ - idx_dataset - ][m].compute() - ) + avg_protein_iptm_top1_clash[idx_dataset][m] = self.confidence_metrics[ + "protein_iptm_top1_clash" + ][idx_dataset][m].compute() avg_protein_iptm_top1_clash[idx_dataset][m] = ( 0.0 if torch.isnan(avg_protein_iptm_top1_clash[idx_dataset][m]) else avg_protein_iptm_top1_clash[idx_dataset][m].item() ) - self.confidence_metrics["protein_iptm_top1_clash"][idx_dataset][ - m - ].reset() + self.confidence_metrics["protein_iptm_top1_clash"][idx_dataset][m].reset() model.log( f"val/protein_iptm_top1_clash_{m}{dataset_name}", avg_protein_iptm_top1_clash[idx_dataset][m], ) - avg_avg_clash[idx_dataset][m] = self.confidence_metrics[ - "avg_clash" - ][idx_dataset][m].compute() + avg_avg_clash[idx_dataset][m] = self.confidence_metrics["avg_clash"][idx_dataset][m].compute() avg_avg_clash[idx_dataset][m] = ( - 0.0 - if torch.isnan(avg_avg_clash[idx_dataset][m]) - else avg_avg_clash[idx_dataset][m].item() + 0.0 if torch.isnan(avg_avg_clash[idx_dataset][m]) else avg_avg_clash[idx_dataset][m].item() ) self.confidence_metrics["avg_clash"][idx_dataset][m].reset() model.log( @@ -1114,13 +1144,9 @@ def common_on_epoch_end(self, model: LightningModule): "ring_6_flatness", "double_bond_flatness", ]: - avg_pb[idx_dataset][m] = self.physicalism_metrics["pb"][ - idx_dataset - ][m].compute() + avg_pb[idx_dataset][m] = self.physicalism_metrics["pb"][idx_dataset][m].compute() avg_pb[idx_dataset][m] = ( - 0.0 - if torch.isnan(avg_pb[idx_dataset][m]) - else avg_pb[idx_dataset][m].item() + 0.0 if torch.isnan(avg_pb[idx_dataset][m]) else avg_pb[idx_dataset][m].item() ) self.physicalism_metrics["pb"][idx_dataset][m].reset() model.log( @@ -1129,13 +1155,9 @@ def common_on_epoch_end(self, model: LightningModule): ) if model.confidence_prediction: - avg_top1_pb[idx_dataset][m] = self.confidence_metrics[ - "top1_pb" - ][idx_dataset][m].compute() + avg_top1_pb[idx_dataset][m] = self.confidence_metrics["top1_pb"][idx_dataset][m].compute() avg_top1_pb[idx_dataset][m] = ( - 0.0 - if torch.isnan(avg_top1_pb[idx_dataset][m]) - else avg_top1_pb[idx_dataset][m].item() + 0.0 if torch.isnan(avg_top1_pb[idx_dataset][m]) else avg_top1_pb[idx_dataset][m].item() ) self.confidence_metrics["top1_pb"][idx_dataset][m].reset() model.log( @@ -1143,25 +1165,23 @@ def common_on_epoch_end(self, model: LightningModule): avg_top1_pb[idx_dataset][m], ) - avg_iplddt_top1_pb[idx_dataset][m] = self.confidence_metrics[ - "iplddt_top1_pb" - ][idx_dataset][m].compute() + avg_iplddt_top1_pb[idx_dataset][m] = self.confidence_metrics["iplddt_top1_pb"][idx_dataset][ + m + ].compute() avg_iplddt_top1_pb[idx_dataset][m] = ( 0.0 if torch.isnan(avg_iplddt_top1_pb[idx_dataset][m]) else avg_iplddt_top1_pb[idx_dataset][m].item() ) - self.confidence_metrics["iplddt_top1_pb"][idx_dataset][ - m - ].reset() + self.confidence_metrics["iplddt_top1_pb"][idx_dataset][m].reset() model.log( f"val/iplddt_top1_pb_{m}{dataset_name}", avg_iplddt_top1_pb[idx_dataset][m], ) - avg_pde_top1_pb[idx_dataset][m] = self.confidence_metrics[ - "pde_top1_pb" - ][idx_dataset][m].compute() + avg_pde_top1_pb[idx_dataset][m] = self.confidence_metrics["pde_top1_pb"][idx_dataset][ + m + ].compute() avg_pde_top1_pb[idx_dataset][m] = ( 0.0 if torch.isnan(avg_pde_top1_pb[idx_dataset][m]) @@ -1173,9 +1193,9 @@ def common_on_epoch_end(self, model: LightningModule): avg_pde_top1_pb[idx_dataset][m], ) - avg_ipde_top1_pb[idx_dataset][m] = self.confidence_metrics[ - "ipde_top1_pb" - ][idx_dataset][m].compute() + avg_ipde_top1_pb[idx_dataset][m] = self.confidence_metrics["ipde_top1_pb"][idx_dataset][ + m + ].compute() avg_ipde_top1_pb[idx_dataset][m] = ( 0.0 if torch.isnan(avg_ipde_top1_pb[idx_dataset][m]) @@ -1187,9 +1207,9 @@ def common_on_epoch_end(self, model: LightningModule): avg_ipde_top1_pb[idx_dataset][m], ) - avg_ptm_top1_pb[idx_dataset][m] = self.confidence_metrics[ - "ptm_top1_pb" - ][idx_dataset][m].compute() + avg_ptm_top1_pb[idx_dataset][m] = self.confidence_metrics["ptm_top1_pb"][idx_dataset][ + m + ].compute() avg_ptm_top1_pb[idx_dataset][m] = ( 0.0 if torch.isnan(avg_ptm_top1_pb[idx_dataset][m]) @@ -1201,9 +1221,9 @@ def common_on_epoch_end(self, model: LightningModule): avg_ptm_top1_pb[idx_dataset][m], ) - avg_iptm_top1_pb[idx_dataset][m] = self.confidence_metrics[ - "iptm_top1_pb" - ][idx_dataset][m].compute() + avg_iptm_top1_pb[idx_dataset][m] = self.confidence_metrics["iptm_top1_pb"][idx_dataset][ + m + ].compute() avg_iptm_top1_pb[idx_dataset][m] = ( 0.0 if torch.isnan(avg_iptm_top1_pb[idx_dataset][m]) @@ -1215,49 +1235,37 @@ def common_on_epoch_end(self, model: LightningModule): avg_iptm_top1_pb[idx_dataset][m], ) - avg_ligand_iptm_top1_pb[idx_dataset][m] = ( - self.confidence_metrics["ligand_iptm_top1_pb"][idx_dataset][ - m - ].compute() - ) + avg_ligand_iptm_top1_pb[idx_dataset][m] = self.confidence_metrics["ligand_iptm_top1_pb"][ + idx_dataset + ][m].compute() avg_ligand_iptm_top1_pb[idx_dataset][m] = ( 0.0 if torch.isnan(avg_ligand_iptm_top1_pb[idx_dataset][m]) else avg_ligand_iptm_top1_pb[idx_dataset][m].item() ) - self.confidence_metrics["ligand_iptm_top1_pb"][idx_dataset][ - m - ].reset() + self.confidence_metrics["ligand_iptm_top1_pb"][idx_dataset][m].reset() model.log( f"val/ligand_iptm_top1_pb_{m}{dataset_name}", avg_ligand_iptm_top1_pb[idx_dataset][m], ) - avg_protein_iptm_top1_pb[idx_dataset][m] = ( - self.confidence_metrics["protein_iptm_top1_pb"][ - idx_dataset - ][m].compute() - ) + avg_protein_iptm_top1_pb[idx_dataset][m] = self.confidence_metrics["protein_iptm_top1_pb"][ + idx_dataset + ][m].compute() avg_protein_iptm_top1_pb[idx_dataset][m] = ( 0.0 if torch.isnan(avg_protein_iptm_top1_pb[idx_dataset][m]) else avg_protein_iptm_top1_pb[idx_dataset][m].item() ) - self.confidence_metrics["protein_iptm_top1_pb"][idx_dataset][ - m - ].reset() + self.confidence_metrics["protein_iptm_top1_pb"][idx_dataset][m].reset() model.log( f"val/protein_iptm_top1_pb_{m}{dataset_name}", avg_protein_iptm_top1_pb[idx_dataset][m], ) - avg_avg_pb[idx_dataset][m] = self.confidence_metrics["avg_pb"][ - idx_dataset - ][m].compute() + avg_avg_pb[idx_dataset][m] = self.confidence_metrics["avg_pb"][idx_dataset][m].compute() avg_avg_pb[idx_dataset][m] = ( - 0.0 - if torch.isnan(avg_avg_pb[idx_dataset][m]) - else avg_avg_pb[idx_dataset][m].item() + 0.0 if torch.isnan(avg_avg_pb[idx_dataset][m]) else avg_avg_pb[idx_dataset][m].item() ) self.confidence_metrics["avg_pb"][idx_dataset][m].reset() model.log( diff --git a/src/boltz/testing/__init__.py b/src/boltz/testing/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/src/boltz/testing/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/src/boltz/testing/utils.py b/src/boltz/testing/utils.py new file mode 100644 index 000000000..c1973df5c --- /dev/null +++ b/src/boltz/testing/utils.py @@ -0,0 +1,3730 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# fmt: off + +import json +import math +import multiprocessing +import os +import random +import shutil +import time +import warnings +from collections import OrderedDict +from contextlib import contextmanager +from dataclasses import asdict, dataclass +from functools import partial, reduce +from pathlib import Path +from typing import Any, Callable, Optional + +import numpy as np +import pytest +import requests +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard +from torch.distributed.tensor._utils import compute_global_tensor_info + +from boltz.data import const +from boltz.data.feature.featurizer import ( + BoltzFeaturizer, + Tokenized, +) +from boltz.data.load import CACHE_DIR +from boltz.distributed.data.feature.featurizer_utils import ( + get_num_atoms_tokens, + get_pair_mask, +) +from boltz.distributed.data.types import PairMaskMode + +# Try to import from main, fall back to defaults if not available +try: + from boltz.main import MODEL_URL, BoltzDiffusionParams +except ImportError: + MODEL_URL = "" + BoltzDiffusionParams = None + +# Import from v2 modules for Boltz-2 +from boltz.distributed.model.modules.utils import Precision +from boltz.distributed.utils import LayoutRightMap +from boltz.model.layers.attentionv2 import AttentionPairBias as SerialAttentionPairBias +from boltz.model.layers.pair_averaging import PairWeightedAveraging as SerialPairWeightedAveraging +from boltz.model.layers.triangular_attention.attention import TriangleAttention as SerialTriangleAttention +from boltz.model.layers.triangular_attention.attention import ( + TriangleAttentionEndingNode as SerialTriangleAttentionEndingNode, +) + +# Import v2 encoder functions for window batching +from boltz.model.modules.encodersv2 import get_indexing_matrix, single_to_keys +from boltz.model.modules.transformersv2 import AtomTransformer as SerialAtomTransformer + +PRECISION_TO_INF = { + Precision.FP16: 6e4, + Precision.BF16: 1e9, + Precision.TF32: 1e9, + Precision.FP32: 1e9, + Precision.FP64: 1e18, +} + + +def is_a6000_gpu() -> bool: + # Check if any of the visible GPUs is an A6000 + for i in range(torch.cuda.device_count()): + device_name = torch.cuda.get_device_name(i) + if "A6000" in device_name: + return True + return False + + +def download_model_ckpt() -> Path: + """Download the model checkpoint for regression and e2e tests.""" + cache = CACHE_DIR / "regression" + if not cache.exists(): + cache.mkdir(parents=True, exist_ok=True) + checkpoint_url = MODEL_URL + model_name = checkpoint_url.split("/")[-1] + checkpoint = cache / model_name + if not checkpoint.exists(): + download_file(checkpoint_url, checkpoint) + return checkpoint + + +def map_to_device(*args, device: torch.device | str = "cpu") -> tuple[Tensor, ...]: + return tuple(arg.to(device) if torch.is_tensor(arg) else arg for arg in args) + + +def get_chunk_size(N_tokens: int, ring_size: int) -> int: + chunk_size = N_tokens / ring_size + assert chunk_size.is_integer(), "number of tokens must be divisible by square root of context parallel size" + return int(chunk_size) + + +def chunk_along_dim(*args, dim: int, chunks: int, chunk_i: int) -> Tensor | tuple[Tensor, ...]: + if len(args) == 1: + return args[0].chunk(chunks, dim=dim)[chunk_i] + return tuple(t.chunk(chunks, dim=dim)[chunk_i] for t in args) + + +def permute_final_dims(tensor: torch.Tensor, inds: list[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def get_weighted_lddt( + all_atom_pred_pos: torch.Tensor, + all_atom_positions: torch.Tensor, + all_atom_mask: torch.Tensor, + cutoff: float = 15.0, + eps: float = 1e-10, + per_residue: bool = True, +) -> torch.Tensor: + all_atom_mask = all_atom_mask.unsqueeze(-1) + n = all_atom_mask.shape[-2] + dmat_true = torch.sqrt( + eps + + torch.sum( + (all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :]) ** 2, + dim=-1, + ) + ) + + dmat_pred = torch.sqrt( + eps + + torch.sum( + (all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2, + dim=-1, + ) + ) + dists_to_score = ( + (dmat_true < cutoff) + * all_atom_mask + * permute_final_dims(all_atom_mask, (1, 0)) + * (1.0 - torch.eye(n, device=all_atom_mask.device)) + ) + + dist_l1 = torch.abs(dmat_true - dmat_pred) + + score = ( + (dist_l1 < 0.5).type(dist_l1.dtype) + + (dist_l1 < 1.0).type(dist_l1.dtype) + + (dist_l1 < 2.0).type(dist_l1.dtype) + + (dist_l1 < 4.0).type(dist_l1.dtype) + ) + score = score * 0.25 + + dims = (-1,) if per_residue else (-2, -1) + norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims)) + score = norm * (eps + torch.sum(dists_to_score * score, dim=dims)) + + return score + + +def compute_pairwise_lddt_rmsd_matrices( + coords_a: list[torch.Tensor], + coords_b: list[torch.Tensor], +) -> tuple[np.ndarray, np.ndarray]: + """Compute NxM pairwise lDDT and RMSD matrices between two sets of structures. + + Each element is a coordinate tensor of shape ``(n_atoms, 3)``. + + Returns + ------- + lddt_matrix : np.ndarray, shape (N, M) + rmsd_matrix : np.ndarray, shape (N, M) + """ + from boltz.model.loss.diffusion import weighted_rigid_align + + n, m = len(coords_a), len(coords_b) + lddt_mat = np.zeros((n, m)) + rmsd_mat = np.zeros((n, m)) + for i in range(n): + for j in range(m): + a = coords_a[i].unsqueeze(0) + b = coords_b[j].unsqueeze(0) + w = torch.ones_like(a[..., 0]) + b_aligned = weighted_rigid_align(true_coords=b, pred_coords=a, weights=w, mask=w) + lddt_mat[i, j] = get_weighted_lddt(a, b_aligned, w).mean().item() + rmsd_mat[i, j] = torch.sum((a - b_aligned) ** 2, dim=-1).sqrt().mean().item() + return lddt_mat, rmsd_mat + + +def energy_distance_from_matrices( + cross: np.ndarray, + intra_a: np.ndarray, + intra_b: np.ndarray, + maximize: bool = False, +) -> float: + """Compute energy distance from pre-computed pairwise metric matrices. + + For a metric where higher is better (``maximize=True``, e.g. lDDT), the + pairwise distance is ``1 - metric``. For lower-is-better (``maximize=False``, + e.g. RMSD), the metric value is used directly as distance. + + Parameters + ---------- + cross : (N, M) array -- all pairs between the two distributions. + intra_a : (N, N) array -- all pairs within distribution A. + intra_b : (M, M) array -- all pairs within distribution B. + + Uses the upper triangle (excluding diagonal) for intra-distribution means. + """ + mean_cross = cross.mean() + triu_a = intra_a[np.triu_indices(intra_a.shape[0], k=1)] + triu_b = intra_b[np.triu_indices(intra_b.shape[0], k=1)] + mean_intra_a = triu_a.mean() + mean_intra_b = triu_b.mean() + if maximize: + return 2 * (1 - mean_cross) - (1 - mean_intra_a) - (1 - mean_intra_b) + return 2 * mean_cross - mean_intra_a - mean_intra_b + + +def matched_mean_metric(matrix: np.ndarray, maximize: bool = False) -> float: + """Optimal 1-to-1 matching (Hungarian) and return mean of matched values. + + Parameters + ---------- + matrix : (N, M) array of metric values. + maximize : if True, maximise the sum (e.g. lDDT); else minimise (e.g. RMSD). + """ + from scipy.optimize import linear_sum_assignment + + cost = (1.0 - matrix) if maximize else matrix + row_ind, col_ind = linear_sum_assignment(cost) + return float(matrix[row_ind, col_ind].mean()) + + +def intra_rowwise_best(matrix: np.ndarray, maximize: bool = False) -> float: + """Mean of row-wise best value from a square matrix, excluding the diagonal. + + For lDDT (``maximize=True``) returns mean of row-wise max. + For RMSD (``maximize=False``) returns mean of row-wise min. + """ + mat = matrix.copy() + if maximize: + np.fill_diagonal(mat, -np.inf) + return float(mat.max(axis=1).mean()) + np.fill_diagonal(mat, np.inf) + return float(mat.min(axis=1).mean()) + + +def download_file(url, filepath, verbose=True): + if verbose: + print(f"Downloading {url} to {filepath}") + response = requests.get(url) + + target_dir = os.path.dirname(filepath) + if target_dir and not os.path.exists(target_dir): + os.makedirs(target_dir) + + # Check if the request was successful + if response.status_code == 200: + with open(filepath, "wb") as file: + file.write(response.content) + else: + print(f"Failed to download file. Status code: {response.status_code}") + + return filepath + + +def detach_and_clone_tensors(tensors: list[Tensor | None], requires_grad: bool = False) -> list[Tensor | None]: + return [t.detach().clone().requires_grad_(requires_grad) if t is not None else None for t in tensors] + + +def assert_tensors_identical( + tensor1: torch.Tensor, + tensor2: torch.Tensor, + check_stride: bool = True, + check_grad: bool = True, + check_grad_fn: bool = True, + check_storage_offset: bool = True, + check_storage_pointer: bool = False, + rtol: float = 0.0, + atol: float = 0.0, + **kwargs_torch_testing_assert_close: dict[str, Any], +) -> None: + """Verify that two PyTorch tensors are identical with configurable strictness. + + Performs a multi-phase validation to ensure tensors match across different aspects: + - Phase 1: Core tensor properties (values, device, dtype, layout, stride) + - Phase 2: Gradient requirements + - Phase 3: Gradient content comparison (if check_grad=True) + - Phase 4: Autograd computation graph (if check_grad_fn=True) + - Phase 5: Memory layout validation on storage offset (if check_storage_offset=True) + - Phase 6: Storage pointers (if check_storage_pointer=True) + + Args: + tensor1: First PyTorch tensor to compare + tensor2: Second PyTorch tensor to compare + check_stride: Whether to check that strides match + check_grad: Whether to check that gradients match (if present) + check_grad_fn: Whether to check that gradient functions match + check_storage_offset: Whether to check that storage offset matches + check_storage_pointer: Whether to check that storage pointers match + rtol: Relative tolerance for torch.testing.assert_close + atol: Absolute tolerance for torch.testing.assert_close + **kwargs_torch_testing_assert_close: Additional keyword arguments for torch.testing.assert_close + + Raises: + AssertionError: If any validation phase fails + """ + if tensor1 is tensor2: + return # Short-circuit for identical objects + + # Phase 1: Core tensor properties and values + torch.testing.assert_close( + tensor1, + tensor2, + rtol=rtol, + atol=atol, + check_device=True, + check_dtype=True, + check_layout=True, + check_stride=check_stride, + equal_nan=True, + **kwargs_torch_testing_assert_close, + ) + + # Phase 2: Gradient requirements + assert tensor1.requires_grad == tensor2.requires_grad, "Input tensors' requires_grad mismatch" + + if check_grad: + # Phase 3: Gradient content comparison + grad1 = tensor1.grad + grad2 = tensor2.grad + + assert (grad1 is None) == (grad2 is None), "Input tensors' gradient existence mismatch" + + if grad1 is not None and grad2 is not None: + torch.testing.assert_close( + grad1, + grad2, + rtol=rtol, + atol=atol, + check_device=True, + check_dtype=True, + check_layout=True, + check_stride=True, + equal_nan=True, + **kwargs_torch_testing_assert_close, + ) + + if check_grad_fn: + # Verify autograd graph compatibility + assert ( + tensor1.grad_fn == tensor2.grad_fn + ), "Autograd computation graph mismatch - Input tensors created through different operations" + + # Phase 4: Memory layout validation + if check_storage_offset: + assert tensor1.storage_offset() == tensor2.storage_offset(), "Input tensors' Storage offset mismatch" + + # Phase 5: Optional storage validation + if check_storage_pointer: + ptr1 = tensor1.storage().data_ptr() + ptr2 = tensor2.storage().data_ptr() + assert ptr1 == ptr2, "Input tensors' Storage pointers mismatch" + + +def assert_tensors_close_with_pad( + a: torch.Tensor, + b: torch.Tensor, + axis: int, + pad_val: Any = 0, + **kwargs, +) -> None: + """Assert that two tensors are close, handling padding along a specified axis. + + Compares the overlapping region of two tensors along the specified axis, + and verifies that the longer tensor's trailing elements are all equal to `pad_val`. + + This is useful for comparing tensors where one has been padded to a larger size, + e.g., after distributed_pack_and_pad operations. + + Args: + a: First tensor to compare. + b: Second tensor to compare. + axis: The axis along which to compare and check padding. + pad_val: The expected value for padding elements (default: 0). + **kwargs: Additional keyword arguments forwarded to torch.testing.assert_close + (e.g., atol, rtol, msg). + + Raises: + AssertionError: If the overlapping regions don't match or trailing elements + are not equal to pad_val. + + Example: + >>> a = torch.tensor([1, 2, 3, 0, 0]) # padded tensor + >>> b = torch.tensor([1, 2, 3]) # original tensor + >>> assert_tensors_close_with_pad(a, b, axis=0, pad_val=0) # passes + """ + len_a = a.shape[axis] + len_b = b.shape[axis] + min_len = min(len_a, len_b) + + # Create slices for the overlapping region + slices_a = [slice(None)] * a.ndim + slices_b = [slice(None)] * b.ndim + slices_a[axis] = slice(0, min_len) + slices_b[axis] = slice(0, min_len) + + # Compare overlapping region + torch.testing.assert_close(a[tuple(slices_a)], b[tuple(slices_b)], **kwargs) + + # Verify trailing padding in the longer tensor + if len_a > len_b: + slices_trailing = [slice(None)] * a.ndim + slices_trailing[axis] = slice(min_len, len_a) + trailing = a[tuple(slices_trailing)] + expected_pad = torch.full_like(trailing, pad_val) + torch.testing.assert_close( + trailing, + expected_pad, + msg=lambda m: f"Trailing elements in tensor 'a' are not equal to pad_val={pad_val}: {m}", + ) + elif len_b > len_a: + slices_trailing = [slice(None)] * b.ndim + slices_trailing[axis] = slice(min_len, len_b) + trailing = b[tuple(slices_trailing)] + expected_pad = torch.full_like(trailing, pad_val) + torch.testing.assert_close( + trailing, + expected_pad, + msg=lambda m: f"Trailing elements in tensor 'b' are not equal to pad_val={pad_val}: {m}", + ) + + +def assert_all_identical(tensor: torch.Tensor, group: torch.distributed.ProcessGroup, *args, **kwargs) -> None: + """Verify that a tensor is identical across all processes in a distributed setup. + + Gathers the tensor from all processes in the specified process group and verifies + that they are all identical to the input tensor using assert_tensors_identical. + + Args: + tensor: The PyTorch tensor to verify across processes + group: The process group to gather tensors from + *args: Additional positional arguments passed to assert_tensors_identical + **kwargs: Additional keyword arguments passed to assert_tensors_identical + (e.g. check_grad, check_grad_fn, check_storage) + + Raises: + AssertionError: If any tensor from any process differs from the input tensor + """ + world_size = torch.distributed.get_world_size(group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(tensor_list, tensor, group=group) + for i in range(world_size): + assert_tensors_identical(tensor_list[i], tensor, *args, **kwargs) + + +def save_gradients(model: torch.nn.Module, detach_host: Optional[bool] = False) -> dict[str, torch.Tensor | None]: + """Save gradients of a model's parameters into a dictionary: parameter_name -> gradient.""" + grad_dict = {name: param.grad if param.grad is not None else None for name, param in model.named_parameters()} + if detach_host: + for name, grad in grad_dict.items(): + if grad is not None: + grad_dict[name] = grad.detach().cpu() + return grad_dict + + +def try_assert_and_collect(assertion_func: Callable, *args, error_name: str, errors_list: list[str], **kwargs) -> None: + """ + Try to run an assertion function and collect any errors. + + Args: + assertion_func: The assertion function to call (e.g., torch.testing.assert_close) + *args: Positional arguments to pass to the assertion function + error_name: Name to use in the error message + errors_list: List to collect error messages + **kwargs: Keyword arguments to pass to the assertion function + """ + try: + assertion_func(*args, **kwargs) + except AssertionError as e: + error_lines = str(e).strip().split("\n") + last_three_lines = "\n".join(error_lines[-3:]) if len(error_lines) >= 3 else str(e) + errors_list.append(f"{error_name} assertion failed: {last_three_lines}") + + +def repad_tensor(tensor: Tensor, pad_mask: Tensor, dim: int = 1) -> Tensor: + """Reinsert padding into a tensor based on a padding mask. + + Args: + tensor: Tensor without padding, shape [..., N_items_non_padded, ...] + pad_mask: Boolean mask indicating valid (non-padding) positions, shape [N_items_total] + dim: Dimension along which to reinsert padding + + Returns: + Padded tensor with shape [..., N_items_total, ...] + """ + assert pad_mask.ndim == 1 + + # Get total length including padding positions + N_total = len(pad_mask) + + # Get shape of padded tensor + padded_shape = list(tensor.shape) + padded_shape[dim] = N_total + + # Create padded tensor filled with zeros + padded_tensor = torch.zeros(padded_shape, device=tensor.device, dtype=tensor.dtype) + + # Get indices of valid (non-padding) positions + valid_indices = torch.nonzero(pad_mask).squeeze() + + # Create indexing tuples for dynamic slicing + idx_specs = [slice(None)] * len(padded_shape) + idx_specs[dim] = valid_indices + + # Assign the non-padded values to their correct positions in the padded tensor + padded_tensor[tuple(idx_specs)] = tensor + return padded_tensor + + +def all_gather_tensors_along_dim(tensor: Tensor, group: dist.ProcessGroup, dim: int = -1) -> Tensor: + """All gather a tensor by concatenating along a specified dimension. + + Args: + tensor (torch.Tensor): Tensor to all gather with a shape of (..., size, ...). + group (dist.ProcessGroup): Process group to all gather on. + dim (int): Dimension to concatenate along. + + Returns: + torch.Tensor: All gathered tensor; shape = (..., size * world_size, ...) + + """ + if tensor is None: + return None + + if tensor.requires_grad: + raise ValueError("all_gather_tensors_along_dim breaks gradient tracking and sent tensor requires grad") + + size = dist.get_world_size(group) + recv = [torch.empty_like(tensor) for _ in range(size)] + dist.all_gather(recv, tensor, group=group) + output = torch.cat(recv, dim=dim) + return output + + +def all_gather_pair_repr_along_dims( + tensor: Tensor, + cp_group: dist.ProcessGroup, + axis_0_size: int, + dim0: int, + dim1: int, +) -> Tensor: + """All gather pair representation along two dimensions. + + Args: + tensor: Local pair representation tensor + cp_group: Process group for communication + axis_0_size: Size of the first dimension of the pair representation + dim0: First dimension of the pair representation + dim1: Second dimension of the pair representation + + Returns: + torch.Tensor: All gathered tensor; shape = (..., axis_0_size * world_size, axis_0_size * world_size, ...) + """ + if tensor is None: + return None + + if tensor.requires_grad: + raise ValueError("all_gather_pair_repr_along_dims breaks gradient tracking and sent tensor requires grad") + + tensor_list = [torch.empty_like(tensor) for _ in range(cp_group.size())] + dist.all_gather(tensor_list, tensor, group=cp_group) + + tensor_list = [ + torch.cat(tensor_list[i * axis_0_size : (i + 1) * axis_0_size], dim=dim1) for i in range(axis_0_size) + ] + return torch.cat(tensor_list, dim=dim0) + + +def seed_by_rank(rank: int, seed: int = 42) -> None: + """Set random seeds based on process rank to ensure reproducible but different randomness per rank. + + This function ensures that each process in a distributed setting has its own + deterministic random state, which is important for reproducible tests while + maintaining appropriate randomness across different ranks. + + Args: + rank: The process rank to use for seeding + seed: Base seed value to which the rank is added (default: 42) + """ + seed = rank + seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def skip_if_cuda_not_avail_or_device_count_less_than_word_size(device_type: str, world_size: int) -> None: + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + +def spawn_multiprocessing(fn: Callable[[int, ...], None], world_size: int, *args) -> None: + """Spawn multiple processes using torch.multiprocessing for distributed testing. + + This function provides a convenient wrapper around torch.multiprocessing.spawn() + with the spawn start method, which is commonly used for distributed PyTorch testing + to ensure clean process isolation. + + Args: + fn: The function to execute in each spawned process. The function must accept + the process rank as its first argument, followed by any additional arguments + passed via *args. Signature should be: fn(rank: int, *args) -> None + world_size: Number of processes to spawn (typically equal to the number of GPUs + or the desired degree of parallelism) + *args: Additional positional arguments to pass to the spawned function + + Example: + ```python + def test_distributed_function(rank: int, tensor_size: int, device_prefix: str): + device = f"{device_prefix}:{rank}" + tensor = torch.randn(tensor_size, device=device) + # ... distributed testing logic ... + + # Spawn 4 processes for testing + spawn_multiprocessing(test_distributed_function, 4, 1024, "cuda") + ``` + + Note: + - The spawn start method creates completely isolated processes, which is safer + for testing but has more overhead than fork on Unix systems + - Each spawned process will receive a unique rank (0 to world_size-1) as the + first argument to the provided function + - This function blocks until all spawned processes complete + """ + torch.multiprocessing.set_start_method("spawn", force=True) + torch.multiprocessing.spawn( + fn=fn, + args=args, + nprocs=world_size, + join=True, + ) + + +def assert_close_statistics( + x: Tensor, + x_ref: Tensor, + *args, + mean_threshold: Optional[float] = None, + median_threshold: Optional[float] = None, + **kwargs, +) -> None: + # shortcircuit for empty tensors + if x.numel() == 0: + return + + diff = (x - x_ref).abs().cpu() + mean = diff.mean().item() + median = diff.median().item() + maximum = diff.max().item() + + std_diff = diff.std().item() + min_diff = diff.min().item() + + msg = f"\nMean: {mean:.2e}, Median: {median:.2e}, Max: {maximum:.2e}, Std: {std_diff:.2e}, Min: {min_diff:.2e}" + + if mean_threshold is not None: + if mean > mean_threshold: + raise AssertionError(f"Mean diff: {mean:.2e} is greater than {mean_threshold:.2e}; {msg}") + return + + if median_threshold is not None: + if median > median_threshold: + raise AssertionError(f"Median diff: {median:.2e} is greater than {median_threshold:.2e}; {msg}") + return + + # fall back to torch.testing.assert_close + try: + torch.testing.assert_close(x, x_ref, *args, **kwargs) + except AssertionError as e: + raise AssertionError(e.args[0] + msg) + + +def assert_absolute_or_relative_close( + x: Tensor, + x_ref: Tensor, + atol: float, + rtol: float, +) -> None: + if x.numel() == 0: + assert x_ref.numel() == 0, "x_ref is not empty but x is empty" + return + + abs_diff = (x - x_ref).abs() + rel_diff = abs_diff / torch.max(x_ref.abs(), x.abs()) + + abs_pass = abs_diff <= atol + rel_pass = rel_diff <= rtol + + either_pass = abs_pass | rel_pass + if not either_pass.all(): + failed_abs_diffs = abs_diff[~either_pass] + failed_rel_diffs = rel_diff[~either_pass] + raise AssertionError( + f"Absolute diff: {failed_abs_diffs.max():.2e} is greater than {atol:.2e} or relative diff: {failed_rel_diffs.max():.2e} is greater than {rtol:.2e}\n" + f"Mean abs diff: {failed_abs_diffs.mean():.2e}, Median abs diff: {failed_abs_diffs.median():.2e}, Max abs diff: {failed_abs_diffs.max():.2e}\n" + f"Mean rel diff: {failed_rel_diffs.mean():.2e}, Median rel diff: {failed_rel_diffs.median():.2e}, Max rel diff: {failed_rel_diffs.max():.2e}" + ) + + +def hist_diff_log10_bins( + actual: torch.Tensor, + expected: torch.Tensor, + max_diff: float = 100.0, + bin_edges: torch.Tensor | None = None, + **kwargs_histogram, +): + """ + Compute histogram of absolute differences between two tensors using logarithmic bins. + + This function is useful for analyzing numerical precision differences between different + implementations (e.g., distributed vs single-device) by creating histograms of error + magnitudes on a logarithmic scale. + + Args: + actual (torch.Tensor): The actual tensor values to compare. + expected (torch.Tensor): The expected tensor values to compare against. + max_diff (float): The upper bound of the histogram bins + bin_edges (torch.Tensor | None): The edges of the histogram bins. + If None, the bins are created automatically based on max_diff and the dtype resolution. + **kwargs_histogram: Additional keyword arguments passed to torch.histogram. + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - hist: Histogram counts for each bin + - bin_edges: The edges of the logarithmic bins used for the histogram + + Raises: + ValueError: If actual or expected are not tensors, or if they are not floating point tensors. + + Note: + - The histogram bins are created on a logarithmic scale from the dtype resolution + divided by 10 up to max_diff + - The function automatically handles CUDA tensors by moving data to CPU for + histogram computation and then back to the original device + - The dtype used for computation is determined by the less precise of the two input tensors + """ + if not isinstance(actual, torch.Tensor): + raise ValueError("actual must be a tensor") + + if not isinstance(expected, torch.Tensor): + raise ValueError("expected must be a tensor") + + if not actual.is_floating_point(): + raise ValueError("actual must be a floating point tensor") + + if not expected.is_floating_point(): + raise ValueError("expected must be a floating point tensor") + + if torch.finfo(actual.dtype).resolution > torch.finfo(expected.dtype).resolution: + # actual is less precise than expected + dtype = actual.dtype + else: + # actual is as precise as expected + dtype = expected.dtype + + if bin_edges is None: + # max_diff is ignored if bin_edges is provided + base = 10 + # the lower bound of the histogram to be 1 decimal place lower than the resolution + min_diff = torch.finfo(dtype).resolution / base + min_bin = round(math.log10(min_diff)) + max_bin = max(round(math.log10(max_diff)), min_bin) + n_bins = max_bin - min_bin + 1 + bin_edges = torch.logspace(start=min_bin, end=max_bin, steps=n_bins, base=base, dtype=dtype) + # add 0 to the left-most bin + bin_edges = torch.cat([torch.tensor([0.0], dtype=dtype), bin_edges]) + + diff = (actual - expected).to(dtype).abs() + # torch.histogram doesn't work with CUDA tensor of bins so we need to histogram on CPU + # and move the result back to the original device. See + # https://github.com/pytorch/pytorch/issues/69519 for details + hist, _ = torch.histogram( + diff.cpu(), bins=bin_edges.to(dtype=dtype, device=torch.device("cpu")), **kwargs_histogram + ) + hist = hist.to(diff.device) + + return hist, bin_edges + + +def pretty_prints_hist( + hist: torch.Tensor, bin_edges: torch.Tensor, do_cumsum: bool = False, convert_to_percentage: bool = True +) -> str: + """ + Pretty print a histogram in a tabular format. Return the string as if printed to stdout. + + This function prints a formatted table showing histogram data with bin ranges + and corresponding counts/percentages. It's useful for visualizing error distributions + and numerical precision analysis. + + Args: + hist (torch.Tensor): Histogram counts for each bin. + bin_edges (torch.Tensor): The edges of the histogram bins. Should have one more + element than hist. + do_cumsum (bool, optional): If True, display cumulative sum of histogram values. + Defaults to True. + convert_to_percentage (bool, optional): If True, convert histogram counts to + percentages of total. Defaults to True. + + Returns: + str: The formatted string as if printed to stdout. + + Raises: + ValueError: If hist or bin_edges are not tensors. + + Example: + The output format shows bin ranges in the first two rows and histogram + values in the last row: + + | 1.0e-07 | 1.0e-06 | 1.0e-05 | ... + | - 1.0e-06| - 1.0e-05| - 1.0e-04| ... + +----------+----------+----------+... + | 2.5e+01% | 4.5e+01% | 7.8e+01% | ... + + Note: + - The function prints directly to stdout using print statements + - Scientific notation is used for both bin edges and histogram values + - Bin ranges are shown as "lower_bound - upper_bound" format + - Values are displayed as percentages when convert_to_percentage=True + """ + if not isinstance(hist, torch.Tensor): + raise ValueError("hist must be a tensor") + + if not isinstance(bin_edges, torch.Tensor): + raise ValueError("bin_edges must be a tensor") + + ans = "\t| " + " | ".join([f"{b.item():8.0e}" for b in bin_edges[:-1]]) + " |" + ans += "\n" + ans += "\t| " + " | ".join([f"- {b.item():6.0e}" for b in bin_edges[1:]]) + " |" + ans += "\n" + ans += "\t" + "+----------" * (bin_edges.numel() - 1) + "+" + ans += "\n" + if convert_to_percentage: + total = hist.sum() + hist = hist / total + if do_cumsum: + hist = hist.cumsum(dim=0) + + if convert_to_percentage: + ans += "\t| " + " | ".join([f"{(p * 100.0):7.1e}%" for p in hist]) + " |" + else: + ans += "\t| " + " | ".join([f"{p:8.1e}" for p in hist]) + " |" + + return ans + + +def get_param_by_key(module: torch.nn.Module, key_state_dict: str) -> Any: + """ + Retrieve a parameter or attribute from a PyTorch module using dot notation. + + This function traverses a module's nested structure using a string key with dot + notation to access deeply nested parameters or attributes. It's useful for + programmatically accessing specific layers, weights, or other attributes in + complex neural network architectures. + + Args: + module (torch.nn.Module): The PyTorch module to search within. + key_state_dict (str): The dot-separated path to the desired parameter or + attribute (e.g., "encoder.layer.0.attention.query.weight"). + + Returns: + Any: The parameter, tensor, or attribute found at the specified path. + + Raises: + ValueError: If module is not a torch.nn.Module, if key_state_dict is not + a string, or if key_state_dict is empty. + AttributeError: If the specified path does not exist in the module. + + Example: + >>> import torch.nn as nn + >>> model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 1)) + >>> weight = get_param_by_key(model, "0.weight") + >>> bias = get_param_by_key(model, "2.bias") + + Note: + - The function uses Python's reduce() with getattr() to traverse the path + - Each part of the dot-separated key must be a valid attribute name + - The function can access any attribute, not just parameters (weights/biases) + """ + if not isinstance(module, torch.nn.Module): + raise ValueError(f"module is not a torch.nn.Module: {type(module)}") + + if not isinstance(key_state_dict, str): + raise ValueError(f"key_state_dict is not a string: {type(key_state_dict)}") + + if key_state_dict == "": + raise ValueError("key_state_dict is empty") + + names = key_state_dict.split(sep=".") + return reduce(getattr, names, module) + + +def assert_no_percentile_upshift( + result: torch.Tensor, + expected: torch.Tensor, + alternative: torch.Tensor, + perc: OrderedDict[float, tuple[float, float]] | None = None, + names_input: tuple[str, str, str] | None = None, +): + """ + Assert that result tensor accuracy doesn't significantly degrade compared to alternative. + + This function validates that a computation result (e.g., distributed computation) + maintains comparable numerical accuracy to an alternative implementation (e.g., + single-precision reference) when both are compared against a high-precision expected + value. It uses percentile-based analysis to detect accuracy degradation. + + Args: + result (torch.Tensor): The tensor result to validate (e.g., from distributed computation). + expected (torch.Tensor): The high-precision reference tensor (ground truth). + alternative (torch.Tensor): The alternative implementation result to compare against + (e.g., single-precision baseline). + perc (OrderedDict[float, tuple[float, float]] | None, optional): Dictionary mapping percentile + values to (atol, rtol) tolerance tuples to be used by torch.testing.assert_close() to + check the consistency of the input percentiles (keys of the dict) of errors between the + result and alternative. If None, defaults to tolerances for 50th, 75th, + and 95th percentiles with assert_close()'s default (atol, rtol) + names_input (tuple[str, str, str] | None, optional): Names of the input tensors to be used in the error message. + If None, no names will be included in the error message. + + Raises: + AssertionError: If the result shows significant upward shift in error percentiles + compared to the alternative, indicating accuracy degradation. + + Note: + - The function allows downward shifts (better accuracy) + - Error histograms are generated using logarithmic bins for detailed analysis + - Assertion errors include histogram visualizations for debugging + - This is particularly useful for validating distributed computing implementations + maintain numerical stability compared to single-device baselines + + Example: + >>> # Validate distributed computation accuracy + >>> assert_no_percentile_upshift( + ... result=distributed_output, + ... expected=fp64_reference, + ... alternative=fp32_reference + ... ) + """ + if perc is None: + perc = OrderedDict({0.25: (None, None), 0.5: (None, None), 0.75: (None, None), 0.95: (None, None)}) + + diff_abs_result = (result - expected).abs() + diff_abs_alternative = (alternative - expected).abs() + + # torch.quantile requires input dtype to be at least fp32 + diff_abs_result = diff_abs_result.to(dtype=torch.promote_types(torch.float32, diff_abs_result.dtype)) + diff_abs_alternative = diff_abs_alternative.to(dtype=torch.promote_types(torch.float32, diff_abs_alternative.dtype)) + + percentages = list(perc.keys()) + quantiles_diff_abs_result = torch.quantile( + diff_abs_result, torch.tensor(percentages, device=result.device, dtype=diff_abs_result.dtype) + ) + quantiles_diff_abs_alternative = torch.quantile( + diff_abs_alternative, torch.tensor(percentages, device=alternative.device, dtype=diff_abs_alternative.dtype) + ) + + max_diff = max(diff_abs_result.max().item(), diff_abs_alternative.max().item()) + if max_diff == 0: # shortcircuit for zero-diff tensors to avoid math.log domain error + return + + hist_result, bin_edges = hist_diff_log10_bins(result, expected, max_diff=max_diff) + hist_alternative, _ = hist_diff_log10_bins(alternative, expected, bin_edges=bin_edges) + + str_hist_result = pretty_prints_hist(hist_result, bin_edges) + str_hist_alternative = pretty_prints_hist(hist_alternative, bin_edges) + + error_msg = "\n" + if names_input is not None: + error_msg += f"input: {names_input[0]}\n" + error_msg += f"expected: {names_input[1]}\n" + error_msg += f"alternative: {names_input[2]}\n" + error_msg += "\n" + error_msg += f"hist(result, ref):\n\n{str_hist_result}\n" + error_msg += "\n" + error_msg += f"hist(alternative, ref):\n\n{str_hist_alternative}\n" + + for i, (p, (atol, rtol)) in enumerate(perc.items()): + if quantiles_diff_abs_result[i] < quantiles_diff_abs_alternative[i]: + # Assumption: we don't care down-shifting of because they only implies + # overall higher consistency of the result than the alternative + continue + torch.testing.assert_close( + quantiles_diff_abs_result[i], + quantiles_diff_abs_alternative[i], + atol=atol, + rtol=rtol, + msg=lambda m: ( + f""" + shift at {p * 100} percentile:\n{m}\n + {error_msg} + """ + ), + ) + + +def init_tensors_uniform(tensors: list[torch.Tensor], low: float = 0.0, high: float = 1.0) -> None: + """Initialize a list of tensors with uniform distribution in-place. + + Args: + tensors: List of tensors to initialize + low: Lower bound for uniform distribution (default: 0.0) + high: Upper bound for uniform distribution (default: 1.0) + """ + with torch.no_grad(): + for tensor in tensors: + tensor.uniform_(low, high) + + +def init_tensors_normal(tensors: list[torch.Tensor], *args, **kwargs) -> None: + """Initialize a list of tensors with values drawn from a normal distribution. + + Args: + tensors: List of tensors to initialize in-place. + *args: Positional arguments forwarded to tensor.normal_(). + **kwargs: Keyword arguments forwarded to tensor.normal_(). + Common kwargs: mean (default 0), std (default 1). + """ + with torch.no_grad(): + for tensor in tensors: + tensor.normal_(*args, **kwargs) + + +def init_module_params_uniform(module: torch.nn.Module, low: float = 0.0, high: float = 1.0) -> None: + """Initialize all named parameters of a module with uniform distribution in-place. + + Args: + module: PyTorch module whose parameters to initialize + low: Lower bound for uniform distribution (default: 0.0) + high: Upper bound for uniform distribution (default: 1.0) + """ + with torch.no_grad(): + for name, param in module.named_parameters(): + param.uniform_(low, high) + + +def init_module_params_glorot(module: torch.nn.Module, gain: float = 1.0) -> None: + """Initialize parameters with Xavier/Glorot uniform distribution in-place. + + Weight tensors (dim >= 2) use ``xavier_uniform_`` scaled by *gain*. + Bias / 1-D tensors use ``uniform_(-b, b)`` with ``b = 1/sqrt(fan_out)``. + + Args: + module: PyTorch module whose parameters to initialize. + gain: Multiplicative scaling factor for ``xavier_uniform_`` + (default: 1.0). + """ + with torch.no_grad(): + for _name, param in module.named_parameters(): + if param.dim() >= 2: + torch.nn.init.xavier_uniform_(param, gain=gain) + else: + bound = 1.0 / math.sqrt(param.shape[0]) if param.shape[0] > 0 else 0.01 + param.uniform_(-bound, bound) + + +def set_dtype_specific_inf_values(module, dtype: torch.dtype) -> None: + """Set dtype-specific inf values for attention modules named + "tri_att_start", "tri_att_end", and "pair_weighted_averaging". + + Parameters + ---------- + module : torch.nn.Module + The module containing attention modules (e.g., MSAModule, MSALayer, + PairformerLayer, PairformerModule, tri_att_start, tri_att_end, + pair_weighted_averaging, etc.) + dtype : torch.dtype + The data type to determine the appropriate inf value + """ + dtype_to_inf = {torch.float32: 1e9, torch.float64: 1e18} + inf_value = dtype_to_inf.get(dtype, 1e9) + + # Handle MSAModule (contains multiple MSALayers) + if hasattr(module, "layers"): + for layer in module.layers: + # Handle both checkpoint-wrapped and non-wrapped layers + l = layer._checkpoint_wrapped_module if hasattr(layer, "_checkpoint_wrapped_module") else layer + _set_layer_inf_values(l, inf_value) + # Handle single MSALayer + else: + _set_layer_inf_values(module, inf_value) + + +def _set_layer_inf_values(layer, inf_value: float) -> None: + """Helper function to set inf values on a single layer. + + Parameters + ---------- + layer : torch.nn.Module + The layer containing attention modules (tri_att_start, tri_att_end, + pair_weighted_averaging, etc.) + inf_value : float + The inf value to set + """ + if hasattr(layer, "tri_att_start"): + layer.tri_att_start.inf = inf_value + if hasattr(layer, "tri_att_end"): + layer.tri_att_end.inf = inf_value + if hasattr(layer, "pair_weighted_averaging"): + layer.pair_weighted_averaging.inf = inf_value + if hasattr(layer, "attention"): + layer.attention.inf = inf_value + + +class SetModuleInfValues: + """A callable class that automatically sets dtype-specific infinity values for attention modules. + + This class is designed to be used as a function that can be applied to PyTorch modules + to automatically configure their infinity values based on their underlying data types. + It's particularly useful for attention mechanisms where infinity values are used for + masking operations and need to be adjusted based on the precision of the computations. + + The class supports several attention module types and automatically detects their + data types by examining specific linear layers within each module. It then sets + appropriate infinity values that are compatible with the detected dtype to avoid + numerical overflow or underflow issues. + + Attributes: + dtype_to_inf (dict): Mapping from PyTorch data types to appropriate infinity values. + - torch.float32: 1e9 (to avoid overflow in single precision) + - torch.float64: 1e18 (higher precision allows larger values) + module_types_to_layer_for_dtype (dict): Mapping from module types to the name + of the layer used to determine the module's data type. This is used to + automatically detect the dtype by examining the weight tensor of a specific + linear layer within each module type. + + Supported Module Types: + - SerialAttentionPairBias: Uses 'proj_q' layer for dtype detection + - SerialPairWeightedAveraging: Uses 'proj_m' layer for dtype detection + - SerialTriangleAttention: Uses 'linear' layer for dtype detection + - SerialTriangleAttentionEndingNode: Uses 'linear' layer for dtype detection + + Usage: + >>> inf_setter = SetModuleInfValues() + >>> model.apply(inf_setter) # Apply to all modules in a model + >>> # Or apply to a specific module + >>> inf_setter(attention_module) + + Raises: + AttributeError: If a supported module type doesn't have the expected 'inf' attribute. + TypeError: If a module contains an unsupported data type. + + Note: + This class is typically used in testing scenarios where consistent infinity + values across different data types are required for numerical stability and + reproducible results. + """ + + def __init__(self): + self.dtype_to_inf = {torch.float32: 1e9, torch.float64: 1e18} + # find each module type's layer to determine its inherent dtype + self.module_types_to_layer_for_dtype = { + SerialAttentionPairBias: "proj_q", + SerialPairWeightedAveraging: "proj_m", + SerialTriangleAttention: "linear", + SerialTriangleAttentionEndingNode: "linear", + } + + def __call__(self, module: torch.nn.Module) -> None: + """Set appropriate infinity value for the given module based on its data type. + + This method examines the module type and, if it's a supported attention module, + determines its data type by inspecting a specific linear layer. It then sets + the module's 'inf' attribute to an appropriate value based on the detected dtype. + + Args: + module (torch.nn.Module): The PyTorch module to process. If the module + type is not in the supported list, the method returns early without + making any changes. + + Returns: + None: This method modifies the module in-place by setting its 'inf' attribute. + + Raises: + AttributeError: If a supported module type doesn't have the expected 'inf' + attribute that should be set. + TypeError: If the module's detected data type is not supported (not in + dtype_to_inf mapping). + + Note: + This method is designed to be used with PyTorch's Module.apply() method + for automatic application across an entire model hierarchy. + """ + type_module = type(module) + if type_module not in self.module_types_to_layer_for_dtype: + return + + if not hasattr(module, "inf"): + raise AttributeError(f"Module {type_module.__name__} should but does not have an 'inf' attribute") + + dtype_module = getattr(module, self.module_types_to_layer_for_dtype[type_module]).weight.dtype + if dtype_module not in self.dtype_to_inf: + raise TypeError(f"Unsupported dtype {dtype_module} found in module {type_module.__name__}") + inf_value = self.dtype_to_inf[dtype_module] + setattr(module, "inf", inf_value) + + +class FixBoltzMultiplicityBug: + """A callable class that fixes the Boltz multiplicity bug by setting reorder_pair_repr_multiplex=True. + + This class is designed to be used with PyTorch's Module.apply() method to recursively + traverse a model and set the `reorder_pair_repr_multiplex` attribute to True on + `AtomTransformer` and `AttentionPairBias` modules. + + This fixes the bug described in Boltz github commit 4fa0d0a0c3090ca09e71073fdd58e4108c517382, + where the pair representation multiplicity was applied in the wrong order, causing + incorrect behavior when multiplicity > 1. + + Supported Module Types: + - AtomTransformer: Sets reorder_pair_repr_multiplex = True + - AttentionPairBias: Sets reorder_pair_repr_multiplex = True + + Usage: + >>> bug_fixer = FixBoltzMultiplicityBug() + >>> model.apply(bug_fixer) # Apply to all modules in a model + + Note: + This class is typically used in testing scenarios to ensure the multiplicity + bug is fixed for numerical correctness validation. + """ + + def __init__(self): + # Module types that need the fix + self.module_types_to_fix = (SerialAtomTransformer, SerialAttentionPairBias) + + def __call__(self, module: torch.nn.Module) -> None: + """Set reorder_pair_repr_multiplex=True for supported module types. + + Args: + module (torch.nn.Module): The PyTorch module to process. If the module + type is not in the supported list, the method returns early without + making any changes. + + Returns: + None: This method modifies the module in-place. + """ + if isinstance(module, self.module_types_to_fix): + if hasattr(module, "reorder_pair_repr_multiplex"): + if isinstance(module, SerialAttentionPairBias): + # The token variant of AttentionPairBias can't apply the fix + # because otherwise the pairbias z will have mismatching shape + # since DiffusionTransformer (token level) never apply the multiplicity + # to "z" unlike AtomTransformer (atom level) does. + module.reorder_pair_repr_multiplex = module.use_window_batching + else: + module.reorder_pair_repr_multiplex = True + + +def get_window_batch_key_indices(n_atoms_no_pad: int, W: int, H: int) -> torch.Tensor: + """Generate key indices for window-based attention batching. + + Creates an indexing matrix that maps global atom positions to key positions + within each window for window-based attention. The function pads the sequence + to the next multiple of W and generates indices for H key positions per window. + + Args: + n_atoms_no_pad (int): Number of atoms without padding. + W (int): Window size (number of queries per window). + H (int): Number of key positions per window. + + Returns: + torch.Tensor: Index tensor of shape (K, H) where K = max_atoms // W. + Each row contains the key indices for one window. Padded positions + are represented by 0. + """ + if n_atoms_no_pad % W == 0: + max_atoms = n_atoms_no_pad + else: + # pad to the next multiple of W + max_atoms = ((n_atoms_no_pad // W) + 1) * W + # construct pair mask through indexing matrices + # TODO construct pair mask directly from AF3 appendix + index = torch.arange(1, max_atoms + 1) + index[n_atoms_no_pad:] = 0 + index = index.unsqueeze(0) + + K = max_atoms // W + keys_indexing_matrix = get_indexing_matrix(K, W, H, index.device) + to_keys = partial(single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H) + index_keys = to_keys(index.unsqueeze(-1).float()).view(K, H).long() + return index_keys + + +def _pair_masked_global_to_window_batch( + pair_masked_global: torch.Tensor, n_atoms_no_pad: int, W: int = 32, H: int = 128 +) -> torch.Tensor: + """Convert a global pair representation to window-batched format. + + Transforms a global pairwise interaction matrix into a window-batched format + suitable for efficient window-based attention computation. The function uses + sparse matrix operations to efficiently handle the transformation while + accounting for padding and window alignment. + + The transformation involves: + 1. Computing key indices for each window + 2. Converting to sparse CSR format for efficient manipulation + 3. Adjusting column indices to account for window-specific padding + 4. Reconstructing as a dense tensor in window-batched format + + Args: + pair_masked_global (torch.Tensor): Input global pair representation. + Shape can be (N, N) for 2D or (N, N, D) for 3D with embedding dimension. + Must be square in the first two dimensions. Already masked with zeros + for invalid padding positions. + n_atoms_no_pad (int): Number of atoms without padding. + W (int, optional): Window size. Defaults to 32. + H (int, optional): Number of keys per window. Defaults to 128. + + Returns: + torch.Tensor: Window-batched representation of shape (K, W, H) for 2D input + or (K, W, H, D) for 3D input, where: + - K = number of windows + - W = window size + - H = number of keys per window + - D = embedding dimension (if present) + + Raises: + ValueError: If pair_masked_global is not square in the first two dimensions. + AssertionError: If pair_masked_global is not 2D or 3D. + + Note: + For batch processing, call this function separately for each batch element. + """ + # Assumption 0: input pair_masked_global is already masked with invalid padding represented by zeros + # valid values of zeros are safe to use because this function doesn't rely on the input element values. + ids_keys_per_window = get_window_batch_key_indices(n_atoms_no_pad, W, H) + + # The limitation of using CSR matrix is that each element along the batch + # dimension must have the same nnz. This won't be useful for converting mask + # in a batch but we must loop over the batch dimension and call each entry + if pair_masked_global.shape[0] != pair_masked_global.shape[1]: + raise ValueError("pair_masked_global must be square") + assert pair_masked_global.ndim == 2 or pair_masked_global.ndim == 3, ( + "pair_masked_global must be 2D with potential one trailing embedding dimension. " + "For batch dimension, loop over the batch dimension and call this function for each entry" + ) + # count number of leading zeros per window, which can only be >= 0. + # Left padding > 0 means that we need to increase the column index + # while == 0 means we need to reset the column offset so that + # the corresponding rows are left-aligned in the resulting global matrix. + # This is to match the window-batching behavior of the attention. + n_left_padding_per_window = (ids_keys_per_window.cumsum(dim=1) == 0).sum(dim=1) + n_windows_no_pad = ids_keys_per_window.shape[0] + # sparse map generation + masked_csr = pair_masked_global.to_sparse_csr(dense_dim=1 if pair_masked_global.ndim == 3 else None) + crow_ids = masked_csr.crow_indices() + col_ids_new = masked_csr.col_indices().detach().clone() + for i_window in range(n_windows_no_pad): + i_rows_begin = i_window * W + i_rows_end = min(i_rows_begin + W, crow_ids.shape[0] - 1) + inz_begin = crow_ids[i_rows_begin] + inz_end = crow_ids[i_rows_end] + n_left_padding_this_window = n_left_padding_per_window[i_window] + if n_left_padding_this_window > 0: + col_ids_new[inz_begin:inz_end] += n_left_padding_this_window + else: + # equivalent to: + # col_id_min = ids_keys_per_window[i_window].min().item() + col_id_min = col_ids_new[inz_begin:inz_end].min().item() + col_ids_new[inz_begin:inz_end] -= col_id_min + # dim 1 is always the number of atoms regardless of ndim == 2 or 3 + n_atoms_padded = pair_masked_global.shape[1] + # pad to the next multiple of W towards the end of both dimensions + n_windows_padded = (n_atoms_padded + W - 1) // W + target_length = n_windows_padded * W + 1 + current_length = crow_ids.shape[0] + if target_length > current_length: + padding_length = target_length - current_length + last_value = crow_ids[-1] + padding = last_value.repeat(padding_length) + crow_ids_new = torch.cat([crow_ids, padding]) + else: + crow_ids_new = crow_ids + masked_csr_new = torch.sparse_csr_tensor( + crow_ids_new, + col_ids_new, + masked_csr.values(), + size=(n_windows_padded * W, H, pair_masked_global.shape[-1]) + if pair_masked_global.ndim == 3 + else (n_windows_padded * W, H), + dtype=masked_csr.dtype, + device=masked_csr.device, + ) + masked_new = masked_csr_new.to_dense().unflatten(0, (n_windows_padded, W)) + return masked_new + + +@torch.no_grad() +def pair_global_to_window_batch( + pair_repr_global: torch.Tensor, + n_atoms_no_pads: torch.Tensor, + pair_mask_global: torch.Tensor | None = None, + W: int = 32, + H: int = 128, +) -> torch.Tensor: + """Convert batched global pair representations to window-batched format. + + Applies masking to global pair representations and converts each batch element + to window-batched format. This is a convenience wrapper around + pair_masked_global_to_window_batch that handles batching and masking. + + Args: + pair_repr_global (torch.Tensor): Global pair representations with shape + (B, N, N, D) where B is batch size, N is sequence length, and D is + embedding dimension. + n_atoms_no_pads (torch.Tensor): Number of atoms without padding for each batch element. + pair_mask_global (torch.Tensor | None, optional): Global pair mask with shape (B, N, N) or + (B, N, N, 1). Used to mask invalid positions. If None, no masking is applied. + Defaults to None. + W (int, optional): Window size. Defaults to 32. + H (int, optional): Number of keys per window. Defaults to 128. + + Returns: + torch.Tensor: Window-batched pair representations of shape (B, K, W, H, D) + where K = number of windows, W = window size, H = keys per window, + and D = embedding dimension. + + Note: + When pair_mask_global is provided, pair_repr_global and pair_mask_global + must be broadcastable for element-wise multiplication. + """ + size_batch = pair_repr_global.shape[0] + if n_atoms_no_pads.shape != (size_batch,): + raise ValueError(f"n_atoms_no_pads must be a 1D tensor of size {size_batch}") + if not (n_atoms_no_pads.dtype == torch.int32 or n_atoms_no_pads.dtype == torch.int64): + raise ValueError("n_atoms_no_pads must be an int32 or int64 tensor") + # the product must be broadcastable + if pair_mask_global is None: + ans_per_batch = [ + _pair_masked_global_to_window_batch(pair_repr_global[i], n_atoms_no_pads[i], W, H) + for i in range(size_batch) + ] + else: + pair_repr_global_masked = pair_repr_global * pair_mask_global + ans_per_batch = [ + _pair_masked_global_to_window_batch(pair_repr_global_masked[i], n_atoms_no_pads[i], W, H) + for i in range(size_batch) + ] + # the assumption here is that _pair_masked_global_to_window_batch will + # pad to the window-batched result to have the same number of windows for each batch element + # so we can always stack the results along the batch dimension + ans = torch.stack(ans_per_batch, dim=0) + return ans + + +def get_features( + tokenized: Tokenized, + window_batching: bool, + shard_dims: Optional[tuple[int, int]] = None, + selected_keys: Optional[list[str]] = None, +) -> dict[str, Tensor] | list[dict[str, Tensor]]: + """Get features from a tokenized object with key filtering. + + Args: + tokenized: Tokenized object + window_batching: Whether to use window batching + shard_dims: Shard dimensions to enable window batching. + selected_keys: Selected keys to return. If None, the following keys are returned: + ["atom_pad_mask", "atom_to_token", "pair_mask", "token_pad_mask"] + + Returns: + dict[str, Tensor] | list[dict[str, Tensor]]: Features of list of features if shard_dims is not None. + """ + max_atoms, max_tokens = get_num_atoms_tokens(tokenized) + if shard_dims is None: + max_atoms = None + max_tokens = None + max_seqs = 1 + else: + ring_size = shard_dims[0] + if window_batching: + pad_to_multiple_atoms = math.lcm(ring_size, 32) + else: + pad_to_multiple_atoms = ring_size + max_atoms = max_atoms + pad_to_multiple_atoms - max_atoms % pad_to_multiple_atoms + max_tokens = max_tokens + ring_size - max_tokens % ring_size + max_seqs = ring_size + + featurizer = BoltzFeaturizer() + feats = featurizer.process( + tokenized, + training=False, + augmentation=False, + pair_mask_mode=PairMaskMode.NONE if window_batching else PairMaskMode.SEQUENCE_LOCAL_ATTENTION, + max_atoms=max_atoms, + max_tokens=max_tokens, + max_seqs=max_seqs, + pad_to_max_seqs=True, + shard_dims=shard_dims, + ) + + if selected_keys is None: + selected_keys = ["atom_pad_mask", "atom_to_token", "pair_mask", "token_pad_mask"] + + if shard_dims is None: + return {k: v for k, v in feats.items() if k in selected_keys} + else: + return [{k: v for k, v in feats_shard.items() if k in selected_keys} for feats_shard in feats] + + +def get_to_keys(s_global: torch.Tensor, W: int = 32, H: int = 128) -> Callable[[torch.Tensor], torch.Tensor]: + """Get to keys function for window-based attention.""" + B, N, D = s_global.shape + # assume s_global.shape[1] has been padded to be a multiple of W + assert N % W == 0, "s_global.shape[1] must be a multiple of W" + K = N // W + indexing_matrix = get_indexing_matrix(K, W, H, s_global.device).to(s_global.dtype) + to_keys = partial(single_to_keys, indexing_matrix=indexing_matrix, W=W, H=H) + return to_keys + + +def create_msa_module_init_params(use_large_model: bool = False) -> dict[str, Any]: + """Create initialization parameters for MSAModule. + + Parameters + ---------- + use_large_model : bool + Whether to use large model parameters + + Returns + ------- + dict[str, Any] + MSAModule initialization parameters + """ + + # Get parameters from the whole set + boltz1_params = create_boltz1_model_init_params(use_large_model=use_large_model, use_window_batching=True) + + # Extract MSA-specific parameters + msa_args = boltz1_params["msa_args"].copy() + + # Calculate s_input_dim based on token_s + token_s = boltz1_params["token_s"] + msa_args["s_input_dim"] = ( + token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info) + ) # Input sequence dimension + + # Add token_z for compatibility + msa_args["token_z"] = boltz1_params["token_z"] + + return msa_args + + +def create_msa_module_init_params_v2(use_large_model: bool = False) -> dict[str, Any]: + """Create initialization parameters for Boltz-2 MSAModule (model.modules.trunkv2.MSAModule). + + Parameters + ---------- + use_large_model : bool + Whether to use large model parameters + + Returns + ------- + dict[str, Any] + MSAModule initialization parameters for Boltz-2 (token_s, msa_s, token_z, etc.) + """ + boltz1_params = create_boltz1_model_init_params(use_large_model=use_large_model, use_window_batching=True) + msa_args = boltz1_params["msa_args"].copy() + # Boltz-2 MSAModule uses token_s (single representation dim), not s_input_dim + msa_args["token_s"] = boltz1_params["token_s"] + msa_args["token_z"] = boltz1_params["token_z"] + msa_args["subsample_msa"] = False # CP does not support MSA subsampling + msa_args["num_subsampled_msa"] = 1024 + return msa_args + + +def create_pairformer_module_init_params(use_large_model: bool = False) -> dict[str, Any]: + """Create initialization parameters for PairformerModule. + + Parameters + ---------- + use_large_model : bool + Whether to use large model parameters + + Returns + ------- + dict[str, Any] + PairformerModule initialization parameters + """ + # Get parameters from the whole set + boltz1_params = create_boltz1_model_init_params(use_large_model=use_large_model, use_window_batching=True) + + # Extract Pairformer-specific parameters + pairformer_args = boltz1_params["pairformer_args"].copy() + + # Add required shared parameters + pairformer_args["token_s"] = boltz1_params["token_s"] + pairformer_args["token_z"] = boltz1_params["token_z"] + + return pairformer_args + + +def create_diffusion_module_init_params( + use_large_model: bool = False, use_window_batching: bool = True +) -> dict[str, Any]: + """Create initialization parameters for DiffusionModule. + + Parameters + ---------- + use_large_model : bool + Whether to use large model parameters + use_window_batching : bool + Whether to enable window batching + + Returns + ------- + dict[str, Any] + DiffusionModule initialization parameters + """ + # Get parameters from the whole set + boltz1_params = create_boltz1_model_init_params( + use_large_model=use_large_model, use_window_batching=use_window_batching + ) + + # Extract Diffusion-specific parameters + score_model_args = boltz1_params["score_model_args"].copy() + + # Add required shared parameters + params = { + "token_s": boltz1_params["token_s"], + "token_z": boltz1_params["token_z"], + "atom_s": boltz1_params["atom_s"], + "atom_z": boltz1_params["atom_z"], + "atoms_per_window_queries": boltz1_params["atoms_per_window_queries"], + "atoms_per_window_keys": boltz1_params["atoms_per_window_keys"], + "atom_feature_dim": boltz1_params["atom_feature_dim"], + **score_model_args, + } + + return params + + +def create_atom_diffusion_init_params( + use_large_model: bool = False, use_window_batching: bool = True +) -> dict[str, Any]: + """Create initialization parameters for AtomDiffusion. + + Parameters + ---------- + use_large_model : bool + Whether to use large model parameters + use_window_batching : bool + Whether to enable window batching + + Returns + ------- + dict[str, Any] + AtomDiffusion initialization parameters + """ + # Get parameters from the whole set + boltz1_params = create_boltz1_model_init_params( + use_large_model=use_large_model, use_window_batching=use_window_batching + ) + + # Extract AtomDiffusion-specific parameters + params = { + "score_model_args": boltz1_params["score_model_args"], + **boltz1_params["diffusion_process_args"], + "compile_score": False, # couldn't be set in the whole-model parameter but enforced here + "accumulate_token_repr": False, # couldn't be set in the whole-model parameter but enforced here + } + + return params + + +@dataclass(frozen=False) +class TrainingArgs: + recycling_steps: int + sampling_steps: int | None + diffusion_multiplicity: int + diffusion_samples: int + confidence_loss_weight: float + diffusion_loss_weight: float + distogram_loss_weight: float + + +def create_boltz1_model_init_params( + use_large_model: bool = False, use_window_batching: bool = True, activation_checkpointing: bool = True +) -> dict[str, Any]: + """Create initialization parameters for Boltz1 model. + + Parameters + ---------- + use_large_model : bool + Whether to use large model parameters + use_window_batching : bool + Whether to enable window batching + activation_checkpointing : bool + Whether to use activation checkpointing + + Returns + ------- + dict[str, Any] + Boltz1 model initialization parameters + """ + if BoltzDiffusionParams is not None: + diffusion_params = asdict(BoltzDiffusionParams()) + else: + diffusion_params = { + "gamma_0": 0.605, + "gamma_min": 1.107, + "noise_scale": 0.901, + "rho": 8, + "step_scale": 1.638, + "sigma_min": 0.0004, + "sigma_max": 160.0, + "sigma_data": 16.0, + "P_mean": -1.2, + "P_std": 1.5, + "coordinate_augmentation": True, + "alignment_reverse_diff": True, + "synchronize_sigmas": True, + "use_inference_model_cache": True, + } + + # can't set the following: + # "compile_score": False, + # "accumulate_token_repr": False, + # because model.py will set them explicitly + # then we would have duplicated keys in the input kwargs + # but the default settings of these two should work for testing + diffusion_params.update( + { + "coordinate_augmentation": False, # Turn off for deterministic testing + "alignment_reverse_diff": True, + "synchronize_sigmas": False, + "use_inference_model_cache": False, + "num_sampling_steps": None, # only relevant to sample() test but disabled otherwise to prevent accidentally enabling irrelevant code + } + ) + + if use_large_model: + atom_s = 128 + atom_z = 64 + token_s = 384 + token_z = 128 + # Distogram module parameters + num_bins = 64 + # InputEmbedder parameters + atom_encoder_depth = 2 + atom_encoder_heads = 4 + # MSAModule parameters + msa_s = 64 + msa_blocks = 4 + pairwise_head_width = 32 + pairwise_num_heads = 4 + # PairformerModule parameters + num_pairformer_blocks = 4 + num_pairformer_heads = 16 + # AtomDiffusion parameters + sigma_data = 16 + dim_fourier = 256 + atom_encoder_depth_diffusion = 3 + atom_encoder_heads_diffusion = 4 + token_transformer_depth_diffusion = 24 + token_transformer_heads_diffusion = 16 + atom_decoder_depth_diffusion = 3 + atom_decoder_heads_diffusion = 4 + conditioning_transition_layers_diffusion = 2 + activation_checkpointing_diffusion = activation_checkpointing + offload_to_cpu_diffusion = False + # Training parameters + recycling_steps = 3 + diffusion_multiplicity = 16 + diffusion_samples = 1 # not used in training step + confidence_loss_weight = 3e-3 + diffusion_loss_weight = 4.0 + distogram_loss_weight = 3e-2 + else: + atom_s = 4 + atom_z = 2 + token_s = 4 + token_z = 12 + # Distogram module parameters + num_bins = 4 + # InputEmbedder parameters + atom_encoder_depth = 1 + atom_encoder_heads = 2 + # MSAModule parameters + msa_s = 4 + msa_blocks = 1 + pairwise_head_width = 4 + pairwise_num_heads = 2 + # PairformerModule parameters + num_pairformer_blocks = 1 + num_pairformer_heads = 4 + # AtomDiffusion parameters + sigma_data = 16 + dim_fourier = 16 + atom_encoder_depth_diffusion = 1 + atom_encoder_heads_diffusion = 2 + token_transformer_depth_diffusion = 2 + token_transformer_heads_diffusion = 2 + atom_decoder_depth_diffusion = 1 + atom_decoder_heads_diffusion = 2 + conditioning_transition_layers_diffusion = 1 + activation_checkpointing_diffusion = activation_checkpointing + offload_to_cpu_diffusion = False + # Training parameters + recycling_steps = 0 + diffusion_multiplicity = 1 + diffusion_samples = 1 # not used in training step + confidence_loss_weight = 3e-3 + diffusion_loss_weight = 4.0 + distogram_loss_weight = 3e-2 + + # Diffusion loss parameters + diffusion_loss_args = { + "add_smooth_lddt_loss": True, + "nucleotide_loss_weight": 5.0, + "ligand_loss_weight": 10.0, + } + + params = { + "atom_s": atom_s, + "atom_z": atom_z, + "token_s": token_s, + "token_z": token_z, + "num_bins": num_bins, + "use_window_batching": use_window_batching, + "embedder_args": { + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "activation_checkpointing": activation_checkpointing, + "activation_checkpointing_pair_repr": activation_checkpointing, + }, + "msa_args": { + "msa_s": msa_s, + "msa_blocks": msa_blocks, + "msa_dropout": 0.0, + "z_dropout": 0.0, + "pairwise_head_width": pairwise_head_width, + "pairwise_num_heads": pairwise_num_heads, + "use_paired_feature": True, + "activation_checkpointing": activation_checkpointing, + "activation_checkpointing_pair_repr": activation_checkpointing, + "offload_to_cpu": False, + }, + "pairformer_args": { + "num_blocks": num_pairformer_blocks, + "num_heads": num_pairformer_heads, + "pairwise_head_width": pairwise_head_width, + "pairwise_num_heads": pairwise_num_heads, + "dropout": 0.0, + "no_update_s": False, + "no_update_z": False, + "activation_checkpointing": activation_checkpointing, + "activation_checkpointing_pair_repr": activation_checkpointing, + "offload_to_cpu": False, + }, + "atom_feature_dim": 389, # hardcoded hidden dim from featurizer stacking multiple features + "atoms_per_window_queries": 32 if use_window_batching else None, + "atoms_per_window_keys": 128 if use_window_batching else None, + "no_msa": False, + "no_atom_encoder": False, + "min_dist": 2.0, + "max_dist": 22.0, + "do_activation_chunking": False, + "score_model_args": { + "sigma_data": sigma_data, + "dim_fourier": dim_fourier, + "atom_encoder_depth": atom_encoder_depth_diffusion, + "atom_encoder_heads": atom_encoder_heads_diffusion, + "token_transformer_depth": token_transformer_depth_diffusion, + "token_transformer_heads": token_transformer_heads_diffusion, + "atom_decoder_depth": atom_decoder_depth_diffusion, + "atom_decoder_heads": atom_decoder_heads_diffusion, + "conditioning_transition_layers": conditioning_transition_layers_diffusion, + "activation_checkpointing": activation_checkpointing_diffusion, + "activation_checkpointing_pair_repr": activation_checkpointing_diffusion, + "offload_to_cpu": offload_to_cpu_diffusion, + }, + "diffusion_process_args": diffusion_params, + # TODO: support validation, confidence and steering args + "training_args": TrainingArgs( + recycling_steps=recycling_steps, + sampling_steps=None, # not used in training step + diffusion_multiplicity=diffusion_multiplicity, + diffusion_samples=diffusion_samples, # not used in training step + confidence_loss_weight=confidence_loss_weight, + diffusion_loss_weight=diffusion_loss_weight, + distogram_loss_weight=distogram_loss_weight, + ), + "validation_args": {}, + "diffusion_loss_args": diffusion_loss_args, + "confidence_model_args": {}, + "steering_args": {}, + } + params["score_model_args"].update( + { + "atom_s": atom_s, + "atom_z": atom_z, + "token_s": token_s, + "token_z": token_z, + "atom_feature_dim": 389, # hardcoded hidden dim from featurizer stacking multiple features + "atoms_per_window_queries": 32 if use_window_batching else None, + "atoms_per_window_keys": 128 if use_window_batching else None, + } + ) + + return params + + +class DictNamespace: + """Picklable namespace with both attribute access and dict-style ``.get()``.""" + + def __init__(self, **kwargs: Any) -> None: + self.__dict__.update(kwargs) + + def get(self, key: str, default: Any = None) -> Any: + return self.__dict__.get(key, default) + + +def create_boltz2_model_init_params( + use_large_model: bool = False, + activation_checkpointing: bool = False, +) -> dict[str, Any]: + """Create initialization parameters for Boltz2 model. + + Parameters + ---------- + use_large_model : bool + Whether to use large model parameters (closer to production config). + activation_checkpointing : bool + Whether to use activation checkpointing. + + Returns + ------- + dict[str, Any] + Boltz2 model initialization parameters. + """ + diffusion_process_args = { + "sigma_min": 0.0004, + "sigma_max": 160.0, + "sigma_data": 16.0, + "rho": 7, + "P_mean": -1.2, + "P_std": 1.5, + "gamma_0": 0.8, + "gamma_min": 1.0, + "noise_scale": 1.0, + "step_scale": 1.0, + "coordinate_augmentation": False, + "alignment_reverse_diff": True, + "synchronize_sigmas": False, + } + + if use_large_model: + atom_s = 128 + atom_z = 16 + token_s = 384 + token_z = 128 + num_bins = 64 + atom_encoder_depth = 3 + atom_encoder_heads = 4 + msa_s = 64 + msa_blocks = 4 + pairwise_head_width = 32 + pairwise_num_heads = 4 + num_pairformer_blocks = 4 + num_pairformer_heads = 16 + sigma_data = 16 + dim_fourier = 256 + atom_encoder_depth_diffusion = 3 + atom_encoder_heads_diffusion = 4 + token_transformer_depth_diffusion = 24 + token_transformer_heads_diffusion = 16 + atom_decoder_depth_diffusion = 3 + atom_decoder_heads_diffusion = 4 + conditioning_transition_layers_diffusion = 2 + else: + atom_s = 4 + atom_z = 2 + token_s = 4 + token_z = 12 + num_bins = 4 + atom_encoder_depth = 1 + atom_encoder_heads = 2 + msa_s = 4 + msa_blocks = 1 + pairwise_head_width = 4 + pairwise_num_heads = 2 + num_pairformer_blocks = 1 + num_pairformer_heads = 4 + sigma_data = 16 + dim_fourier = 16 + atom_encoder_depth_diffusion = 1 + atom_encoder_heads_diffusion = 2 + token_transformer_depth_diffusion = 2 + token_transformer_heads_diffusion = 2 + atom_decoder_depth_diffusion = 1 + atom_decoder_heads_diffusion = 2 + conditioning_transition_layers_diffusion = 1 + + training_args = DictNamespace( + recycling_steps=3 if use_large_model else 0, + sampling_steps=None, + sampling_steps_random=None, + diffusion_multiplicity=2, + diffusion_samples=1, + confidence_loss_weight=0.0, + diffusion_loss_weight=4.0, + distogram_loss_weight=3e-2, + bfactor_loss_weight=0.0, + symmetry_correction=False, + adam_beta_1=0.9, + adam_beta_2=0.95, + adam_eps=1e-8, + base_lr=1e-3, + max_lr=1e-3, + lr_scheduler="af3", + lr_warmup_no_steps=10, + lr_start_decay_after_n_steps=100, + lr_decay_every_n_steps=50000, + lr_decay_factor=0.95, + weight_decay=0.0, + ) + + validation_args = DictNamespace( + recycling_steps=0, + sampling_steps=2, + diffusion_samples=1, + symmetry_correction=False, + run_confidence_sequentially=False, + ) + + params: dict[str, Any] = { + "atom_s": atom_s, + "atom_z": atom_z, + "token_s": token_s, + "token_z": token_z, + "num_bins": num_bins, + "atom_feature_dim": 388, + "atoms_per_window_queries": 32, + "atoms_per_window_keys": 128, + "embedder_args": { + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "activation_checkpointing": activation_checkpointing, + "add_mol_type_feat": True, + "add_method_conditioning": True, + "add_modified_flag": True, + "add_cyclic_flag": True, + }, + "msa_args": { + "msa_s": msa_s, + "msa_blocks": msa_blocks, + "msa_dropout": 0.0, + "z_dropout": 0.0, + "pairwise_head_width": pairwise_head_width, + "pairwise_num_heads": pairwise_num_heads, + "use_paired_feature": True, + "activation_checkpointing": activation_checkpointing, + }, + "pairformer_args": { + "num_blocks": num_pairformer_blocks, + "num_heads": num_pairformer_heads, + "dropout": 0.0, + "v2": True, + "post_layer_norm": False, + "activation_checkpointing": activation_checkpointing, + }, + "score_model_args": { + "sigma_data": sigma_data, + "dim_fourier": dim_fourier, + "atom_encoder_depth": atom_encoder_depth_diffusion, + "atom_encoder_heads": atom_encoder_heads_diffusion, + "token_transformer_depth": token_transformer_depth_diffusion, + "token_transformer_heads": token_transformer_heads_diffusion, + "atom_decoder_depth": atom_decoder_depth_diffusion, + "atom_decoder_heads": atom_decoder_heads_diffusion, + "conditioning_transition_layers": conditioning_transition_layers_diffusion, + "activation_checkpointing": activation_checkpointing, + "transformer_post_ln": False, + }, + "diffusion_process_args": diffusion_process_args, + "diffusion_loss_args": { + "add_smooth_lddt_loss": True, + "nucleotide_loss_weight": 5.0, + "ligand_loss_weight": 10.0, + }, + "training_args": training_args, + "validation_args": validation_args, + "confidence_prediction": False, + "affinity_prediction": False, + "structure_prediction_training": True, + "validate_structure": False, + "use_templates": False, + "predict_bfactor": True, + "bond_type_feature": True, + "steering_args": None, + "confidence_model_args": { + "num_dist_bins": 64, + "max_dist": 22, + "add_s_to_z_prod": True, + "add_s_input_to_s": True, + "add_z_input_to_z": True, + "conditioning_cutoff_min": 4.0, + "conditioning_cutoff_max": 20.0, + "confidence_args": { + "num_plddt_bins": 50, + "num_pde_bins": 64, + "num_pae_bins": 64, + }, + }, + "ema": False, + } + return params + + +def random_features( + size_batch: int, + n_tokens: int, + n_atoms: int, + n_msa: int, + atom_counts_per_token_range: tuple[int, int], + device: torch.device, + float_value_range: tuple[float, float], + selected_keys: Optional[list[str]] = None, + num_disto_bins: int = 64, + rng: Optional[torch.Generator] = None, +) -> dict[str, torch.Tensor]: + """Generate random feature tensors matching the shapes and dtypes of selected_keys features. + + NOTE: This function uses all-valid masks (mask with all True) for token, atoms and MSA + NOTE: the returned tensors dtype for floating point features is torch.float64 + NOTE: constraints and symmetry features are not supported yet + + Parameters + ---------- + size_batch : int + Batch size + n_tokens : int + Number of tokens + n_atoms : int + Number of atoms + n_msa : int + Number of MSA sequences + atom_counts_per_token_range : tuple[int, int] + Range for number of atoms per token (min, max) + device : torch.device + Device to create tensors on + float_value_range : tuple[float, float] + Range for float values (min, max) + selected_keys : Optional[list[str]] + If provided, only return features for these keys. If None, return all features. + num_disto_bins : int + Number of bins for distogram target. Default is 64. + rng : Optional[torch.Generator] + Optional random generator for deterministic sampling without modifying global RNG state. + + Returns + ------- + dict[str, torch.Tensor] + Dictionary of randomly generated feature tensors + """ + features = {} + + # Generate atom_counts_per_token first to ensure proper atom_to_token mapping + min_atoms, max_atoms = atom_counts_per_token_range + if min_atoms < 1: + raise ValueError(f"min_atoms must be >= 1, got {min_atoms}") + if max_atoms < min_atoms: + raise ValueError(f"max_atoms ({max_atoms}) must be >= min_atoms ({min_atoms})") + + # For now, to avoid collating different samples of different atom counts per token, we + # generate the same atom counts per token for all samples. + atom_counts_per_token = ( + torch.randint(min_atoms, max_atoms + 1, (n_tokens,), dtype=torch.int64, device=device, generator=rng) + .unsqueeze(0) + .repeat_interleave(size_batch, 0) + ) + # Ensure rewrite-path coverage for confidence.compute_frame_pred: + # when possible, force the last 3 tokens to be a contiguous NONPOLYMER segment + # (each has one atom, so the chain has >=3 atoms in total). + force_nonpolymer_tail = min_atoms <= 1 and n_tokens >= 4 and n_atoms >= n_tokens + if force_nonpolymer_tail: + atom_counts_per_token[:, -3:] = 1 + + features["atom_counts_per_token"] = atom_counts_per_token + + # Ensure total atoms match n_atoms by adjusting one anchor token. + # If we force a nonpolymer tail, keep the last 3 tokens fixed at 1 atom each. + anchor_idx = n_tokens - 4 if force_nonpolymer_tail else n_tokens - 1 + current_total_except_anchor = atom_counts_per_token.sum(dim=1) - atom_counts_per_token[:, anchor_idx] + if (current_total_except_anchor >= n_atoms).any(): + raise ValueError( + f"Total atoms {current_total_except_anchor} excluding anchor token {anchor_idx} exceeds n_atoms {n_atoms}" + ) + atom_counts_per_token[:, anchor_idx] = n_atoms - current_total_except_anchor + + # Create atom_to_token one-hot mapping based on atom_counts_per_token + atom_to_token_ccol_ids = torch.zeros((size_batch, n_tokens + 1), dtype=torch.int64, device=device) + atom_to_token_ccol_ids[:, 1:] = atom_counts_per_token.cumsum(dim=1) + atom_to_token_row_ids = ( + torch.arange(n_atoms, dtype=torch.int64, device=device).unsqueeze(0).repeat_interleave(size_batch, 0) + ) + atom_to_token_values = torch.ones_like(atom_to_token_row_ids, dtype=torch.int64, device=device) + atom_to_token_csc = torch.sparse_csc_tensor( + atom_to_token_ccol_ids, + atom_to_token_row_ids, + atom_to_token_values, + size=(size_batch, n_atoms, n_tokens), + dtype=torch.int64, + device=device, + ) + atom_to_token = atom_to_token_csc.to_dense() + + # TODO: support heterogeneous atom counts per token across samples in a batch + assert (atom_to_token[0] == atom_to_token).all(), "atom_to_token is not identical across samples in a batch" + + # Create token_to_rep_atom one-hot mapping (each token picks one representative atom randomly) + # token_to_rep_atom has shape (size_batch, n_tokens, n_atoms) + # For token i, we pick a random atom from the atoms it owns: [cumsum[i], cumsum[i+1]) + token_atom_start = atom_to_token_ccol_ids[:, :-1] # (size_batch, n_tokens) - start index of atoms for each token + + # For each token, pick a random offset within its atom range [0, atom_count) + # Since atom_counts_per_token is identical across batch, we can vectorize + # Note: min_atoms >= 1 is enforced above, so max_count is always > 0 + max_count = atom_counts_per_token.max().item() + random_offsets = torch.randint(0, max_count, (size_batch, n_tokens), device=device, generator=rng) + # Clamp offsets to be within [0, count-1] for each token + random_offsets = random_offsets % atom_counts_per_token + + # Representative atom index = start + offset + rep_atom_indices = token_atom_start + random_offsets # (size_batch, n_tokens) + + # Create one-hot token_to_rep_atom using one_hot + token_to_rep_atom = torch.nn.functional.one_hot(rep_atom_indices, num_classes=n_atoms) + + features["token_to_rep_atom"] = token_to_rep_atom + + # Create r_set_to_rep_atom: randomly select subset of tokens as R-set elements + # For each R-set element, assign a representative atom from within that token's atoms + # + # Technical notes: + # - N_R is typically N_resolved_polymer_tokens; we simulate with a random subset of tokens + # - r_set_to_rep_atom is stored as one-hot tensor [B, N_R, N_atoms] + # - The featurizer preserves this format and shards it as diagonal blocks aligned with + # token sharding. This enables local einsum for atom-to-R-set coordinate mapping + # without cross-shard communication in the distributed plddt_loss. + # - R-set token indices are identical across batch (intentional for consistent sharding) + # + n_r = max( + 1, + n_tokens - torch.randint(0, max(1, n_tokens // 4), (1,), device=device, generator=rng).item(), + ) # Random N_R <= n_tokens + + # Randomly select which tokens are in the R-set (sorted for consistency across batch) + r_set_token_indices = torch.randperm(n_tokens, device=device, generator=rng)[:n_r].sort().values + + # Vectorized: get atom start indices and counts for R-set tokens + # atom_to_token_ccol_ids shape: [B, n_tokens + 1] (cumulative column indices) + r_set_atom_start = atom_to_token_ccol_ids[:, r_set_token_indices] # [B, N_R] + r_set_atom_end = atom_to_token_ccol_ids[:, r_set_token_indices + 1] # [B, N_R] + r_set_atom_counts = r_set_atom_end - r_set_atom_start # [B, N_R] + + # Generate random offsets within each token's atom range + max_atoms_in_token = r_set_atom_counts.max().item() + if max_atoms_in_token > 0: + r_offsets = torch.randint( + 0, max(1, max_atoms_in_token), (size_batch, n_r), device=device, dtype=torch.int64, generator=rng + ) + # Clamp to valid range using modulo (handles varying atom counts per token) + r_offsets = r_offsets % torch.clamp(r_set_atom_counts, min=1) + else: + r_offsets = torch.zeros((size_batch, n_r), device=device, dtype=torch.int64) + + # Compute representative atom indices + r_set_rep_atom_indices = r_set_atom_start + r_offsets # [B, N_R] + + # Create one-hot tensor using F.one_hot + r_set_to_rep_atom = torch.nn.functional.one_hot(r_set_rep_atom_indices, num_classes=n_atoms).to( + dtype=torch.float64 + ) # [B, N_R, N_atoms] + + features["r_set_to_rep_atom"] = r_set_to_rep_atom + + # Extract float range values + min_val, max_val = float_value_range + + # Core features for InputEmbedder + features["atom_pad_mask"] = torch.ones((size_batch, n_atoms), dtype=torch.float64, device=device) + features["atom_to_token"] = atom_to_token + features["pair_mask"] = ( + get_pair_mask(n_atoms).unsqueeze(0).repeat(size_batch, 1, 1).to(dtype=torch.float64, device=device) + ) + features["token_pad_mask"] = torch.ones((size_batch, n_tokens), dtype=torch.float64, device=device) + features["ref_pos"] = torch.empty(size_batch, n_atoms, 3, dtype=torch.float64, device=device).uniform_( + min_val, max_val, generator=rng + ) + features["ref_charge"] = torch.randint(-3, 4, (size_batch, n_atoms), dtype=torch.int8, device=device, generator=rng) + features["ref_element"] = torch.randint( + 0, + const.num_elements, + (size_batch, n_atoms, const.num_elements), + dtype=torch.int64, + device=device, + generator=rng, + ) + features["ref_atom_name_chars"] = torch.randint( + 0, 64, (size_batch, n_atoms, 4, 64), dtype=torch.int64, device=device, generator=rng + ) + features["ref_space_uid"] = torch.randint( + 0, 100, (size_batch, n_atoms), dtype=torch.int64, device=device, generator=rng + ) + features["res_type"] = torch.randint( + 0, const.num_tokens, (size_batch, n_tokens, const.num_tokens), dtype=torch.int64, device=device, generator=rng + ) + features["profile"] = torch.empty( + size_batch, n_tokens, const.num_tokens, dtype=torch.float64, device=device + ).uniform_(min_val, max_val, generator=rng) + features["deletion_mean"] = torch.empty(size_batch, n_tokens, dtype=torch.float64, device=device).uniform_( + min_val, max_val, generator=rng + ) + features["pocket_feature"] = torch.randint( + 0, 2, (size_batch, n_tokens, 4), dtype=torch.int64, device=device, generator=rng + ) + + # Additional features for AtomDiffusion + features["atom_resolved_mask"] = torch.ones((size_batch, n_atoms), dtype=torch.float64, device=device) + + # Additional features for MSA module + features["msa"] = torch.randint( + 0, + const.num_tokens, + (size_batch, n_msa, n_tokens, const.num_tokens), + dtype=torch.int64, + device=device, + generator=rng, + ) + features["has_deletion"] = torch.randint( + 0, 2, (size_batch, n_msa, n_tokens), dtype=torch.bool, device=device, generator=rng + ) + features["deletion_value"] = torch.empty(size_batch, n_msa, n_tokens, dtype=torch.float64, device=device).uniform_( + min_val, max_val, generator=rng + ) + features["msa_paired"] = torch.empty(size_batch, n_msa, n_tokens, dtype=torch.float64, device=device).uniform_( + min_val, max_val, generator=rng + ) + features["msa_mask"] = torch.ones((size_batch, n_msa, n_tokens), dtype=torch.int64, device=device) + + # Additional features for Boltz1 + # NOTE: token_bonds typically is binary but in the model workflow will go thru linear projection so it's float + # anyway + features["token_bonds"] = torch.empty( + size_batch, n_tokens, n_tokens, 1, dtype=torch.float64, device=device + ).uniform_(min_val, max_val, generator=rng) + features["type_bonds"] = torch.randint( + 0, len(const.bond_types), (size_batch, n_tokens, n_tokens), dtype=torch.long, device=device, generator=rng + ) + features["token_pair_pad_mask"] = features["token_pad_mask"][:, :, None] * features["token_pad_mask"][:, None, :] + + # Additional features for RelativePositionEncoder + features["residue_index"] = torch.randint( + 0, 1000, (size_batch, n_tokens), dtype=torch.int64, device=device, generator=rng + ) + features["entity_id"] = torch.randint( + 0, 10, (size_batch, n_tokens), dtype=torch.int64, device=device, generator=rng + ) + features["token_index"] = ( + torch.arange(n_tokens, dtype=torch.int64, device=device).unsqueeze(0).expand(size_batch, -1).contiguous() + ) + features["cyclic_period"] = torch.randint( + 0, 100, (size_batch, n_tokens), dtype=torch.int32, device=device, generator=rng + ) + + # Additional features for AtomDiffusion + features["coords"] = torch.randn(size_batch, 1, n_atoms, 3, dtype=torch.float64, device=device, generator=rng) + disto_target = torch.randint(0, num_disto_bins, (size_batch, n_tokens, n_tokens), device=device, generator=rng) + features["disto_target"] = torch.nn.functional.one_hot(disto_target, num_classes=num_disto_bins).to( + dtype=torch.float64 + ) + features["token_disto_mask"] = torch.randint( + 0, 2, (size_batch, n_tokens), dtype=torch.float64, device=device, generator=rng + ) + + # Additional features for confidence module + features["atom_resolved_mask"] = torch.randint( + 0, 2, (size_batch, n_atoms), dtype=torch.float64, device=device, generator=rng + ) + + # Ensure chain-type consistency: + # - non-polymer tokens have exactly 1 atom + # - polymer tokens have > 3 atoms and are PROTEIN/DNA/RNA + atom_counts_per_token = features["atom_to_token"].sum(dim=1).to(torch.int64) + nonpolymer_flags = atom_counts_per_token == 1 + polymer_type_ids = torch.tensor( + [ + const.chain_type_ids["PROTEIN"], + const.chain_type_ids["DNA"], + const.chain_type_ids["RNA"], + ], + dtype=torch.int64, + device=device, + ) + polymer_type_idx = torch.randint( + 0, polymer_type_ids.numel(), (size_batch, n_tokens), dtype=torch.int64, device=device, generator=rng + ) + features["mol_type"] = polymer_type_ids[polymer_type_idx] + features["mol_type"][nonpolymer_flags] = const.chain_type_ids["NONPOLYMER"] + if force_nonpolymer_tail: + features["mol_type"][:, -3:] = const.chain_type_ids["NONPOLYMER"] + if n_tokens > 3: + features["mol_type"][:, -4] = const.chain_type_ids["PROTEIN"] + + # Assign asym_id by contiguous mol_type segments per batch. + # E.g. [0, 0, 0, 1, 1, 0, 2, 2] -> [0, 0, 0, 1, 1, 2, 3, 3] + asym_id = torch.zeros((size_batch, n_tokens), device=device, dtype=torch.int64) + if n_tokens > 1: + mol_changes = features["mol_type"][:, 1:] != features["mol_type"][:, :-1] + asym_id[:, 1:] = mol_changes.to(torch.int64).cumsum(dim=1) + features["asym_id"] = asym_id + + # Fuse consecutive asym_id values into shared sym_id buckets per batch. + # E.g. asym_id unique [0,1,2,3,4] -> sym_id [0,0,1,1,2] + sym_id = torch.empty_like(asym_id) + for batch_idx in range(size_batch): + unique_asym = torch.unique(asym_id[batch_idx]) + sym_map = torch.arange(unique_asym.numel(), device=device) // 2 + for asym_value, sym_value in zip(unique_asym.tolist(), sym_map.tolist()): + sym_id[batch_idx][asym_id[batch_idx] == asym_value] = sym_value + features["sym_id"] = sym_id + + if not (features["mol_type"] == const.chain_type_ids["NONPOLYMER"]).any(): + warnings.warn( + "No non-polymer token is created in random features generation.", + stacklevel=2, + ) + + # Build frames_idx from atom_to_token using token start offsets. + # For each token, sample 3 distinct local offsets within its atom range, + # then map to global indices. + max_atom_counts_per_token = int(atom_counts_per_token.max().item()) + offset_ids = torch.arange(max_atom_counts_per_token, device=device).view(1, 1, -1) + rand_scores = torch.rand(size_batch, n_tokens, max_atom_counts_per_token, device=device, generator=rng) + invalid_offsets = offset_ids >= atom_counts_per_token.unsqueeze(-1) + rand_scores = rand_scores.masked_fill(invalid_offsets, float("inf")) + offsets = torch.argsort(rand_scores, dim=-1)[..., :3] + frames_idx = token_atom_start.unsqueeze(-1) + offsets + + # Match featurizer behavior for tokens with fewer than 3 atoms: + # use the first atom of the token for all three frame slots. + small_token_mask = atom_counts_per_token.unsqueeze(-1) < 3 + if small_token_mask.any(): + token_start_triplet = token_atom_start.unsqueeze(-1).expand(-1, -1, 3) + frames_idx = torch.where(small_token_mask, token_start_triplet, frames_idx) + features["frames_idx"] = frames_idx + + # Derive frame_resolved_mask: True when all 3 frame atoms are resolved. + # Matches the featurizer logic in compute_frames_nonpolymer which sets + # resolved_frame_data[t] = resolved_mask[frames[t]].all(). + batch_expand = torch.arange(size_batch, device=device).view(-1, 1, 1).expand_as(frames_idx) + frame_atoms_resolved = features["atom_resolved_mask"][batch_expand, frames_idx] # (B, T, 3) + features["frame_resolved_mask"] = frame_atoms_resolved.prod(dim=-1) # (B, T) + + # is_nonpolymer_with_frame: True for non-polymer tokens in chains with >= 3 atoms + atoms_per_chain = torch.zeros_like(asym_id, dtype=torch.int64) + atoms_per_chain.scatter_add_(1, asym_id, atom_counts_per_token) + chain_total_atoms = atoms_per_chain.gather(1, asym_id) + features["is_nonpolymer_with_frame"] = nonpolymer_flags & (chain_total_atoms >= 3) + + # Boltz-2 specific features (unconditionally generated, opt-in via selected_keys) + num_cc_types = len(const.contact_conditioning_info) + rand_type = torch.rand(size_batch, n_tokens, n_tokens, device=device, generator=rng) + cc = torch.zeros(size_batch, n_tokens, n_tokens, num_cc_types, dtype=torch.float64, device=device) + cc[:, :, :, 0] = (rand_type < 0.3).to(torch.float64) + cc[:, :, :, 1] = ((rand_type >= 0.3) & (rand_type < 0.5)).to(torch.float64) + if num_cc_types > 4: # noqa: PLR2004 + cc[:, :, :, 4] = (rand_type >= 0.5).to(torch.float64) + features["contact_conditioning"] = cc + features["contact_threshold"] = ( + torch.rand(size_batch, n_tokens, n_tokens, device=device, generator=rng) * 22.0 + ).to(torch.float64) + features["method_feature"] = torch.randint( + 0, const.num_method_types, (size_batch, n_tokens), dtype=torch.int64, device=device, generator=rng + ) + features["modified"] = torch.randint(0, 2, (size_batch, n_tokens), dtype=torch.int64, device=device, generator=rng) + features["bfactor"] = torch.empty(size_batch, n_atoms, dtype=torch.float64, device=device).uniform_( + min_val, max_val, generator=rng + ) + features["plddt"] = torch.empty(size_batch, n_atoms, dtype=torch.float64, device=device).uniform_( + 0.0, 1.0, generator=rng + ) + + # Return only selected features if specified + if selected_keys is not None: + return {k: v for k, v in features.items() if k in selected_keys} + else: + return features + + +def get_features_shardable( + tokenized: Tokenized, + pair_mask_mode: PairMaskMode, + return_shards: bool, + shard_dims: tuple[int, int], + selected_keys: Optional[list[str]] = None, + **kwargs_feat_process: dict[str, Any], +) -> dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]: + """Get features from a tokenized object with sharding support. + + Args: + tokenized: Tokenized object + pair_mask_mode: Pair mask mode to use + return_shards: Whether to return shards + shard_dims: Shard dimensions to enable sharding + selected_keys: Selected keys to return. If None, the following keys are returned: + ["atom_pad_mask", "atom_to_token", "pair_mask", "token_pad_mask"] + + Returns: + dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]: Features or list of features if return_shards is True. + """ + if return_shards and shard_dims is None: + raise ValueError("shard_dims must be provided if return_shards is True") + + num_atoms, num_tokens = get_num_atoms_tokens(tokenized) + # always pad to the next multiple of the ring size + ring_size = shard_dims[0] + # max_atoms would need to be padded to the least common multiple of ring_size and atoms_per_window_queries=32 + atom_counts_lcm = math.lcm(ring_size, 32) + max_atoms = ((num_atoms + atom_counts_lcm - 1) // atom_counts_lcm) * atom_counts_lcm + max_tokens = ((num_tokens + ring_size - 1) // ring_size) * ring_size + max_seqs = ring_size + + featurizer = BoltzFeaturizer() + feats = featurizer.process( + tokenized, + training=False, + augmentation=False, + pair_mask_mode=pair_mask_mode, + max_atoms=max_atoms, + max_tokens=max_tokens, + max_seqs=max_seqs, + pad_to_max_seqs=True, + shard_dims=shard_dims if return_shards else None, + **kwargs_feat_process, + ) + + if selected_keys is None: + selected_keys = ["atom_pad_mask", "atom_to_token", "pair_mask", "token_pad_mask"] + + if return_shards: + return [{k: v for k, v in feats_shard.items() if k in selected_keys} for feats_shard in feats] + else: + return {k: v for k, v in feats.items() if k in selected_keys} + + +def get_feature_placements( + token_keys: Optional[set[str]] = None, + msa_keys: Optional[set[str]] = None, + atom_keys: Optional[set[str]] = None, + model_io_keys: Optional[set[str]] = None, + model_io_fp32_keys: Optional[set[str]] = None, +): + """Get comprehensive feature placement definitions for distributed testing. + + Args: + token_keys: Subset of token feature keys to include. If None, include all token features. + msa_keys: Subset of MSA feature keys to include. If None, include all MSA features. + atom_keys: Subset of atom feature keys to include. If None, include all atom features. + model_io_keys: Subset of model I/O keys to include. If None, include all model I/O features. + model_io_fp32_keys: Subset of FP32 model I/O keys to include. If None, include all FP32 model I/O features. + + Returns: + dict: Dictionary containing all placement definitions organized by category. + Contains both 3-tuple placements (for main device mesh) and 2-tuple cp placements + (for cp submesh with preexisting batch dimension). + """ + + # Base placement patterns + placements_single = (Shard(0), Shard(1), Replicate()) + placements_pair = (Shard(0), Shard(1), Shard(2)) + placements_scalar = (Shard(0), Replicate(), Replicate()) + + # Base placement patterns for cp submesh with preexisting batch dimension + placements_cp_single = (Shard(0), Replicate()) + placements_cp_pair = (Shard(0), Shard(1)) + + # Helper function to convert 3-tuple placements to 2-tuple cp placements + def convert_to_cp_placement(placement): + if placement == placements_single: + return placements_cp_single + elif placement == placements_pair: + return placements_cp_pair + elif placement == (Shard(0), Shard(2), Replicate()): + # Special case for coords in atom_features + return (Shard(1), Replicate()) + else: + raise ValueError(f"Unsupported placement pattern: {placement}") + + # 3-tuple placements for main device mesh + placements_token_features_full = OrderedDict( + { + # Core features for InputEmbedder + "token_pad_mask": placements_single, + "res_type": placements_single, + "mol_type": placements_single, + "pocket_feature": placements_single, + # Additional features for Boltz1 + "token_bonds": placements_pair, + "type_bonds": placements_pair, + "token_pair_pad_mask": placements_pair, # this isn't returned with "window_batching = True" + # Additional features for RelativePositionEncoder + "asym_id": placements_single, + "residue_index": placements_single, + "entity_id": placements_single, + "token_index": placements_single, + "sym_id": placements_single, + "cyclic_period": placements_single, + # Additional features for distogram loss + "disto_target": placements_pair, + "token_disto_mask": placements_single, + "disto_coords_ensemble": placements_single, + # Boltz-2 token features + "method_feature": placements_single, + "modified": placements_single, + "contact_conditioning": placements_pair, + "contact_threshold": placements_pair, + "frames_idx": placements_single, + } + ) + + placements_msa_features_full = OrderedDict( + { + # Additional features for MSA module + "msa": placements_pair, + "has_deletion": placements_pair, + "deletion_value": placements_pair, + "msa_paired": placements_pair, + "msa_mask": placements_pair, + # Core features for InputEmbedder (MSA-derived) + "profile": placements_single, + "deletion_mean": placements_single, + } + ) + + placements_atom_features_full = OrderedDict( + { + # Core features for InputEmbedder + "atom_pad_mask": placements_single, + "atom_to_token": placements_single, + "pair_mask": placements_pair, # this isn't returned with "window_batching = True" + "ref_pos": placements_single, + "ref_charge": placements_single, + "ref_element": placements_single, + "ref_atom_name_chars": placements_single, + "ref_space_uid": placements_single, + "atom_resolved_mask": placements_single, + # Additional features for AtomDiffusion + # Original Boltz-1x code processes "coords" to shape (B, 1, n_atoms, 3) + "coords": (Shard(0), Shard(2), Replicate()), + # Boltz-2 atom features + "token_to_rep_atom": placements_single, + "bfactor": placements_single, + "plddt": placements_single, + } + ) + + placements_model_io_full = OrderedDict( + { + "noise": placements_single, + "denoised_atom_coords": placements_single, + "d_denoised_atom_coords": placements_single, + "aligned_true_atom_coords": placements_single, + # Additional model I/O for DiffusionModule + "r_noisy_expected": placements_single, + "d_r_noisy_expected": placements_single, + "r_update_expected": placements_single, + "d_r_update_expected": placements_single, + # Additional model I/O for AtomDiffusion preconditioned network + "noised_atom_coords": placements_single, + "denoised_atom_coords_expected": placements_single, + "d_denoised_atom_coords_expected": placements_single, + # Additional model I/O for AtomDiffusion sample + "sample_atom_coords_expected": placements_single, + } + ) + + placements_model_io_fp32_full = OrderedDict( + { + "denoised_atom_coords_fp32": placements_single, + # Additional FP32 model I/O for DiffusionModule + "r_update_fp32": placements_single, + "d_r_noisy_fp32": placements_single, + # Additional FP32 model I/O for AtomDiffusion sample + "sample_atom_coords_fp32": placements_single, + } + ) + + # Apply subsetting if specified + placements_token_features = ( + OrderedDict({k: v for k, v in placements_token_features_full.items() if k in token_keys}) + if token_keys is not None + else placements_token_features_full + ) + + placements_msa_features = ( + OrderedDict({k: v for k, v in placements_msa_features_full.items() if k in msa_keys}) + if msa_keys is not None + else placements_msa_features_full + ) + + placements_atom_features = ( + OrderedDict({k: v for k, v in placements_atom_features_full.items() if k in atom_keys}) + if atom_keys is not None + else placements_atom_features_full + ) + + placements_model_io = ( + OrderedDict({k: v for k, v in placements_model_io_full.items() if k in model_io_keys}) + if model_io_keys is not None + else placements_model_io_full + ) + + placements_model_io_fp32 = ( + OrderedDict({k: v for k, v in placements_model_io_fp32_full.items() if k in model_io_fp32_keys}) + if model_io_fp32_keys is not None + else placements_model_io_fp32_full + ) + + # 2-tuple placements for cp submesh with preexisting batch dimension + # Generated using dictionary comprehension from their non-cp counterparts + placements_cp_atom_features = OrderedDict( + [ + # Add the additional key specific to cp variant first + ("atom_counts_per_token", placements_cp_single), + # Convert existing atom features maintaining original order + *[(k, convert_to_cp_placement(v)) for k, v in placements_atom_features.items()], + ] + ) + + placements_cp_model_io = OrderedDict({k: convert_to_cp_placement(v) for k, v in placements_model_io.items()}) + + placements_cp_model_io_fp32 = OrderedDict( + {k: convert_to_cp_placement(v) for k, v in placements_model_io_fp32.items()} + ) + + return { + # Base patterns + "single": placements_single, + "pair": placements_pair, + "scalar": placements_scalar, + "cp_single": placements_cp_single, + "cp_pair": placements_cp_pair, + # Feature placements (3-tuple) + "token_features": placements_token_features, + "msa_features": placements_msa_features, + "atom_features": placements_atom_features, + "model_io": placements_model_io, + "model_io_fp32": placements_model_io_fp32, + # Feature placements (2-tuple, cp submesh) + "cp_atom_features": placements_cp_atom_features, + "cp_model_io": placements_cp_model_io, + "cp_model_io_fp32": placements_cp_model_io_fp32, + } + + +def pad_to_length(t: DTensor, dim: int, length: int) -> DTensor: + """Pad a DTensor's local shards along *dim* so the global size equals *length*. + + ``distribute_atom_features`` applies intersperse padding per DP rank + independently, but ``CollateDTensor`` additionally homogenizes local shard + shapes across DP ranks via an all-reduce MAX. When samples have different + atom counts the dataloader features are homogenized but DTensors produced + by ``distribute_atom_features`` are not. This helper pads a DTensor's + local shard along *dim* to match an externally-known correct global size + (typically obtained from a homogenized batch feature such as + ``feats["atom_pad_mask"].shape[-1]``). + + Unlike ``homogenize_shard_shapes`` — which derives the target from the + DTensor's own (potentially inconsistent) global shape — this function + accepts an explicit *length* that is authoritative across all ranks. + + Parameters + ---------- + t : DTensor + Input distributed tensor whose global ``t.shape[dim]`` may be smaller + than *length*. + dim : int + Tensor dimension to pad (e.g. 1 for the atom dimension in + ``[B*mult, n_atoms, 3]``). + length : int + Desired global size along *dim*. Must be >= ``t.shape[dim]``. + + Returns + ------- + DTensor + A new DTensor with ``shape[dim] == length``. + + Raises + ------ + ValueError + If ``t.shape[dim] >= length`` (no padding needed — caller bug). + If ``length`` or ``t.shape[dim]`` is not divisible by the mesh size. + If multiple mesh axes shard the same tensor dimension. + """ + if t.shape[dim] > length: + raise ValueError(f"t.shape[{dim}]={t.shape[dim]} already exceeds target length={length}") + if t.shape[dim] == length: + raise ValueError( + f"t.shape[{dim}]={t.shape[dim]} already equals target length={length}; " + f"pad_to_length should not be called when no padding is needed" + ) + + # Find the single mesh axis that shards tensor dimension *dim*. + mesh_size = 1 + for mesh_dim_idx, p in enumerate(t.placements): + if isinstance(p, Shard) and p.dim == dim: + if mesh_size != 1: + raise ValueError( + f"pad_to_length does not support multiple mesh axes sharding " + f"the same tensor dimension {dim}. " + f"Placements: {t.placements}" + ) + mesh_size = t.device_mesh.size(mesh_dim_idx) + + if t.shape[dim] % mesh_size != 0: + raise ValueError( + f"t.shape[{dim}]={t.shape[dim]} is not divisible by mesh size {mesh_size} along the axis sharding dim {dim}" + ) + if length % mesh_size != 0: + raise ValueError(f"length={length} is not divisible by mesh size {mesh_size} along the axis sharding dim {dim}") + + local = t.to_local() + local_target = length // mesh_size + pad_amount = local_target - local.shape[dim] + if pad_amount <= 0: + raise ValueError( + f"Local shard shape[{dim}]={local.shape[dim]} already >= " + f"local target {local_target} (global length={length}, " + f"mesh_size={mesh_size}); no padding possible" + ) + + pad_spec = [0] * (2 * local.ndim) + pad_spec[2 * (local.ndim - 1 - dim) + 1] = pad_amount + local = torch.nn.functional.pad(local, pad_spec) + + gshape = list(t.shape) + gshape[dim] = length + gshape = torch.Size(gshape) + + return DTensor.from_local( + local, + t.device_mesh, + t.placements, + shape=gshape, + stride=LayoutRightMap(tuple(gshape)).strides, + ) + + +def homogenize_shard_shapes(input: DTensor, value_to_pad: Any | None = None) -> DTensor: + """Homogenize shard shapes across all ranks by padding local shards to a consistent size. + + NOTE: the involved padding is always towards the end (or the last element) along each + tensor axis + + In distributed tensor operations, different ranks may have slightly different local shard + sizes due to uneven data distribution. This function ensures all ranks have consistent + local shard shapes by padding smaller shards to match the target size. + + The target shape for each tensor dimension is determined by: + - For non-sharded dimensions: Same as the global tensor shape + - For sharded dimensions: global_size // mesh_size_along_sharding_mesh_dimension + + All ranks participate in this operation, even if some don't require padding, to ensure + proper synchronization for subsequent collective operations like DTensor.from_local(). + + Args: + input (DTensor): The input distributed tensor to homogenize + value_to_pad (Any | None, optional): Value to use for padding. Defaults to 0. + + Returns: + DTensor: A new DTensor with homogenized local shard shapes across all ranks + + Raises: + ValueError: If any local shard dimension is larger than the computed target size + + Example: + >>> # Ranks have different local shard sizes: [10, 8] and [10, 9] + >>> # After homogenization: both ranks have [10, 9] + >>> homogenized_dtensor = homogenize_shard_shapes(input_dtensor) + + Note: + This function is particularly useful in testing scenarios where you need + consistent shard shapes across ranks for reliable distributed computations. + """ + # Get the local shard and its shape + local_shard = input.to_local() + local_shape = torch.tensor(local_shard.shape, dtype=torch.int64) + + # Get DTensor properties + global_shape = list(input.shape) + placements = input.placements + device_mesh = input.device_mesh + + # Calculate target shape for the local shard + # Start with global shape and modify only sharded dimensions + target_shape = torch.tensor(global_shape, dtype=torch.int64) + + for mesh_dim_idx, placement in enumerate(placements): + if isinstance(placement, Partial): + raise ValueError(f"Partial placements are not supported: {placement}") + if isinstance(placement, Shard): + # For sharded dimensions, target size is global_size / mesh_size_for_this_mesh_dim + tensor_dim = placement.dim # This is the tensor dimension being sharded + mesh_size = device_mesh.size(mesh_dim_idx) + target_shape[tensor_dim] = global_shape[tensor_dim] // mesh_size + + # Calculate padding amounts using tensor operations (all ranks must compute this even if they don't need padding) + padding_amounts = target_shape - local_shape + + # Check for invalid cases where local shard is larger than target + if (padding_amounts < 0).any(): + raise ValueError(f"Local shard shape {local_shape} has axes larger than target shape {target_shape}") + + # Check if this rank needs padding + needs_padding = (padding_amounts > 0).any() + + if needs_padding: + # Apply padding to local shard + # torch.nn.functional.pad expects padding in reverse order (last dim first) + # Create pad_values as (ndim, 2) tensor: [pad_left, pad_right] for each dimension + pad_values = torch.zeros((local_shard.ndim, 2), dtype=torch.int64) + # always pad towards the end (or the last element) along each tensor axis + pad_values[:, 1] = padding_amounts # Set right padding amounts + + # Flatten and reverse for torch.nn.functional.pad format + pad_values = pad_values.flip(0).flatten().tolist() + + if value_to_pad is None: + value_to_pad = torch.tensor(0, dtype=local_shard.dtype) + + padded_local_shard = torch.nn.functional.pad(local_shard, pad_values, value=value_to_pad) + else: + # No padding needed for this rank, but still participate in collective operation + padded_local_shard = local_shard + + # By definition, this function ensures homogeneous shape across ranks + shape_output_global, stride_output_global = map( + tuple, + compute_global_tensor_info(padded_local_shard, device_mesh, placements), + ) + + # Create new DTensor from padded local shard + return DTensor.from_local( + padded_local_shard, + device_mesh=device_mesh, + placements=placements, + shape=shape_output_global, + stride=stride_output_global, + run_check=False, # Skip validation for performance + ) + + +def concat_data(out_dir: Path, *datas: Path) -> Path: + out_dir.mkdir(parents=True, exist_ok=True) + # manifest.json msa structures + # 1. copy the msa contents into msa, raising an error if there are duplicate filenames + # 2. copy the structures contents into structures, raising an error if there are duplicate filenames + # 3. merge the manifest.json files, raising an error if there are duplicate filenames + # 4. write the merged manifest.json + msa_dir = out_dir / "msa" + msa_dir.mkdir(parents=True, exist_ok=True) + copied = set() + if isinstance(datas, (Path, str)): + data_lst: list[Path] = [Path(datas)] + else: + data_lst = [Path(data) for data in datas] + for data in data_lst: + for file in (data / "msa").glob("*"): + if file.name in copied: + raise ValueError(f"Duplicate MSA file {file.name}") + shutil.copy(file, msa_dir / file.name) + copied.add(file.name) + structures_dir = out_dir / "structures" + structures_dir.mkdir(parents=True, exist_ok=True) + copied_structures = set() + manifests = [] + for data in data_lst: + for file in (data / "structures").glob("*"): + if file.name in copied_structures: + raise ValueError(f"Duplicate structure file {file.name}") + shutil.copy(file, structures_dir / file.name) + copied_structures.add(file.name) + with open(data / "manifest.json", "r") as f: + manifest = json.load(f) + manifests.append(manifest) + assert all(manifest.keys() == manifests[0].keys() for manifest in manifests), "Manifest keys do not match" + assert all(set(manifest.keys()) == {"records"} for manifest in manifests), "Manifest keys do not match" + records = [] + for manifest in manifests: + records.extend(manifest["records"]) + manifest = {"records": records} + manifest_file = out_dir / "manifest.json" + manifest_file.write_text(json.dumps(manifest)) + return out_dir + + +@contextmanager +def pytorch_use_deterministic_ops(): + """Context manager to enable PyTorch deterministic algorithms in spawn processes. + + This context manager enables deterministic behavior in PyTorch operations + for reproducibility in testing or debugging. It sets the CUBLAS_WORKSPACE_CONFIG + environment variable and enables PyTorch's deterministic algorithms mode. + + Important: This context manager must only be used in spawn parallel processes + (not the main process) because the CUBLAS_WORKSPACE_CONFIG environment variable + affects the underlying CUDA context for the entire process lifetime. Restricting + usage to spawned processes prevents side effects on other tests. + + The context manager automatically restores the previous deterministic setting + and CUBLAS_WORKSPACE_CONFIG value upon exit. + + Yields: + None + + Raises: + RuntimeError: If called in the main process (not a spawn parallel process). + + Example: + >>> # In a spawned test process + >>> with pytorch_use_deterministic_ops(): + ... # All PyTorch operations use deterministic algorithms + ... result = model(input_tensor) + """ + # technically, the CUBLAS_WORKSPACE_CONFIG must be set before the + # "import torch" statement and its effect on the underlying CUDA + # context will last until the process ends. For that reason, to + # exclude this env variable's side effects on other tests, we + # need to restrict this context manager to a spawn parallel + # processes + is_spawn_process = multiprocessing.parent_process() is not None + if is_spawn_process: + raise RuntimeError("pytorch_use_deterministic_ops() can only be used in spawn parallel processes") + deterministic_restore = torch.are_deterministic_algorithms_enabled() + original_env = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None) + try: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + torch.use_deterministic_algorithms(True) + yield + finally: + torch.use_deterministic_algorithms(deterministic_restore) + if original_env is not None: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = original_env + else: + os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None) + + +@contextmanager +def benchmark_peak_memory_and_runtime(): + """Context manager to benchmark peak memory usage and runtime of a code block. + + This context manager tracks the peak CUDA memory allocated during the execution + of the wrapped code block and measures the wall-clock time execution duration. + It yields a dictionary that will be populated with 'peak_mem' (in MB) and 'time' + (in ms) keys after the block execution completes. + + The memory measurement tracks the *peak allocated memory* relative to the memory + allocated at the start of the context, attempting to isolate the memory usage + of the specific operations within the block. + + Yields: + dict: A dictionary that will contain results after context exit: + - "peak_mem" (float): Peak memory usage in MB. + - "time" (float): Execution time in milliseconds. + + Note: + - Requires CUDA to be available. + - Performs `torch.cuda.synchronize()` before stopping the timer to ensure + accurate GPU timing. + - Clears cache and resets peak memory stats at the beginning. + + Example: + >>> with benchmark_peak_memory_and_runtime() as stats: + ... model(input) + >>> print(f"Memory: {stats['peak_mem']} MB, Time: {stats['time']} ms") + """ + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + start_mem = torch.cuda.memory_allocated() + start_time = time.time() + + stats = {} + yield stats + + torch.cuda.synchronize() + end_time = time.time() + peak_mem = torch.cuda.max_memory_allocated() + + # We want the peak memory *induced* by the function, relative to start. + peak_usage_mb = (peak_mem - start_mem) / 1024 / 1024 + duration_ms = (end_time - start_time) * 1000 + + stats["peak_mem"] = peak_usage_mb + stats["time"] = duration_ms + + +def pad_or_shrink_to_length( + tensor: torch.Tensor, axis: int, target_length: int, pad_value: float = 0.0 +) -> torch.Tensor: + """Pad or shrink tensor along the specified axis to target_length. + + When shrinking, the function slices from the beginning of the axis. + When padding, zeros (or pad_value) are appended at the end. + + Args: + tensor: Input tensor to resize. + axis: The dimension along which to pad or shrink. + target_length: The desired length along the axis. + pad_value: Value to use for padding (default: 0.0). + + Returns: + Tensor with shape[axis] == target_length. + """ + current_length = tensor.shape[axis] + if current_length == target_length: + return tensor + elif current_length > target_length: + # Shrink: slice to target_length + slices = [slice(None)] * tensor.ndim + slices[axis] = slice(0, target_length) + return tensor[tuple(slices)] + else: + # Pad: append zeros + pad_shape = list(tensor.shape) + pad_shape[axis] = target_length - current_length + padding = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device) + return torch.cat([tensor, padding], dim=axis) + + +def distribute_atom_features( + inputs: dict[str, Tensor], + placements_cp: dict[str, tuple], + placements_dp_cp: dict[str, tuple], + device_mesh: DeviceMesh, + cp_group: dist.ProcessGroup, + cp_mesh_dim_names: tuple[str, ...] = ("cp_axis_0", "cp_axis_1"), + dp_mesh_dim_name: str = "dp", + multiplicities: dict[str, int] | None = None, +) -> dict[str, DTensor]: + """Distribute atom features across a device mesh with intersperse padding. + + This utility abstracts the common workflow of: + 1. Calling pad_and_scatter_atom_features_dtensor() per DP rank to shard across CP + 2. Collating the batch along DP ranks + 3. Combining per-multiplicity DTensors into single DTensor with flattened batch*mult + + Parameters + ---------- + inputs : dict[str, Tensor] + Input tensors on host (CPU). Shape [batch, ...] for each feature. + For multiplicity features, use separate keys like "feat_0", "feat_1". + Batch size must equal the DP world size. + May include auxiliary features (e.g., "atom_counts_per_token") needed + by pad_and_scatter_atom_features_dtensor but not returned in output. + placements_cp : dict[str, tuple] + Placements for CP submesh (e.g., (Shard(0), Replicate())). + Must have the same keys as inputs. + placements_dp_cp : dict[str, tuple] + Placements for full mesh (e.g., (Shard(0), Shard(1), Replicate())). + Keys must be a subset of inputs.keys(). Only features with keys in + placements_dp_cp will be returned in the output. + device_mesh : DeviceMesh + Full device mesh (e.g., 3D: dp, cp_0, cp_1). + cp_group : dist.ProcessGroup + Flattened CP process group for this DP slice. + cp_mesh_dim_names : tuple[str, ...], optional + Names of CP dimensions in the mesh. Default ("cp_axis_0", "cp_axis_1"). + dp_mesh_dim_name : str, optional + Name of DP dimension in the mesh. Default "dp". + multiplicities : dict[str, int] | None, optional + Features with multiplicity. Key is base name (without _0, _1 suffix), + value is the multiplicity count. These will be combined in output. + + Returns + ------- + dict[str, DTensor] + DTensors distributed across the full mesh. Only features with keys in + placements_dp_cp are returned. Multiplicity features are combined with + shape [batch*mult, ...]. + + Raises + ------ + ValueError + If keys don't match requirements, batch size != DP world size, + or local batch size != 1. + + Examples + -------- + >>> # For features with multiplicity=2: + >>> inputs = { + ... "atom_counts_per_token": counts, # [B, n_tokens] - auxiliary, not in output + ... "token_to_rep_atom": tensor_a, # [B, n_tokens, n_atoms] + ... "resolved_mask_0": tensor_b0, # [B, n_atoms] + ... "resolved_mask_1": tensor_b1, # [B, n_atoms] + ... } + >>> placements_cp = { + ... "atom_counts_per_token": (Shard(0), Replicate()), + ... "token_to_rep_atom": (Shard(0), Replicate()), + ... "resolved_mask_0": (Shard(0), Replicate()), + ... "resolved_mask_1": (Shard(0), Replicate()), + ... } + >>> placements_dp_cp = { + ... "token_to_rep_atom": (Shard(0), Shard(1), Replicate()), + ... "resolved_mask_0": (Shard(0), Shard(1), Replicate()), + ... "resolved_mask_1": (Shard(0), Shard(1), Replicate()), + ... } + >>> multiplicities = {"resolved_mask": 2} + >>> result = distribute_atom_features( + ... inputs, placements_cp, placements_dp_cp, device_mesh, cp_group, + ... multiplicities=multiplicities + ... ) + >>> # result["token_to_rep_atom"] has shape [B, n_tokens_padded, n_atoms_padded] + >>> # result["resolved_mask"] has shape [B*2, n_atoms_padded] + >>> # "atom_counts_per_token" is NOT in result + """ + from boltz.distributed.data.feature.featurizer import pad_and_scatter_atom_features_dtensor + + multiplicities = multiplicities or {} + + # Validate key consistency + # inputs and placements_cp must have the same keys + if inputs.keys() != placements_cp.keys(): + raise ValueError( + f"inputs and placements_cp must have the same keys. " + f"inputs: {set(inputs.keys())}, placements_cp: {set(placements_cp.keys())}" + ) + # placements_dp_cp keys must be a subset of inputs keys + if not placements_dp_cp.keys() <= inputs.keys(): + raise ValueError( + f"placements_dp_cp keys must be a subset of inputs keys. " + f"placements_dp_cp: {set(placements_dp_cp.keys())}, inputs: {set(inputs.keys())}" + ) + + # Validate multiplicities keys exist in placements_dp_cp (since they'll be in output) + for base_name, mult in multiplicities.items(): + for i in range(mult): + key = f"{base_name}_{i}" + if key not in placements_dp_cp: + raise ValueError(f"Multiplicity key '{key}' not found in placements_dp_cp") + + # Get mesh info + cp_mesh = device_mesh[cp_mesh_dim_names] + dp_dim_idx = device_mesh.mesh_dim_names.index(dp_mesh_dim_name) + dp_rank = device_mesh.get_coordinate()[dp_dim_idx] + dp_world_size = device_mesh.size(dp_dim_idx) + + # Get batch size and validate + sample_key = next(iter(inputs.keys())) + batch_size = inputs[sample_key].shape[0] + if batch_size != dp_world_size: + raise ValueError(f"Batch size ({batch_size}) must equal DP world size ({dp_world_size})") + batch_size_per_dp = batch_size // dp_world_size + if batch_size_per_dp != 1: + raise ValueError( + f"Local batch size must be 1, got {batch_size_per_dp}. " + f"pad_and_scatter_atom_features_dtensor can only process one sample at a time." + ) + + # Get CP group info + cp_rank = dist.get_rank(cp_group) + cp_src_rank = dist.get_process_group_ranks(cp_group)[0] + + # Determine device + if device_mesh.device_type == "cuda": + local_device_idx = torch.cuda.current_device() + device = torch.device("cuda", local_device_idx) + else: + device = torch.device(device_mesh.device_type) + + # Prepare inputs for scatter: select sample for this DP rank + # Only CP rank 0 provides inputs; others pass None + _ENSEMBLE_EXPECTED_NDIM = {"frames_idx": 3, "frame_resolved_mask": 2} + if cp_rank == 0: + inputs_for_scatter = {} + for k, v in inputs.items(): + val = v[dp_rank].to(device=device) + expected_ndim = _ENSEMBLE_EXPECTED_NDIM.get(k) + if expected_ndim is not None and val.ndim < expected_ndim: + val = val.unsqueeze(0) + inputs_for_scatter[k] = val + else: + inputs_for_scatter = None + + # Scatter across CP ranks with intersperse padding + feats_cp = pad_and_scatter_atom_features_dtensor( + inputs_for_scatter, + placements_cp, + cp_group, + cp_src_rank, + cp_mesh, + ) + + # Compute global shape/stride and create DTensors on full mesh + # Only process keys that are in placements_dp_cp (output features) + def _local_for_dp_wrap(k, v): + """Get local tensor ready for DP wrapping (unsqueeze batch dim). + + For ensemble-aware features (frames_idx, frame_resolved_mask), the + featurizer requires an E=1 ensemble dim on input but downstream + consumers expect (T, ...) per sample. Squeeze the ensemble dim + before adding the DP batch dim. + """ + local = v.to_local() + if k in _ENSEMBLE_EXPECTED_NDIM: + local = local.squeeze(0) # (1, T_shard, ...) -> (T_shard, ...) + return local.unsqueeze(0) # Add batch dim + + feats_shape_stride = { + k: tuple( + map( + tuple, + compute_global_tensor_info(_local_for_dp_wrap(k, v), device_mesh, placements_dp_cp[k]), + ) + ) + for k, v in feats_cp.items() + if k in placements_dp_cp + } + + feats: dict[str, DTensor] = { + k: DTensor.from_local( + _local_for_dp_wrap(k, v), # Add batch dim (squeeze ensemble dim first if needed) + device_mesh, + placements_dp_cp[k], + shape=feats_shape_stride[k][0], + stride=feats_shape_stride[k][1], + ) + for k, v in feats_cp.items() + if k in placements_dp_cp + } + + # Combine per-multiplicity DTensors into single DTensor with flattened batch*mult + for base_name, mult in multiplicities.items(): + # Pop the per-multiplicity DTensors + mult_tensors = [feats.pop(f"{base_name}_{i}") for i in range(mult)] + + # Validate all tensors have same shape, placements, device_mesh + if not all(t.shape == mult_tensors[0].shape for t in mult_tensors): + raise ValueError(f"All multiplicity tensors for '{base_name}' must have the same shape") + if not all(t.placements == mult_tensors[0].placements for t in mult_tensors): + raise ValueError(f"All multiplicity tensors for '{base_name}' must have the same placements") + if not all(t.device_mesh == mult_tensors[0].device_mesh for t in mult_tensors): + raise ValueError(f"All multiplicity tensors for '{base_name}' must have the same device_mesh") + + # Combine: stack along dim 1, then flatten dims 0 and 1 + # [B, ...] -> stack -> [B, mult, ...] -> [B*mult, ...] + local_cat = torch.cat([t.to_local().unsqueeze(1) for t in mult_tensors], dim=1) + local_flat = local_cat.flatten(0, 1) + + feats[base_name] = DTensor.from_local( + local_flat, + mult_tensors[0].device_mesh, + mult_tensors[0].placements, + ) + + return feats + + +def make_random_contact_conditioning_features( + B: int, + N: int, + num_cc_types: int, + dtype: torch.dtype = torch.float32, + device: str = "cpu", + seed: int = 42, +) -> tuple[Tensor, Tensor]: + """Create random contact conditioning features for testing. + + Exercises all three masking branches (UNSPECIFIED, UNSELECTED, active) + so that both ``encoding_unspecified`` and ``encoding_unselected`` + parameters receive non-zero gradients. + + Parameters + ---------- + B : int + Batch size. + N : int + Number of tokens. + num_cc_types : int + Number of contact conditioning types (``len(const.contact_conditioning_info)``). + dtype : torch.dtype + Data type for the tensors. + device : str + Device for the tensors. + seed : int + Random seed for reproducibility. + + Returns + ------- + cc : Tensor + Contact conditioning features, shape ``(B, N, N, num_cc_types)``. + ct : Tensor + Contact threshold values, shape ``(B, N, N)``. + """ + with torch.random.fork_rng(): + torch.manual_seed(seed) + rand_type = torch.rand(B, N, N, device=device) + cc = torch.zeros(B, N, N, num_cc_types, dtype=dtype, device=device) + cc[:, :, :, 0] = (rand_type < 0.3).to(dtype) # UNSPECIFIED (~30%) + cc[:, :, :, 1] = ((rand_type >= 0.3) & (rand_type < 0.5)).to(dtype) # UNSELECTED (~20%) + if num_cc_types > 4: # noqa: PLR2004 + cc[:, :, :, 4] = (rand_type >= 0.5).to(dtype) # CONTACT (~50%) + ct = (torch.rand(B, N, N, device=device) * 22.0).to(dtype) + return cc, ct + + +def _extract_output_dtypes( + output: Any, + prefix: str, +) -> dict[str, torch.dtype]: + """Extract dtypes from a module's forward output. + + Handles plain tensors, DTensors, tuples/lists of tensors, and dicts + mapping to tensors. Returns ``{qualified_name: dtype}`` entries. + """ + if isinstance(output, Tensor): + return {prefix: output.dtype} + if isinstance(output, dict): + result: dict[str, torch.dtype] = {} + for k, v in output.items(): + if isinstance(v, Tensor): + result[f"{prefix}/{k}"] = v.dtype + return result + if isinstance(output, (tuple, list)): + result = {} + for i, v in enumerate(output): + if isinstance(v, Tensor): + result[f"{prefix}/{i}"] = v.dtype + return result + return {} + + +class DtypeProfiler: + """Capture dtypes of module outputs, parameters, and parameter gradients. + + Attach to a model before a forward + backward pass. After the pass, + the three dictionaries ``fwd_dtypes``, ``param_dtypes``, and + ``param_grad_dtypes`` contain ``{qualified_name: torch.dtype}`` entries + for every module output, parameter, and parameter gradient respectively. + + Usage:: + + profiler = DtypeProfiler(model) + loss = model(batch).sum() + loss.backward() + profiler.collect_grad_dtypes(model) + profiler.remove_hooks() + + Works transparently with both plain ``torch.Tensor`` and ``DTensor`` + outputs (``DTensor.dtype`` returns the element dtype). + """ + + def __init__(self, model: torch.nn.Module) -> None: + self.fwd_dtypes: dict[str, torch.dtype] = {} + self.param_dtypes: dict[str, torch.dtype] = {} + self.param_grad_dtypes: dict[str, torch.dtype] = {} + self._handles: list[torch.utils.hooks.RemovableHook] = [] + self._register(model) + + def _make_fwd_hook(self, name: str): + def hook(_module: torch.nn.Module, _input: Any, output: Any) -> None: + self.fwd_dtypes.update(_extract_output_dtypes(output, name)) + + return hook + + def _register(self, model: torch.nn.Module) -> None: + for name, param in model.named_parameters(): + self.param_dtypes[name] = param.dtype + for name, module in model.named_modules(): + self._handles.append(module.register_forward_hook(self._make_fwd_hook(name))) + + def collect_grad_dtypes(self, model: torch.nn.Module) -> None: + """Snapshot ``param.grad.dtype`` for every parameter with a gradient.""" + for name, param in model.named_parameters(): + if param.grad is not None: + self.param_grad_dtypes[name] = param.grad.dtype + + def remove_hooks(self) -> None: + """Remove all registered forward hooks.""" + for h in self._handles: + h.remove() + self._handles.clear() + + def __enter__(self) -> "DtypeProfiler": + return self + + def __exit__(self, *args: Any) -> None: + self.remove_hooks() + + +class RecomputeProfiler: + """Count how many times each module's forward is invoked. + + When activation checkpointing is active, ``torch.utils.checkpoint.checkpoint`` + re-runs the checkpointed function during backward. Because + ``nn.Module.__call__`` fires registered forward **pre**-hooks on every + invocation, modules inside a checkpointed region will have + ``fwd_counts[name] >= 2`` (once in the original forward, once during + recomputation) while modules outside checkpointed regions will have + ``fwd_counts[name] == 1``. + + Pre-hooks (``register_forward_pre_hook``) are used instead of post-forward + hooks because ``use_reentrant=False`` checkpointing raises an internal + ``_StopRecomputationError`` via ``early_stop`` to halt recomputation as + soon as all needed tensors are regenerated. This exception interrupts + ``Module.forward()``, so post-forward hooks never fire for modules at or + after the stop point. Pre-hooks fire at the *start* of + ``Module.__call__`` — before ``forward()`` — so they are unaffected. + + Usage:: + + profiler = RecomputeProfiler(model) + loss = model(batch).sum() + loss.backward() + profiler.remove_hooks() + print(profiler.recomputed_modules) + + Works transparently with both plain ``torch.nn.Module`` and DTensor-wrapped + models (hooks are registered on the local module hierarchy). + """ + + def __init__(self, model: torch.nn.Module) -> None: + self.fwd_counts: dict[str, int] = {} + self._handles: list[torch.utils.hooks.RemovableHook] = [] + for name, module in model.named_modules(): + self._handles.append(module.register_forward_pre_hook(self._make_hook(name))) + + def _make_hook(self, name: str): + def hook(_module: torch.nn.Module, _input: Any) -> None: + self.fwd_counts[name] = self.fwd_counts.get(name, 0) + 1 + + return hook + + @property + def recomputed_modules(self) -> frozenset[str]: + """Module names whose forward was called >= 2 times (recomputed by checkpoint).""" + return frozenset(n for n, c in self.fwd_counts.items() if c >= 2) + + def remove_hooks(self) -> None: + """Remove all registered forward hooks.""" + for h in self._handles: + h.remove() + self._handles.clear() + + +# --------------------------------------------------------------------------- +# CCD ligand feature loading for PB metric tests +# --------------------------------------------------------------------------- + +LIGAND_KEYS = ( + "ligand_edge_index", + "ligand_edge_lower_bounds", + "ligand_edge_upper_bounds", + "ligand_edge_bond_mask", + "ligand_edge_angle_mask", + "ligand_chiral_atom_index", + "ligand_chiral_check_mask", + "ligand_chiral_atom_orientations", + "ligand_stereo_bond_index", + "ligand_stereo_check_mask", + "ligand_stereo_bond_orientations", + "ligand_aromatic_5_ring_index", + "ligand_aromatic_6_ring_index", + "ligand_planar_double_bond_index", +) + + +def load_ligand_features_from_ccd( + mols_dir: str, + mol_name: str = "PHE", + atom_offset: int = 0, +) -> tuple[torch.Tensor, dict]: + """Load ligand geometry features from a CCD pickle file. + + Uses ``load_molecules`` + ``get_symmetries`` from ``boltz.data.mol`` + to read the pre-computed PB feature arrays (edge index, distance + bounds, chirality, stereo, aromatics) stored inside the RDKit Mol. + + Parameters + ---------- + mols_dir : str + Path to the directory containing per-residue ``.pkl`` files + (e.g. ``tests/test_data/data/mols``). + mol_name : str + CCD component name (default ``"PHE"``). + atom_offset : int + Offset added to all atom-index features so they refer to the + correct position within the full-structure atom array. + + Returns + ------- + coords : torch.Tensor + Ideal 3-D coordinates, shape ``(n_atoms, 3)``, ``float32``. + features : dict[str, torch.Tensor] + Dictionary keyed by the strings in ``LIGAND_KEYS``. + """ + from boltz.data.mol import get_symmetries, load_molecules + + mols = load_molecules(mols_dir, [mol_name]) + mol = mols[mol_name] + syms = get_symmetries(mols) + ( + _syms_ccd, _names_ccd, + edge_index, lower_bounds, upper_bounds, bond_mask, angle_mask, + chiral_atom_index, chiral_check_mask, chiral_atom_orientations, + stereo_bond_index, stereo_check_mask, stereo_bond_orientations, + aromatic_5_ring_index, aromatic_6_ring_index, planar_double_bond_index, + ) = syms[mol_name] + + conf = mol.GetConformer(0) + coords = torch.tensor( + [[conf.GetAtomPosition(i).x, conf.GetAtomPosition(i).y, conf.GetAtomPosition(i).z] + for i in range(mol.GetNumAtoms())], + dtype=torch.float32, + ) + + features = { + "ligand_edge_index": torch.tensor(edge_index, dtype=torch.long) + atom_offset, + "ligand_edge_lower_bounds": torch.tensor(lower_bounds, dtype=torch.float32), + "ligand_edge_upper_bounds": torch.tensor(upper_bounds, dtype=torch.float32), + "ligand_edge_bond_mask": torch.tensor(bond_mask), + "ligand_edge_angle_mask": torch.tensor(angle_mask), + "ligand_chiral_atom_index": torch.tensor(chiral_atom_index, dtype=torch.long) + atom_offset, + "ligand_chiral_check_mask": torch.tensor(chiral_check_mask), + "ligand_chiral_atom_orientations": torch.tensor(chiral_atom_orientations), + "ligand_stereo_bond_index": torch.tensor(stereo_bond_index, dtype=torch.long) + atom_offset, + "ligand_stereo_check_mask": torch.tensor(stereo_check_mask), + "ligand_stereo_bond_orientations": torch.tensor(stereo_bond_orientations), + "ligand_aromatic_5_ring_index": torch.tensor(aromatic_5_ring_index, dtype=torch.long) + atom_offset, + "ligand_aromatic_6_ring_index": torch.tensor(aromatic_6_ring_index, dtype=torch.long) + atom_offset, + "ligand_planar_double_bond_index": torch.tensor(planar_double_bond_index, dtype=torch.long) + atom_offset, + } + return coords, features + + +def make_pb_test_data(n_tok, n_atom, mols_dir, mol_name="PHE", batch_size=1, n_samples=2, seed=42): + """Create test data with a real CCD ligand for PB metric tests. + + Loads the ligand from the CCD pickle file at ``mols_dir/mol_name.pkl`` + via :func:`load_ligand_features_from_ccd`. + + Layout per batch element (``atoms_per_tok = n_atom // n_tok``): + Tokens 0 .. n_tok-4 : PROTEIN (asym_id=0) + Tokens n_tok-3 .. n_tok-1 : NONPOLYMER / ligand (asym_id=1) + + The 3 ligand tokens x ``atoms_per_tok`` atoms must be >= the number + of heavy atoms in the CCD component. CCD ideal coordinates are + placed at the ligand atom positions in ``sample_atom_coords``; + remaining protein positions are filled with random noise. + """ + atoms_per_tok = n_atom // n_tok + rng = torch.Generator().manual_seed(seed) + + feats = random_features( + size_batch=batch_size, + n_tokens=n_tok, + n_atoms=n_atom, + n_msa=1, + atom_counts_per_token_range=(atoms_per_tok, atoms_per_tok), + device=torch.device("cpu"), + float_value_range=(-1.0, 1.0), + selected_keys=["atom_to_token", "mol_type", "atom_pad_mask", "atom_counts_per_token", "asym_id"], + rng=rng, + ) + + batch: dict = {} + for k, v in feats.items(): + batch[k] = v.to(torch.float32) if v.is_floating_point() else v + batch["atom_to_token"] = batch["atom_to_token"].to(torch.float32) + + lig_start_tok = n_tok - 3 + batch["mol_type"][:, :lig_start_tok] = const.chain_type_ids["PROTEIN"] + batch["mol_type"][:, lig_start_tok:] = const.chain_type_ids["NONPOLYMER"] + + batch["asym_id"][:, :lig_start_tok] = 0 + batch["asym_id"][:, lig_start_tok:] = 1 + + lig_atom_start = lig_start_tok * atoms_per_tok + + lig_coords, lig_feats = load_ligand_features_from_ccd(mols_dir, mol_name, atom_offset=lig_atom_start) + n_lig_atoms = lig_coords.shape[0] + + sample_coords = torch.randn(batch_size * n_samples, n_atom, 3, generator=rng) + for s in range(batch_size * n_samples): + sample_coords[s, lig_atom_start : lig_atom_start + n_lig_atoms] = lig_coords + + for k in LIGAND_KEYS: + batch[k] = [lig_feats[k].clone() for _ in range(batch_size)] + + out = {"sample_atom_coords": sample_coords} + return batch, out diff --git a/src/boltz/workflow/__init__.py b/src/boltz/workflow/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/src/boltz/workflow/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/src/boltz/workflow/utils.py b/src/boltz/workflow/utils.py new file mode 100644 index 000000000..374aece3e --- /dev/null +++ b/src/boltz/workflow/utils.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import copy +import warnings +from pathlib import Path + +import omegaconf +import pytorch_lightning as pl +import torch + + +def convert_datasets_dict_to_list_config( + base: omegaconf.ListConfig, + override: omegaconf.DictConfig, + keys_to_override: set[str], + *, + remove_null_datasets: bool = False, +) -> omegaconf.ListConfig: + """Convert a DictConfig with string integer keys to a ListConfig by merging into base. + + This function provides a workaround for OmegaConf's limitation of partially overriding + entries nested in a ListConfig by enabling the user to specify the partial overrides + using a DictConfig with string integer keys so that entry under those string integer keys + are merged into the base ListConfig[index] structure. + + Base items may have extra keys beyond keys_to_override. Override entries may only + contain keys that are both in keys_to_override and present in the corresponding + base item. + + When remove_null_datasets is True, an override value of None (e.g. data.datasets.1=null + on the CLI) removes that list entry from the result instead of merging. + + Args: + base: The base ListConfig to merge overrides into. Items may have any keys; + override keys must exist in the target base item. + override: A DictConfig with string integer keys (e.g., {"0": {...}, "1": null}). + Each key is an index in the base ListConfig. Values are either a DictConfig + of keys to merge, or None to remove that entry (only when remove_null_datasets + is True). + keys_to_override: Set of keys that are allowed in override items (whitelist). + Used for validation; override keys must be a subset of this set and of + the target base item's keys. + remove_null_datasets: If True, override entries with value None cause that + list index to be removed from the result. If False, a None value raises + ValueError. + + Returns: + A new ListConfig that is a deep copy of base with the override values merged in at + the specified indices, and with null-marked entries removed when remove_null_datasets + is True. The user can use OmegaConf.merge to merge the returned ListConfig + with the base ListConfig to get the effect of partial overrides as described above. + + Raises: + ValueError: If base is not a ListConfig, override is not a DictConfig, override is + empty, override contains invalid index keys, override items have keys not in + keys_to_override, override items have keys not present in the corresponding + base item, or an override value is None while remove_null_datasets is False. + + Example: + >>> base = OmegaConf.create([{"target_dir": "/path1", "prob": 0.5}]) + >>> override = OmegaConf.create({"0": {"target_dir": "/new/path"}}) + >>> result = convert_datasets_dict_to_list_config(base, override, {"target_dir", "prob"}) + >>> result[0].target_dir + '/new/path' + """ + if not isinstance(base, omegaconf.ListConfig): + raise ValueError(f"base must be a ListConfig, got {type(base)}") + + if not isinstance(override, omegaconf.DictConfig): + raise ValueError(f"override must be a DictConfig, got {type(override)}") + + # expecting command line override data.datasets. where is an integer + # in the range [0, len(base) - 1] + keys_dataset_ids_override = set(map(str, range(len(base)))) + + if len(override.keys()) == 0: + raise ValueError( + "Input DictConfig override is empty. " + "Please specify at least one item using as its key " + f"where is in the set of {keys_dataset_ids_override}" + ) + + if not (override.keys() <= keys_dataset_ids_override): + raise ValueError(f"Invalid keys in override: {override.keys()}. Valid keys are: {keys_dataset_ids_override}") + + ans = copy.deepcopy(base) + indices_to_remove = set() + for i in range(len(base)): + i_str = str(i) + if i_str in override: + if override[i_str] is None: + if remove_null_datasets: + indices_to_remove.add(i) + continue + raise ValueError(f"Override for item {i_str} is null but remove_null_datasets is False") + if not (override[i_str].keys() <= keys_to_override): + raise ValueError( + f"Invalid keys in override of item {i_str}: " + f"{override[i_str].keys()}. Valid keys are: {keys_to_override}" + ) + if not (override[i_str].keys() <= base[i].keys()): + raise ValueError( + f"Override keys {override[i_str].keys()} for item {i_str} " + f"contain keys not present in base item: " + f"{override[i_str].keys() - base[i].keys()}" + ) + for k, v in override[i_str].items(): + ans[i][k] = v + if indices_to_remove: + ans = omegaconf.OmegaConf.create([ans[i] for i in range(len(ans)) if i not in indices_to_remove]) + return ans + + +# Default whitelist for CLI overrides of data.datasets[*], aligned with DatasetConfig (trainingv2). +# Used by train entrypoints when calling convert_datasets_dict_to_list_config. +_DATASET_KEYS_TO_OVERRIDE = { + "_target_", + "target_dir", + "msa_dir", + "prob", + "sampler", + "cropper", + "template_dir", + "filters", + "split", + "symmetry_correction", + "val_group", + "use_train_subset", + "moldir", + "override_bfactor", + "override_method", +} + + +class CUDAMemoryProfile(pl.Callback): + """PyTorch Lightning callback for profiling CUDA memory usage. + + Captures a detailed history of CUDA memory allocations and deallocations + throughout training or prediction, then dumps a memory snapshot at the end. + The snapshot can be analyzed with the PyTorch Memory Visualizer. + + Uses ``torch.cuda.memory._record_memory_history`` / + ``torch.cuda.memory._dump_snapshot`` under the hood. + + Parameters + ---------- + output_path : Path or str + Path where the memory snapshot pickle file will be saved. + Parent directories are created automatically. + *args + Forwarded to ``torch.cuda.memory._record_memory_history()``. + **kwargs + Forwarded to ``torch.cuda.memory._record_memory_history()``. + Common kwargs include ``max_entries`` (default 100 000). + + Examples + -------- + >>> profiler = CUDAMemoryProfile("profiling/mem_rank0.pickle", max_entries=300000) + >>> trainer = pl.Trainer(callbacks=[profiler]) + """ + + def __init__(self, output_path: Path | str, *args, **kwargs): + super().__init__() + self._output_path = Path(output_path) if isinstance(output_path, str) else output_path + self._args = args + self._kwargs = kwargs + self._output_path.parent.mkdir(parents=True, exist_ok=True) + + def _start_recording(self) -> None: + torch.cuda.memory._record_memory_history(*self._args, **self._kwargs) + + def _stop_and_dump(self) -> None: + try: + torch.cuda.memory._dump_snapshot(str(self._output_path)) + except Exception as e: + warnings.warn(f"CUDAMemoryProfile: Failed to capture memory snapshot: {e}") + torch.cuda.memory._record_memory_history(enabled=None) + + def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._start_recording() + + def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._stop_and_dump() + + def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._start_recording() + + def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._stop_and_dump() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..87410ceb0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,591 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import os +import pickle +import re +import shutil +import tarfile +import tempfile +import urllib.request +import warnings +from collections import OrderedDict +from copy import deepcopy +from dataclasses import asdict +from math import prod +from pathlib import Path + +import gdown +import pytest +import torch +from torch import Tensor + +from boltz.data.feature.featurizer import BoltzFeaturizer +from boltz.data.load import CACHE_DIR +from boltz.data.load.load import load +from boltz.data.module.inference import BoltzInferenceDataModule +from boltz.data.tokenize.boltz import BoltzTokenizer +from boltz.data.types import Input, Manifest, Target, Tokenized +from boltz.distributed.data.types import PairMaskMode +from boltz.distributed.manager import DistributedManager +from boltz.main import ( + MOL_URL, + BoltzDiffusionParams, + BoltzProcessedInput, + BoltzSteeringParams, + parse_yaml, +) +from boltz.model.models.boltz1 import Boltz1 +from boltz.testing.utils import concat_data, download_model_ckpt + + +def pytest_addoption(parser): + """Add custom CLI options for pytest. + + Adds the --name_regex option to filter tests by matching their node IDs + against a regular expression pattern. + + Args: + parser: The pytest argument parser to add options to. + """ + parser.addoption("--name_regex", action="store", default=None, help="Run tests matching a regex") + + +def pytest_collection_modifyitems(config, items): + """Filter collected test items based on the --name_regex option. + + If --name_regex is provided, only tests whose node ID matches the given + regular expression pattern will be selected for execution. Non-matching + tests are deselected and reported via the pytest_deselected hook. + + Args: + config: The pytest config object containing CLI options. + items: List of collected test items to filter (modified in-place). + """ + regex = config.getoption("--name_regex") + if regex: + r = re.compile(regex) + selected = [] + deselected = [] + for item in items: + if r.search(item.nodeid): + selected.append(item) + else: + deselected.append(item) + items[:] = selected + config.hook.pytest_deselected(items=deselected) + + +ROOT_DIR = Path(__file__).resolve().parents[1] +EXAMPLE_DIR = ROOT_DIR / "examples" + +EXAMPLE_PROT_YAML = EXAMPLE_DIR / "prot.yaml" +EXAMPLE_PROT_CUSTOM_MSA_YAML = EXAMPLE_DIR / "prot_custom_msa.yaml" +EXAMPLE_MULTIMER_YAML = EXAMPLE_DIR / "multimer.yaml" # can lead to SIGKILL on cp = (4, 4) +EXAMPLE_CYCLIC_PROT_YAML = EXAMPLE_DIR / "cyclic_prot.yaml" +EXAMPLE_YAMLS = [ + EXAMPLE_PROT_YAML, + EXAMPLE_MULTIMER_YAML, +] + +TEST_INFERENCE_DIR = ROOT_DIR / "tests" / "data" / "inference" +TEST_YAML = TEST_INFERENCE_DIR / "test_input.yaml" +TEST_POCKET_YAMLS = [ + TEST_INFERENCE_DIR / "pocket_small.yaml", +] + +SEED = 42 +CCD_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/ccd.pkl" + + +def download_ccd() -> Path: + ccd_path = CACHE_DIR / "ccd.pkl" + if not ccd_path.exists(): + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "ccd.pkl") + gdown.download(CCD_URL, zip_path) + shutil.move(zip_path, ccd_path) + + return ccd_path + + +@pytest.fixture(scope="session") +def ccd(): + ccd_path = download_ccd() + with ccd_path.open("rb") as file: + return pickle.load(file) # noqa: S301 + + +@pytest.fixture(scope="function") +def example_cyclic_prot_input(ccd): + target: Target = parse_yaml(EXAMPLE_CYCLIC_PROT_YAML, ccd) + input = Input(target.structure, {}, target.record) + return input + + +@pytest.fixture(scope="function") +def example_multimer_input(ccd): + target: Target = parse_yaml(EXAMPLE_MULTIMER_YAML, ccd) + input = Input(target.structure, {}, target.record) + return input + + +@pytest.fixture(scope="function", params=EXAMPLE_YAMLS) +def example_input(ccd, request): + yaml = request.param + target: Target = parse_yaml(yaml, ccd) + input = Input(target.structure, {}, target.record) + return input + + +@pytest.fixture(scope="function") +def example_multimer_tokenized(example_multimer_input: Input) -> Tokenized: + tokenizer = BoltzTokenizer() + return tokenizer.tokenize(example_multimer_input) + + +@pytest.fixture(scope="function") +def example_tokenized(example_input: Input) -> Tokenized: + tokenizer = BoltzTokenizer() + return tokenizer.tokenize(example_input) + + +@pytest.fixture(scope="function") +def example_features(example_tokenized: Tokenized) -> dict[str, Tensor]: + featurizer = BoltzFeaturizer() + feats = featurizer.process( + example_tokenized, + training=False, + augmentation=False, + pair_mask_mode=PairMaskMode.SEQUENCE_LOCAL_ATTENTION, + ) + return {k: v.unsqueeze(0) for k, v in feats.items()} # create batch dimension + + +@pytest.fixture(scope="function") +def example_prot_input(ccd): + target: Target = parse_yaml(EXAMPLE_PROT_YAML, ccd) + input = Input(target.structure, {}, target.record) + return input + + +@pytest.fixture(scope="function") +def example_prot_tokenized(example_prot_input: Input) -> Tokenized: + tokenizer = BoltzTokenizer() + return tokenizer.tokenize(example_prot_input) + + +@pytest.fixture(scope="function", params=TEST_POCKET_YAMLS) +def example_pocket_input(ccd, request): + yaml = request.param + target: Target = parse_yaml(yaml, ccd) + input = Input(target.structure, {}, target.record) + return input + + +@pytest.fixture(scope="function") +def setup_env(request, monkeypatch): + (n_procs_dp, n_procs_cp), specify_method, device_type, method_init = request.param + if isinstance(n_procs_cp, tuple) and all(isinstance(n, int) for n in n_procs_cp): + world_size = n_procs_dp * prod(n_procs_cp) + elif isinstance(n_procs_cp, int): + world_size = n_procs_dp * n_procs_cp + else: + raise ValueError(f"Invalid type for CP ranks: {type(n_procs_cp)}") + env_per_rank = None + if specify_method: + if method_init is None: + raise ValueError("method_init must be specified if specify_method is True") + # this emulates the behavior in the DistributedManager in the + # case where the user explicitly sets the init method. In this + # case, we don't need to clean up the env variables not in the + # the scope of the "method_init" because only the "method_init" + # in question is used to initialize the distributed manager + monkeypatch.setenv("BOLTZ_DISTRIBUTED_INIT_METHOD", method_init) + elif not specify_method: + # this emulates the behavior in the DistributedManager in the + # case where the user does not specify the init method but rely + # on available environment variables to initialize the DistributedManager. + # We need to clean up the existing environment variables so that + # the tests truly respect the method_init. There are two sub-cases: + # 1. the user does not have any environment variables set, in which case + # the DistributedManager will default initialize, which falls back + # to the single-device case + # 2. the user has set the environment variables related to the + # ENV or SLURM init method, in which case the DistributedManager + # will use the environment variables to initialize + # Sub-case 2 is handled by the code in the subsequent if statements + monkeypatch.delenv("MASTER_ADDR", raising=False) + monkeypatch.delenv("MASTER_PORT", raising=False) + monkeypatch.delenv("WORLD_SIZE", raising=False) + monkeypatch.delenv("RANK", raising=False) + monkeypatch.delenv("LOCAL_RANK", raising=False) + monkeypatch.delenv("SLURM_LAUNCH_NODE_IPADDR", raising=False) + monkeypatch.delenv("SLURM_NPROCS", raising=False) + monkeypatch.delenv("SLURM_PROCID", raising=False) + monkeypatch.delenv("SLURM_LOCALID", raising=False) + if method_init == "ENV": + monkeypatch.setenv("MASTER_ADDR", "localhost") + monkeypatch.setenv("MASTER_PORT", "29500") + monkeypatch.setenv("WORLD_SIZE", f"{world_size}") + env_per_rank = {"RANK": "", "LOCAL_RANK": ""} + elif method_init == "SLURM": + monkeypatch.setenv("SLURM_LAUNCH_NODE_IPADDR", "localhost") + monkeypatch.setenv("SLURM_NPROCS", f"{world_size}") + env_per_rank = {"SLURM_PROCID": "", "SLURM_LOCALID": ""} + backend = DistributedManager.backend_for_device()[device_type] + grid_group_sizes = OrderedDict(dp=n_procs_dp, cp=n_procs_cp) + yield grid_group_sizes, world_size, device_type, backend, method_init, env_per_rank + + +@pytest.fixture(scope="session") +def get_inference_golden_value_dir_v1() -> Path: + return load("unittests/test_inference_pipeline_golden_values") + + +@pytest.fixture(scope="session") +def test_cp_integration_data_dir_v1() -> Path: + base_data_dir = load("unittests/test_cp_integration") + return base_data_dir + + +@pytest.fixture(scope="session") +def test_cp_integration_data_dir_boltz1_v1() -> Path: + base_data_dir = load("unittests/test_cp_integration_boltz1") + return base_data_dir + + +@pytest.fixture(scope="session", params=["7ylz", "7z64", "8ayv", "8b2e"]) +def get_preprocessed_boltz1_v1(test_cp_integration_data_dir_boltz1_v1, request): + name = request.param + return test_cp_integration_data_dir_boltz1_v1 / f"processed_{name}" + + +@pytest.fixture(scope="session", params=["7ylz", "7z64", "8ayv", "8b2e"]) +def get_preprocessed_v1(test_cp_integration_data_dir_v1, request): + name = request.param + return test_cp_integration_data_dir_v1 / f"processed_{name}" + + +@pytest.fixture(scope="session") +def create_preprocessed_handle_boltz1_v1( + get_preprocessed_boltz1_v1: Path, +) -> BoltzProcessedInput: + f_manifest = get_preprocessed_boltz1_v1 / "manifest.json" + dir_structure = get_preprocessed_boltz1_v1 / "structures" + dir_msa = get_preprocessed_boltz1_v1 / "msa" + assert f_manifest.is_file(), f"Manifest file {f_manifest} does not exist" + assert dir_structure.is_dir(), f"Structure directory {dir_structure} does not exist" + assert dir_msa.is_dir(), f"MSA directory {dir_msa} does not exist" + processed = BoltzProcessedInput( + manifest=Manifest.load(f_manifest), + targets_dir=dir_structure, + msa_dir=dir_msa, + ) + return processed + + +@pytest.fixture(scope="session") +def create_preprocessed_handle_v1(get_preprocessed_v1: Path) -> BoltzProcessedInput: + f_manifest = get_preprocessed_v1 / "manifest.json" + dir_structure = get_preprocessed_v1 / "structures" + dir_msa = get_preprocessed_v1 / "msa" + assert f_manifest.is_file(), f"Manifest file {f_manifest} does not exist" + assert dir_structure.is_dir(), f"Structure directory {dir_structure} does not exist" + assert dir_msa.is_dir(), f"MSA directory {dir_msa} does not exist" + processed = BoltzProcessedInput( + manifest=Manifest.load(f_manifest), + targets_dir=dir_structure, + msa_dir=dir_msa, + ) + return processed + + +@pytest.fixture(scope="function") +def create_datamodule_serial_v1(create_preprocessed_handle_v1): + data_module = BoltzInferenceDataModule( + manifest=create_preprocessed_handle_v1.manifest, + target_dir=create_preprocessed_handle_v1.targets_dir, + msa_dir=create_preprocessed_handle_v1.msa_dir, + augmentation=False, + num_workers=0, + ) # default use_cache=False + return data_module + + +@pytest.fixture(scope="session") +def get_model_v1_ckpt(): + f_ckpt = download_model_ckpt() + assert f_ckpt.is_file(), f"Checkpoint file {f_ckpt} does not exist" + return f_ckpt + + +@pytest.fixture(scope="session") +def get_predict_args_v1(): + predict_args = { + "recycling_steps": 10, # Boltz uses 10 for evaluation (https://github.com/jwohlwend/boltz/blob/main/docs/evaluation.md#evaluation-setup) + "sampling_steps": 200, + "diffusion_samples": 1, + "write_confidence_summary": False, + "write_full_pae": False, + "write_full_pde": False, + } + return predict_args + + +@pytest.fixture(scope="session") +def get_diffusion_params_v1(): + diffusion_params = BoltzDiffusionParams() + return asdict(diffusion_params) + + +@pytest.fixture(scope="session") +def get_steering_params_no_potentials_v1(): + steering_params = BoltzSteeringParams() + steering_params.fk_steering = False + steering_params.guidance_update = False + return asdict(steering_params) + + +@pytest.fixture(scope="session") +def get_score_model_args_v1(): + return { + "sigma_data": 16, + "dim_fourier": 256, + "atom_encoder_depth": 3, + "atom_encoder_heads": 4, + "token_transformer_depth": 24, + "token_transformer_heads": 16, + "atom_decoder_depth": 3, + "atom_decoder_heads": 4, + "activation_checkpointing": False, + } + + +@pytest.fixture(scope="session") +def _load_model_v1_with_ckpt_serial( + get_model_v1_ckpt, + get_predict_args_v1, + get_diffusion_params_v1, + get_steering_params_no_potentials_v1, + get_score_model_args_v1, +): + kwargs_model = { + "confidence_prediction": False, + "predict_args": get_predict_args_v1, + "diffusion_process_args": get_diffusion_params_v1, + "steering_args": get_steering_params_no_potentials_v1, + "score_model_args": get_score_model_args_v1, + "ema": False, + } + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + model = Boltz1.load_from_checkpoint(get_model_v1_ckpt, strict=False, map_location="cpu", **kwargs_model) + ckpt_dict = torch.load(get_model_v1_ckpt, map_location="cpu", weights_only=False) + # merge the model kwargs into the hyper_parameters entry + ckpt_dict["hyper_parameters"].update(kwargs_model) + # drop deprecated args + ckpt_dict["hyper_parameters"].pop("chain_sampling_args") + ckpt_dict["hyper_parameters"].pop("recycling_detach") + ckpt_dict["hyper_parameters"].pop("run_trunk_and_structure") + # the CP-related settings must be absent to prevent mix-and-matching + # data sharding and model distributed computation settings + assert "dist_manager" not in ckpt_dict["hyper_parameters"] + assert "layout_map_cp" not in ckpt_dict["hyper_parameters"] + return model, ckpt_dict + + +@pytest.fixture(scope="function") +def load_model_v1_with_ckpt_serial(_load_model_v1_with_ckpt_serial): + _, ckpt_dict = _load_model_v1_with_ckpt_serial + ckpt_dict = deepcopy(ckpt_dict) + model = Boltz1(**ckpt_dict["hyper_parameters"]) + model.load_state_dict(ckpt_dict["state_dict"], strict=False) + return model, ckpt_dict + + +@pytest.fixture(scope="function", params=[False, True], ids=["no_tf32", "tf32"]) +def setup_tf32(request): + """Configure TF32 settings and reset them after the test.""" + use_tf32 = request.param + + original_env = os.environ.get("NVIDIA_TF32_OVERRIDE", None) + original_matmul_tf32 = torch.backends.cuda.matmul.allow_tf32 + original_cudnn_tf32 = torch.backends.cudnn.allow_tf32 + + tf32_value = "1" if use_tf32 else "0" + os.environ["NVIDIA_TF32_OVERRIDE"] = tf32_value + torch.backends.cuda.matmul.allow_tf32 = use_tf32 + torch.backends.cudnn.allow_tf32 = use_tf32 + + yield use_tf32 + + if original_env is not None: + os.environ["NVIDIA_TF32_OVERRIDE"] = original_env + else: + os.environ.pop("NVIDIA_TF32_OVERRIDE", None) + torch.backends.cuda.matmul.allow_tf32 = original_matmul_tf32 + torch.backends.cudnn.allow_tf32 = original_cudnn_tf32 + + +@pytest.fixture +def get_training_data_v1(): + """Download and return the path to the training data truncated set. + + Returns the path to boltz_training_truncated_set directory which contains: + - openfold_processed_targets/ + - openfold_processed_msa/ + - train_ids.txt + - validation_ids.txt + + The parent directory (training_data) contains symmetry.pkl. + """ + return load("unittests/training_data_truncated_set") / "training_data" / "boltz_training_truncated_set" + + +#################################################################################################### +# Boltz 2 UTILITIES +#################################################################################################### + + +def _build_processed_input_boltz2(base_dir: Path) -> BoltzProcessedInput: + """Build a BoltzProcessedInput from a processed data directory.""" + manifest = Manifest.load(base_dir / "manifest.json") + targets_dir = base_dir / "structures" + msa_dir = base_dir / "msa" + template_dir = base_dir / "templates" if (base_dir / "templates").exists() else None + extra_mols_dir = base_dir / "extra_mols" if (base_dir / "extra_mols").exists() else None + return BoltzProcessedInput( + manifest=manifest, + targets_dir=targets_dir, + msa_dir=msa_dir, + constraints_dir=None, + template_dir=template_dir, + extra_mols_dir=extra_mols_dir, + ) + + +def _concat_data_with_records(out_dir: Path, *datas: Path) -> Path: + """Concatenate processed Boltz2 dirs, including records.""" + merged = concat_data(out_dir, *datas) + records_dir = merged / "records" + records_dir.mkdir(parents=True, exist_ok=True) + + copied_records = set() + for data in datas: + src_records_dir = Path(data) / "records" + for record_file in src_records_dir.glob("*.json"): + if record_file.name in copied_records: + raise ValueError(f"Duplicate record file {record_file.name}") + shutil.copy(record_file, records_dir / record_file.name) + copied_records.add(record_file.name) + + return merged + + +@pytest.fixture(scope="session") +def canonical_mols_dir() -> Path: + """Download canonical molecules to cache if needed, return the mols directory. + + The mols directory contains per-residue pickle files of RDKit Mol objects derived + from the Chemical Component Dictionary (CCD). Each file (e.g. ALA.pkl) provides + the reference 3D structure, atom names, bonds, and other chemical metadata for a + single CCD component. + + This is the Boltz2 equivalent of the `ccd.pkl` file used by Boltz1. Where + Boltz1 deserializes the entire CCD dictionary into memory at once, Boltz2 splits it + into individual per-component files so that only the needed subset is loaded. + `load_canonicals` eagerly loads the 20 standard amino acid residues plus UNK, + while `load_molecules` and `get_mol` load additional non-standard components + lazily on demand during tokenization, featurization, and structure parsing. + + Returns + ------- + Path + The path to the mols directory. + """ + cache_mols = CACHE_DIR / "mols" + cache_tar = CACHE_DIR / "mols.tar" + CACHE_DIR.mkdir(parents=True, exist_ok=True) + if not cache_tar.exists(): + with tempfile.TemporaryDirectory() as temp_dir: + tmp_tar = os.path.join(temp_dir, "mols.tar") + gdown.download(MOL_URL, tmp_tar) + shutil.move(tmp_tar, cache_tar) + if not cache_mols.exists(): + with tarfile.open(str(cache_tar), "r") as tar: + tar.extractall(CACHE_DIR) # noqa: S202 + return cache_mols + + +@pytest.fixture(scope="session") +def test_cp_training_base_data_dir_boltz2() -> Path: + """Raw Boltz2 training archive root containing processed_{id} dirs.""" + return load("unittests/test_cp_training_data_boltz2") + + +@pytest.fixture(scope="session", params=["processed_7ylz", "processed_7z64", "processed_8ayv", "processed_8b2e"]) +def get_preprocessed_boltz2(test_cp_training_base_data_dir_boltz2: Path, request: pytest.FixtureRequest) -> Path: + """Per-sample preprocessed directory for Boltz-2 predict tests.""" + return test_cp_training_base_data_dir_boltz2 / request.param + + +@pytest.fixture(scope="session") +def test_cp_training_data_dir_boltz2( + tmp_path_factory: pytest.TempPathFactory, + test_cp_training_base_data_dir_boltz2: Path, +) -> Path: + """Merged Boltz2 training directory with records for the training data module.""" + names = ["7ylz", "7z64", "8ayv", "8b2e"] + source_dirs = [test_cp_training_base_data_dir_boltz2 / f"processed_{name}" for name in names] + out_dir = tmp_path_factory.mktemp("cp_training_data_boltz2") / "processed_training" + return _concat_data_with_records(out_dir, *source_dirs) + + +@pytest.fixture(scope="session") +def create_preprocessed_handle_boltz2(test_cp_training_data_dir_boltz2: Path) -> BoltzProcessedInput: + """Build a BoltzProcessedInput from the merged Boltz2 training data directory.""" + return _build_processed_input_boltz2(test_cp_training_data_dir_boltz2) + + +@pytest.fixture(scope="session") +def get_inference_golden_value_dir_v2() -> Path: + return load("unittests/test_inference_pipeline_golden_values_boltz2") + + +@pytest.fixture(scope="session") +def get_model_ckpt_v2() -> Path: + from boltz.main import BOLTZ2_URL_WITH_FALLBACK + + cache = CACHE_DIR + cache.mkdir(parents=True, exist_ok=True) + checkpoint = cache / "boltz2_conf.ckpt" + if not checkpoint.exists(): + for i, url in enumerate(BOLTZ2_URL_WITH_FALLBACK): + try: + urllib.request.urlretrieve(url, str(checkpoint)) # noqa: S310 + break + except Exception as e: # noqa: BLE001 + if i == len(BOLTZ2_URL_WITH_FALLBACK) - 1: + msg = f"Failed to download Boltz-2 checkpoint: {e}" + raise RuntimeError(msg) from e + assert checkpoint.is_file(), f"Checkpoint {checkpoint} does not exist" + return checkpoint diff --git a/tests/data/feature/test_featurizerv2.py b/tests/data/feature/test_featurizerv2.py new file mode 100644 index 000000000..56866d96d --- /dev/null +++ b/tests/data/feature/test_featurizerv2.py @@ -0,0 +1,480 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for featurizerv2.py (inference featurizer) bug fixes. + +Covers: + 1. `visited` set construction: taxonomy-assigned sequences must be excluded + from the unpaired `available` pool. + 2. In-place mutation: construct_paired_msa must not mutate caller's MSA data. + 3. Deletion indexing: the deletion extraction loop must slice from the + ORIGINAL full array on every iteration, not from the previous slice. +""" + +from types import SimpleNamespace + +import numpy as np +import pytest + +from boltz.data import const +from boltz.data.feature.featurizerv2 import construct_paired_msa +from boltz.data.types import ( + MSA, + Chain, + MSADeletion, + MSAResidue, + MSASequence, + Residue, + StructureV2, + Token, +) + +# --------------------------------------------------------------------------- +# Helpers -- same synthetic builders as the train tests, but using Token dtype +# (inference featurizer operates on Tokenized which uses Token, not TokenV2) +# --------------------------------------------------------------------------- + + +def _make_chain(chain_id, res_idx, res_num): + """Create a single-element Chain structured array. + + Parameters + ---------- + chain_id : int + Value for the ``asym_id`` field. + res_idx : int + Starting index of this chain's residues in the global residue array. + res_num : int + Number of residues in the chain. + """ + arr = np.zeros(1, dtype=Chain) + arr[0]["asym_id"] = chain_id + arr[0]["res_idx"] = res_idx + arr[0]["res_num"] = res_num + return arr + + +def _make_msa(residue_types, taxonomies, deletions_per_seq=None): + """Build an MSA from explicit residue types, taxonomies, and deletions. + + Parameters + ---------- + residue_types : list[list[int]] + Outer = sequences, inner = residue token IDs. + taxonomies : list[int] + Taxonomy ID per sequence (-1 for none). + deletions_per_seq : list[list[tuple[int, int]]], optional + Per-sequence list of (res_idx, count) deletion entries. + """ + if deletions_per_seq is None: + deletions_per_seq = [[] for _ in residue_types] + + all_residues, all_deletions, sequences = [], [], [] + for seq_idx, (res_types, taxon, dels) in enumerate(zip(residue_types, taxonomies, deletions_per_seq)): + res_start = len(all_residues) + all_residues.extend(res_types) + res_end = len(all_residues) + + del_start = len(all_deletions) + all_deletions.extend(dels) + del_end = len(all_deletions) + + sequences.append((seq_idx, taxon, res_start, res_end, del_start, del_end)) + + return MSA( + residues=np.array(all_residues, dtype=MSAResidue), + deletions=np.array(all_deletions, dtype=MSADeletion), + sequences=np.array(sequences, dtype=MSASequence), + ) + + +def _make_tokens(asym_ids, res_idxs): + """Create a Token array with ``asym_id``, ``res_idx``, and ``token_idx`` populated. + + Parameters + ---------- + asym_ids : list[int] + Per-token chain identifier. + res_idxs : list[int] + Per-token residue index (0-based within its chain). + """ + n = len(asym_ids) + tokens = np.zeros(n, dtype=Token) + tokens["asym_id"] = asym_ids + tokens["res_idx"] = res_idxs + tokens["token_idx"] = np.arange(n) + return tokens + + +def _make_data(chain_specs, msas): + """Build a minimal data object accepted by construct_paired_msa. + + Parameters + ---------- + chain_specs : list[list[int]] + Per-chain structure residue types (must match MSA query, seq 0). + msas : dict[int, MSA] + """ + chains_list, residues_list, asym_ids, res_idxs = [], [], [], [] + offset = 0 + for chain_id, res_types in enumerate(chain_specs): + n_res = len(res_types) + chains_list.append(_make_chain(chain_id, res_idx=offset, res_num=n_res)) + res = np.zeros(n_res, dtype=Residue) + res["res_type"] = res_types + for i in range(n_res): + res[i]["res_idx"] = i + residues_list.append(res) + for r in range(n_res): + asym_ids.append(chain_id) + res_idxs.append(r) + offset += n_res + + chains = np.concatenate(chains_list) + residues = np.concatenate(residues_list) + tokens = _make_tokens(asym_ids, res_idxs) + structure = SimpleNamespace(chains=chains, residues=residues) + record = SimpleNamespace(id="test_sample") + return SimpleNamespace(tokens=tokens, structure=structure, msa=msas, record=record) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def two_chain_taxonomy_data(): + """2 chains x 3 residues, each with 3 MSA sequences. + + seq 0 = query, seq 1 = Human (9606), seq 2 = unique per chain. + """ + chain0_query = [1, 2, 3] + chain1_query = [6, 7, 8] + + msa_chain0 = _make_msa( + residue_types=[chain0_query, [1, 2, 4], [5, 2, 3]], + taxonomies=[-1, 9606, 7227], + ) + msa_chain1 = _make_msa( + residue_types=[chain1_query, [6, 7, 9], [10, 7, 8]], + taxonomies=[-1, 9606, 7955], + ) + data = _make_data( + chain_specs=[chain0_query, chain1_query], + msas={0: msa_chain0, 1: msa_chain1}, + ) + return data + + +@pytest.fixture() +def met_unk_mismatch_data(): + """Chain 0: structure [MET, ALA, GLY], MSA query [UNK, ALA, GLY].""" + met_id = const.token_ids["MET"] + unk_id = const.token_ids["UNK"] + ala_id = const.token_ids["ALA"] + gly_id = const.token_ids["GLY"] + + msa_chain0 = _make_msa( + residue_types=[[unk_id, ala_id, gly_id]], + taxonomies=[-1], + ) + data = _make_data( + chain_specs=[[met_id, ala_id, gly_id]], + msas={0: msa_chain0}, + ) + return data, msa_chain0 + + +@pytest.fixture() +def data_with_deletions(): + """1 chain, 3 sequences: query (no deletions), seq 1 (1 del), seq 2 (2 dels).""" + query = [1, 2, 3, 4, 5] + msa = _make_msa( + residue_types=[ + query, + [1, 2, 4, 4, 5], + [2, 2, 3, 4, 6], + ], + taxonomies=[-1, -1, -1], + deletions_per_seq=[ + [], + [(2, 3)], + [(0, 1), (4, 5)], + ], + ) + return _make_data(chain_specs=[query], msas={0: msa}) + + +# --------------------------------------------------------------------------- +# Test 1: visited set -- taxonomy-assigned seqs excluded from available pool +# +# Bug: the old comprehension {(c, s) for c, items in taxonomy_map for s in items} +# produced {(taxon, (chain_id, seq_idx)), ...}. The downstream membership +# check (chain_id, seq_idx) not in visited never matched, so every sequence +# leaked into the unpaired pool. +# --------------------------------------------------------------------------- + + +def test_visited_unpaired_fill_excludes_taxonomy_assigned(two_chain_taxonomy_data): + """Sequences assigned to taxonomy groups should NOT appear as + unpaired fill in other rows.""" + gap = const.token_ids["-"] + rng = np.random.default_rng(42) + msa_data, _, _ = construct_paired_msa(two_chain_taxonomy_data, random=rng, max_seqs=100) + + n_chain0_tokens = 3 + n_rows = msa_data.shape[1] + human_chain0 = (1, 2, 4) + human_chain1 = (6, 7, 9) + + for row_idx in range(2, n_rows): + chain0_row = tuple(msa_data[:n_chain0_tokens, row_idx].tolist()) + chain1_row = tuple(msa_data[n_chain0_tokens:, row_idx].tolist()) + is_gap_chain0 = all(v == gap for v in chain0_row) + is_gap_chain1 = all(v == gap for v in chain1_row) + if not is_gap_chain0: + assert chain0_row != human_chain0, f"Row {row_idx}: chain 0 Human seq reused as unpaired fill" + if not is_gap_chain1: + assert chain1_row != human_chain1, f"Row {row_idx}: chain 1 Human seq reused as unpaired fill" + + +# --------------------------------------------------------------------------- +# Test 2: in-place mutation -- caller's MSA data must not be modified +# +# Bug: msa_residues was a direct view into data.msa[chain_id].residues. +# Modifying msa_residues[...]["res_type"] = ... mutated the caller's data. +# --------------------------------------------------------------------------- + + +def test_copy_on_write_original_msa_residues_unchanged(met_unk_mismatch_data): + """After construct_paired_msa, the caller's MSA residue array must + still contain the original UNK token, not the patched MET.""" + data, original_msa = met_unk_mismatch_data + original_residues_before = original_msa.residues["res_type"].copy() + + rng = np.random.default_rng(42) + construct_paired_msa(data, random=rng, max_seqs=10) + + np.testing.assert_array_equal( + original_msa.residues["res_type"], + original_residues_before, + err_msg="construct_paired_msa mutated the caller's MSA residues in-place.", + ) + + +def test_copy_on_write_idempotent_on_double_call(met_unk_mismatch_data): + """Calling construct_paired_msa twice on the same data must produce + identical results, proving the first call didn't corrupt the input.""" + data, _ = met_unk_mismatch_data + + rng1 = np.random.default_rng(42) + msa1, del1, paired1 = construct_paired_msa(data, random=rng1, max_seqs=10) + + rng2 = np.random.default_rng(42) + msa2, del2, paired2 = construct_paired_msa(data, random=rng2, max_seqs=10) + + np.testing.assert_array_equal(msa1, msa2, err_msg="Double-call MSA mismatch") + np.testing.assert_array_equal(del1, del2, err_msg="Double-call deletion mismatch") + np.testing.assert_array_equal(paired1, paired2, err_msg="Double-call paired mismatch") + + +# --------------------------------------------------------------------------- +# Test 3: deletion indexing -- must slice from original array every iteration +# +# Bug (inference-specific): chain_deletions = chain_deletions[del_start:del_end] +# progressively shrank the array. After the query (del_start=0, del_end=0), +# chain_deletions became empty; all subsequent deletions were silently lost. +# With the bug, n_nonzero == 0. With the fix, n_nonzero == 3. +# --------------------------------------------------------------------------- + + +def test_deletion_all_present_in_output(data_with_deletions): + """Every deletion entry must appear in the output tensor.""" + rng = np.random.default_rng(42) + _, del_data, _ = construct_paired_msa(data_with_deletions, random=rng, max_seqs=10) + + assert del_data[2, 1].item() == 3, ( + f"Seq 1 deletion at res_idx=2: expected 3, got {del_data[2, 1].item()}. " + "This fails if chain_deletions is progressively shrunk." + ) + assert del_data[0, 2].item() == 1, f"Seq 2 deletion at res_idx=0: expected 1, got {del_data[0, 2].item()}" + assert del_data[4, 2].item() == 5, f"Seq 2 deletion at res_idx=4: expected 5, got {del_data[4, 2].item()}" + + +def test_deletion_query_row_has_no_deletions(data_with_deletions): + """The query row (row 0) should have zero deletions everywhere.""" + rng = np.random.default_rng(42) + _, del_data, _ = construct_paired_msa(data_with_deletions, random=rng, max_seqs=10) + + query_dels = del_data[:, 0] + assert (query_dels == 0).all(), f"Query row should have zero deletions, got {query_dels}" + + +def test_deletion_nonzero_count_matches_expected(data_with_deletions): + """With the bug, n_nonzero would be 0 (all deletions lost). + With the fix, n_nonzero should be exactly 3.""" + rng = np.random.default_rng(42) + _, del_data, _ = construct_paired_msa(data_with_deletions, random=rng, max_seqs=10) + + n_nonzero = (del_data != 0).sum().item() + assert n_nonzero == 3, ( + f"Expected 3 nonzero deletion entries, got {n_nonzero}. " + "A value of 0 indicates the progressive-shrink bug is still present." + ) + + +# --------------------------------------------------------------------------- +# Real-data integration tests +# +# Load the 8ayv homodimer (2 protein chains sharing one MSA with 4611 seqs, +# 2309 with taxonomy, 716 with deletions) and run construct_paired_msa on +# it to verify the fixes work on production-format data. +# --------------------------------------------------------------------------- + + +def _load_real_sample(sample_dir): + """Build a data object from a processed sample directory. + + Uses Token dtype (inference featurizer) instead of TokenV2. + """ + import json + + manifest = json.loads((sample_dir / "manifest.json").read_text()) + rec = manifest["records"][0] + + structure = StructureV2.load(sample_dir / "structures" / f"{rec['id']}.npz") + + msa_dir = sample_dir / "msa" + msas = {} + for chain_info in rec["chains"]: + msa_id = chain_info.get("msa_id", -1) + chain_id = chain_info["chain_id"] + if msa_id != -1: + msa_path = msa_dir / f"{msa_id}.npz" + if msa_path.exists(): + msas[chain_id] = MSA.load(msa_path) + + asym_ids, res_idxs = [], [] + for chain in structure.chains: + cid = int(chain["asym_id"]) + if cid not in msas: + continue + for r in range(int(chain["res_num"])): + asym_ids.append(cid) + res_idxs.append(r) + + tokens = _make_tokens(asym_ids, res_idxs) + record = SimpleNamespace(id=rec["id"]) + data = SimpleNamespace(tokens=tokens, structure=structure, msa=msas, record=record) + return data, msas + + +@pytest.fixture() +def real_8ayv_data(test_cp_training_base_data_dir_boltz2): + """Load the 8ayv homodimer from the real test data cache.""" + return _load_real_sample(test_cp_training_base_data_dir_boltz2 / "processed_8ayv") + + +def test_real_data_no_mutation(real_8ayv_data): + """MSA residue arrays must not be modified in-place by construct_paired_msa.""" + data, original_msas = real_8ayv_data + + snapshots = {cid: msa.residues["res_type"].copy() for cid, msa in original_msas.items()} + + rng = np.random.default_rng(42) + construct_paired_msa(data, random=rng, max_seqs=512) + + for cid, before in snapshots.items(): + np.testing.assert_array_equal( + original_msas[cid].residues["res_type"], + before, + err_msg=f"Chain {cid}: MSA residues mutated in-place", + ) + + +def test_real_data_deletions_preserved(real_8ayv_data): + """The output deletion matrix must contain nonzero entries when the + input MSA has sequences with deletions (8ayv has 716).""" + data, _ = real_8ayv_data + + rng = np.random.default_rng(42) + _, del_data, _ = construct_paired_msa(data, random=rng, max_seqs=512) + + n_nonzero = (del_data != 0).sum().item() + assert n_nonzero > 0, ( + "del_data is all zeros despite 716 input sequences having deletions. " + "This indicates the deletion extraction loop is broken." + ) + + +def test_real_data_idempotent(real_8ayv_data): + """Two calls with the same seed must produce identical output.""" + data, _ = real_8ayv_data + + rng1 = np.random.default_rng(42) + msa1, del1, paired1 = construct_paired_msa(data, random=rng1, max_seqs=512) + + rng2 = np.random.default_rng(42) + msa2, del2, paired2 = construct_paired_msa(data, random=rng2, max_seqs=512) + + np.testing.assert_array_equal(msa1, msa2, err_msg="Idempotency: MSA mismatch") + np.testing.assert_array_equal(del1, del2, err_msg="Idempotency: deletion mismatch") + np.testing.assert_array_equal(paired1, paired2, err_msg="Idempotency: paired mismatch") + + +def test_real_data_output_shape(real_8ayv_data): + """Output arrays must have consistent shapes: (n_tokens, n_rows).""" + data, _ = real_8ayv_data + n_tokens = len(data.tokens) + + rng = np.random.default_rng(42) + max_seqs = 128 + msa_data, del_data, paired_data = construct_paired_msa(data, random=rng, max_seqs=max_seqs) + + assert msa_data.shape[0] == n_tokens, f"msa_data dim 0 should be n_tokens={n_tokens}, got {msa_data.shape[0]}" + assert del_data.shape == msa_data.shape, f"del_data shape {del_data.shape} != msa_data shape {msa_data.shape}" + n_rows = msa_data.shape[1] + assert n_rows >= 1, "Must have at least the query row" + assert n_rows <= max_seqs, f"n_rows={n_rows} exceeds max_seqs={max_seqs}" + + +def test_real_data_query_row_matches_structure(real_8ayv_data): + """Row 0 (query) should contain residue types matching the structure.""" + data, _ = real_8ayv_data + + rng = np.random.default_rng(42) + msa_data, _, _ = construct_paired_msa(data, random=rng, max_seqs=128) + + for chain_id in sorted(data.msa.keys()): + chain = data.structure.chains[chain_id] + res_start = int(chain["res_idx"]) + res_num = int(chain["res_num"]) + expected = data.structure.residues[res_start : res_start + res_num]["res_type"] + + token_mask = data.tokens["asym_id"] == chain_id + query_row = msa_data[token_mask, 0] + + np.testing.assert_array_equal( + query_row, + expected, + err_msg=f"Chain {chain_id}: query row doesn't match structure residues", + ) diff --git a/tests/data/feature/test_featurizerv2_train.py b/tests/data/feature/test_featurizerv2_train.py new file mode 100644 index 000000000..c72e2e29f --- /dev/null +++ b/tests/data/feature/test_featurizerv2_train.py @@ -0,0 +1,525 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for featurizerv2_train.py bug fixes. + +Covers: + 1. `visited` set construction: taxonomy-assigned sequences must be excluded + from the unpaired `available` pool. + 2. In-place mutation: construct_paired_msa must not mutate caller's MSA data. + 3. Deletion extraction: all sequences' deletions must be captured. +""" + +from types import SimpleNamespace + +import numpy as np +import pytest + +from boltz.data import const +from boltz.data.feature.featurizerv2_train import construct_paired_msa +from boltz.data.types import ( + MSA, + Chain, + MSADeletion, + MSAResidue, + MSASequence, + Residue, + StructureV2, + TokenV2, +) + +# --------------------------------------------------------------------------- +# Helpers to build minimal synthetic data +# --------------------------------------------------------------------------- + + +def _make_chain(chain_id, res_idx, res_num): + """Create a single-element Chain structured array. + + Parameters + ---------- + chain_id : int + Value for the ``asym_id`` field. + res_idx : int + Starting index of this chain's residues in the global residue array. + res_num : int + Number of residues in the chain. + """ + arr = np.zeros(1, dtype=Chain) + arr[0]["asym_id"] = chain_id + arr[0]["res_idx"] = res_idx + arr[0]["res_num"] = res_num + return arr + + +def _make_msa(residue_types, taxonomies, deletions_per_seq=None): + """Build an MSA from explicit residue types, taxonomies, and deletions.""" + if deletions_per_seq is None: + deletions_per_seq = [[] for _ in residue_types] + + all_residues, all_deletions, sequences = [], [], [] + for seq_idx, (res_types, taxon, dels) in enumerate(zip(residue_types, taxonomies, deletions_per_seq)): + res_start = len(all_residues) + all_residues.extend(res_types) + res_end = len(all_residues) + + del_start = len(all_deletions) + all_deletions.extend(dels) + del_end = len(all_deletions) + + sequences.append((seq_idx, taxon, res_start, res_end, del_start, del_end)) + + return MSA( + residues=np.array(all_residues, dtype=MSAResidue), + deletions=np.array(all_deletions, dtype=MSADeletion), + sequences=np.array(sequences, dtype=MSASequence), + ) + + +def _make_tokens(asym_ids, res_idxs): + """Create a TokenV2 array with ``asym_id``, ``res_idx``, and ``token_idx`` populated. + + Parameters + ---------- + asym_ids : list[int] + Per-token chain identifier. + res_idxs : list[int] + Per-token residue index (0-based within its chain). + """ + n = len(asym_ids) + tokens = np.zeros(n, dtype=TokenV2) + tokens["asym_id"] = asym_ids + tokens["res_idx"] = res_idxs + tokens["token_idx"] = np.arange(n) + return tokens + + +def _make_data(chain_specs, msas): + """Build a minimal data object accepted by construct_paired_msa. + + Parameters + ---------- + chain_specs : list[list[int]] + Per-chain list of structure residue types. The length of each + inner list determines the chain's residue count. These MUST match + the MSA query sequence (seq 0) for each chain to avoid the dummy-MSA + fallback. + msas : dict[int, MSA] + MSA per chain_id. + """ + chains_list, residues_list, asym_ids, res_idxs = [], [], [], [] + offset = 0 + for chain_id, res_types in enumerate(chain_specs): + n_res = len(res_types) + chains_list.append(_make_chain(chain_id, res_idx=offset, res_num=n_res)) + res = np.zeros(n_res, dtype=Residue) + res["res_type"] = res_types + for i in range(n_res): + res[i]["res_idx"] = i + residues_list.append(res) + for r in range(n_res): + asym_ids.append(chain_id) + res_idxs.append(r) + offset += n_res + + chains = np.concatenate(chains_list) + residues = np.concatenate(residues_list) + tokens = _make_tokens(asym_ids, res_idxs) + structure = SimpleNamespace(chains=chains, residues=residues) + record = SimpleNamespace(id="test_sample") + return SimpleNamespace(tokens=tokens, structure=structure, msa=msas, record=record) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def two_chain_taxonomy_data(): + """2 chains x 3 residues, each with 3 MSA sequences. + + seq 0 = query (no taxonomy) -- structure residues MUST match this + seq 1 = Human (taxonomy 9606) -- appears in both chains + seq 2 = unique to chain (no taxonomy match across chains) + """ + chain0_query = [1, 2, 3] + chain1_query = [6, 7, 8] + + msa_chain0 = _make_msa( + residue_types=[ + chain0_query, # seq 0: query + [1, 2, 4], # seq 1: Human + [5, 2, 3], # seq 2: Fly (only chain 0) + ], + taxonomies=[-1, 9606, 7227], + ) + msa_chain1 = _make_msa( + residue_types=[ + chain1_query, # seq 0: query + [6, 7, 9], # seq 1: Human + [10, 7, 8], # seq 2: Zebrafish (only chain 1) + ], + taxonomies=[-1, 9606, 7955], + ) + + data = _make_data( + chain_specs=[chain0_query, chain1_query], + msas={0: msa_chain0, 1: msa_chain1}, + ) + return data + + +@pytest.fixture() +def met_unk_mismatch_data(): + """Chain 0 has 3 residues. Structure says [MET, ALA, GLY]. + MSA query says [UNK, ALA, GLY]. The MET/UNK fix should patch the + MSA query to match, but only on a copy.""" + met_id = const.token_ids["MET"] + unk_id = const.token_ids["UNK"] + ala_id = const.token_ids["ALA"] + gly_id = const.token_ids["GLY"] + + structure_res_types = [met_id, ala_id, gly_id] + msa_query_types = [unk_id, ala_id, gly_id] + + msa_chain0 = _make_msa( + residue_types=[msa_query_types], + taxonomies=[-1], + ) + + data = _make_data( + chain_specs=[structure_res_types], + msas={0: msa_chain0}, + ) + return data, msa_chain0 + + +@pytest.fixture() +def data_with_deletions(): + """1 chain, 3 sequences with known deletions. + + seq 0 (query): no deletions (del_start=0, del_end=0) + seq 1: 1 deletion at res_idx=2, count=3 + seq 2: 2 deletions at res_idx=0 (count=1) and res_idx=4 (count=5) + + Structure residues match seq 0 so the MSA is not replaced by a dummy. + """ + query = [1, 2, 3, 4, 5] + msa = _make_msa( + residue_types=[ + query, # seq 0: query + [1, 2, 4, 4, 5], # seq 1 + [2, 2, 3, 4, 6], # seq 2 + ], + taxonomies=[-1, -1, -1], + deletions_per_seq=[ + [], # seq 0: no deletions + [(2, 3)], # seq 1: 3 insertions before col 2 + [(0, 1), (4, 5)], # seq 2: 1 before col 0, 5 before col 4 + ], + ) + return _make_data(chain_specs=[query], msas={0: msa}) + + +# --------------------------------------------------------------------------- +# Test 1: visited set -- taxonomy-assigned seqs excluded from available pool +# +# Bug: the old comprehension {(c, s) for c, items in taxonomy_map for s in items} +# produced {(taxon, (chain_id, seq_idx)), ...} because after sorted(), c was +# the taxon key and s was a (chain_id, seq_idx) tuple. The downstream +# (chain_id, seq_idx) not in visited check never matched, so every sequence +# leaked into available and could appear as both a paired and an unpaired row. +# --------------------------------------------------------------------------- + + +def test_visited_no_duplicate_seq_indices_in_pairing(two_chain_taxonomy_data): + """Each (chain_id, seq_idx) pair must appear at most once across all + rows of the pairing matrix.""" + rng = np.random.RandomState(42) + msa_data, _, _ = construct_paired_msa(two_chain_taxonomy_data, random=rng, max_seqs=100) + + n_chain0_tokens = 3 + + chain0_row1 = tuple(msa_data[:n_chain0_tokens, 1].tolist()) + chain1_row1 = tuple(msa_data[n_chain0_tokens:, 1].tolist()) + assert chain0_row1 == (1, 2, 4), "Row 1 chain 0 should be Human seq" + assert chain1_row1 == (6, 7, 9), "Row 1 chain 1 should be Human seq" + + +def test_visited_unpaired_fill_excludes_taxonomy_assigned(two_chain_taxonomy_data): + """Sequences assigned to taxonomy groups should NOT appear as + unpaired fill in other rows.""" + gap = const.token_ids["-"] + rng = np.random.RandomState(42) + msa_data, _, _ = construct_paired_msa(two_chain_taxonomy_data, random=rng, max_seqs=100) + + n_chain0_tokens = 3 + n_rows = msa_data.shape[1] + + human_chain0 = (1, 2, 4) + human_chain1 = (6, 7, 9) + + for row_idx in range(2, n_rows): + chain0_row = tuple(msa_data[:n_chain0_tokens, row_idx].tolist()) + chain1_row = tuple(msa_data[n_chain0_tokens:, row_idx].tolist()) + is_gap_chain0 = all(v == gap for v in chain0_row) + is_gap_chain1 = all(v == gap for v in chain1_row) + if not is_gap_chain0: + assert chain0_row != human_chain0, f"Row {row_idx}: chain 0 Human seq reused as unpaired fill" + if not is_gap_chain1: + assert chain1_row != human_chain1, f"Row {row_idx}: chain 1 Human seq reused as unpaired fill" + + +# --------------------------------------------------------------------------- +# Test 2: in-place mutation -- caller's MSA data must not be modified +# +# Bug: msa_residues was a direct view into data.msa[chain_id].residues. +# Modifying msa_residues[...]["res_type"] = ... mutated the caller's data. +# If construct_paired_msa were called twice (e.g. retry), the second call +# would see residues already patched by the first. +# --------------------------------------------------------------------------- + + +def test_copy_on_write_original_msa_residues_unchanged(met_unk_mismatch_data): + """After construct_paired_msa, the caller's MSA residue array must + still contain the original UNK token, not the patched MET.""" + data, original_msa = met_unk_mismatch_data + original_residues_before = original_msa.residues["res_type"].copy() + + rng = np.random.RandomState(42) + construct_paired_msa(data, random=rng, max_seqs=10) + + np.testing.assert_array_equal( + original_msa.residues["res_type"], + original_residues_before, + err_msg=( + "construct_paired_msa mutated the caller's MSA residues in-place. " + "The first residue should still be UNK, not patched to MET." + ), + ) + + +def test_copy_on_write_output_msa_has_patched_values(met_unk_mismatch_data): + """The output MSA rows should reflect the patched MET, not UNK.""" + data, _ = met_unk_mismatch_data + met_id = const.token_ids["MET"] + + rng = np.random.RandomState(42) + msa_data, _, _ = construct_paired_msa(data, random=rng, max_seqs=10) + + assert msa_data[0, 0].item() == met_id, f"Query row should have MET (patched from UNK), got {msa_data[0, 0].item()}" + + +def test_copy_on_write_idempotent_on_double_call(met_unk_mismatch_data): + """Calling construct_paired_msa twice on the same data must produce + identical results, proving the first call didn't corrupt the input.""" + data, _ = met_unk_mismatch_data + + rng1 = np.random.RandomState(42) + msa1, del1, paired1 = construct_paired_msa(data, random=rng1, max_seqs=10) + + rng2 = np.random.RandomState(42) + msa2, del2, paired2 = construct_paired_msa(data, random=rng2, max_seqs=10) + + np.testing.assert_array_equal(msa1, msa2, err_msg="Double-call MSA mismatch") + np.testing.assert_array_equal(del1, del2, err_msg="Double-call deletion mismatch") + np.testing.assert_array_equal(paired1, paired2, err_msg="Double-call paired mismatch") + + +# --------------------------------------------------------------------------- +# Test 3: deletion extraction -- all sequences' deletions must be captured +# +# The train featurizer's inner loop was correct but had a dead-code outer +# assignment (`chain_deletions = chain_msa.deletions`) that was removed. +# These tests verify the deletion data is fully preserved. +# --------------------------------------------------------------------------- + + +def test_deletion_all_present_in_output(data_with_deletions): + """The deletion tensor must be nonzero at the expected positions.""" + rng = np.random.RandomState(42) + _, del_data, _ = construct_paired_msa(data_with_deletions, random=rng, max_seqs=10) + + assert del_data[2, 1].item() == 3, f"Seq 1 deletion at res_idx=2 should be 3, got {del_data[2, 1].item()}" + assert del_data[0, 2].item() == 1, f"Seq 2 deletion at res_idx=0 should be 1, got {del_data[0, 2].item()}" + assert del_data[4, 2].item() == 5, f"Seq 2 deletion at res_idx=4 should be 5, got {del_data[4, 2].item()}" + + +def test_deletion_query_row_has_no_deletions(data_with_deletions): + """The query row (row 0) should have zero deletions everywhere.""" + rng = np.random.RandomState(42) + _, del_data, _ = construct_paired_msa(data_with_deletions, random=rng, max_seqs=10) + + query_dels = del_data[:, 0] + assert (query_dels == 0).all(), f"Query row should have zero deletions, got {query_dels}" + + +def test_deletion_total_nonzero_entries(data_with_deletions): + """Exactly 3 nonzero deletion entries (1 from seq 1 + 2 from seq 2).""" + rng = np.random.RandomState(42) + _, del_data, _ = construct_paired_msa(data_with_deletions, random=rng, max_seqs=10) + + n_nonzero = (del_data != 0).sum().item() + assert n_nonzero == 3, f"Expected 3 nonzero deletion entries, got {n_nonzero}" + + +# --------------------------------------------------------------------------- +# Real-data integration tests +# +# Load the 8ayv homodimer (2 protein chains sharing one MSA with 4611 seqs, +# 2309 with taxonomy, 716 with deletions) and run construct_paired_msa on +# it to verify the fixes work on production-format data. +# --------------------------------------------------------------------------- + + +def _load_real_sample(sample_dir): + """Build a data object from a processed sample directory. + + Loads the manifest, structure, and MSAs, then creates minimal tokens + (TokenV2 with asym_id and res_idx) so construct_paired_msa can run. + + Returns (data, original_msas) where original_msas is a dict of + independently loaded MSA objects (mirrors the real pipeline where + each chain gets its own MSA.load() call). + """ + import json + + manifest = json.loads((sample_dir / "manifest.json").read_text()) + rec = manifest["records"][0] + + structure = StructureV2.load(sample_dir / "structures" / f"{rec['id']}.npz") + + msa_dir = sample_dir / "msa" + msas = {} + for chain_info in rec["chains"]: + msa_id = chain_info.get("msa_id", -1) + chain_id = chain_info["chain_id"] + if msa_id != -1: + msa_path = msa_dir / f"{msa_id}.npz" + if msa_path.exists(): + msas[chain_id] = MSA.load(msa_path) + + asym_ids, res_idxs = [], [] + for chain in structure.chains: + cid = int(chain["asym_id"]) + if cid not in msas: + continue + for r in range(int(chain["res_num"])): + asym_ids.append(cid) + res_idxs.append(r) + + tokens = _make_tokens(asym_ids, res_idxs) + record = SimpleNamespace(id=rec["id"]) + data = SimpleNamespace(tokens=tokens, structure=structure, msa=msas, record=record) + return data, msas + + +@pytest.fixture() +def real_8ayv_data(test_cp_training_base_data_dir_boltz2): + """Load the 8ayv homodimer from the real test data cache.""" + return _load_real_sample(test_cp_training_base_data_dir_boltz2 / "processed_8ayv") + + +def test_real_data_no_mutation(real_8ayv_data): + """MSA residue arrays must not be modified in-place by construct_paired_msa.""" + data, original_msas = real_8ayv_data + + snapshots = {cid: msa.residues["res_type"].copy() for cid, msa in original_msas.items()} + + rng = np.random.RandomState(42) + construct_paired_msa(data, random=rng, max_seqs=512) + + for cid, before in snapshots.items(): + np.testing.assert_array_equal( + original_msas[cid].residues["res_type"], + before, + err_msg=f"Chain {cid}: MSA residues mutated in-place", + ) + + +def test_real_data_deletions_preserved(real_8ayv_data): + """The output deletion matrix must contain nonzero entries when the + input MSA has sequences with deletions (8ayv has 716).""" + data, _ = real_8ayv_data + + rng = np.random.RandomState(42) + _, del_data, _ = construct_paired_msa(data, random=rng, max_seqs=512) + + n_nonzero = (del_data != 0).sum().item() + assert n_nonzero > 0, ( + "del_data is all zeros despite 716 input sequences having deletions. " + "This indicates the deletion extraction loop is broken." + ) + + +def test_real_data_idempotent(real_8ayv_data): + """Two calls with the same seed must produce identical output.""" + data, _ = real_8ayv_data + + rng1 = np.random.RandomState(42) + msa1, del1, paired1 = construct_paired_msa(data, random=rng1, max_seqs=512) + + rng2 = np.random.RandomState(42) + msa2, del2, paired2 = construct_paired_msa(data, random=rng2, max_seqs=512) + + np.testing.assert_array_equal(msa1, msa2, err_msg="Idempotency: MSA mismatch") + np.testing.assert_array_equal(del1, del2, err_msg="Idempotency: deletion mismatch") + np.testing.assert_array_equal(paired1, paired2, err_msg="Idempotency: paired mismatch") + + +def test_real_data_output_shape(real_8ayv_data): + """Output arrays must have consistent shapes: (n_tokens, n_rows).""" + data, _ = real_8ayv_data + n_tokens = len(data.tokens) + + rng = np.random.RandomState(42) + max_seqs = 128 + msa_data, del_data, paired_data = construct_paired_msa(data, random=rng, max_seqs=max_seqs) + + assert msa_data.shape[0] == n_tokens, f"msa_data dim 0 should be n_tokens={n_tokens}, got {msa_data.shape[0]}" + assert del_data.shape == msa_data.shape, f"del_data shape {del_data.shape} != msa_data shape {msa_data.shape}" + n_rows = msa_data.shape[1] + assert n_rows >= 1, "Must have at least the query row" + assert n_rows <= max_seqs, f"n_rows={n_rows} exceeds max_seqs={max_seqs}" + + +def test_real_data_query_row_matches_structure(real_8ayv_data): + """Row 0 (query) should contain residue types matching the structure.""" + data, _ = real_8ayv_data + + rng = np.random.RandomState(42) + msa_data, _, _ = construct_paired_msa(data, random=rng, max_seqs=128) + + for chain_id in sorted(data.msa.keys()): + chain = data.structure.chains[chain_id] + res_start = int(chain["res_idx"]) + res_num = int(chain["res_num"]) + expected = data.structure.residues[res_start : res_start + res_num]["res_type"] + + token_mask = data.tokens["asym_id"] == chain_id + query_row = msa_data[token_mask, 0] + + np.testing.assert_array_equal( + query_row, + expected, + err_msg=f"Chain {chain_id}: query row doesn't match structure residues", + ) diff --git a/tests/data/write/test_writer.py b/tests/data/write/test_writer.py new file mode 100644 index 000000000..cfa49a6cb --- /dev/null +++ b/tests/data/write/test_writer.py @@ -0,0 +1,330 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for BoltzWriter multi-diffusion-sample correctness. + +Verifies that the current BoltzWriter produces identical CIF output to the +dev-v2 reference version. The dev-v2 writer crashes with ``KeyError`` when +``diffusion_samples > 1`` and no confidence scores (``idx_to_rank`` is sized +by ``len(records)`` instead of the number of diffusion samples). We work +around this by calling the dev-v2 reference one sample at a time, then verify +the current writer (called once with all samples) produces identical output. +""" + +from dataclasses import replace +from pathlib import Path + +import numpy as np +import pytest +import torch +from torch import Tensor + +from boltz.data.types import ( + AtomV2, + BondV2, + Chain, + ChainInfo, + Coords, + Ensemble, + Interface, + Record, + Residue, + StructureInfo, + StructureV2, +) +from boltz.data.write.mmcif import to_mmcif +from boltz.data.write.writer import BoltzWriter + +# Atom names used to populate multi-atom residues without duplicates. +_ATOM_NAMES = ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "OG"] + + +# --------------------------------------------------------------------------- +# Reference: verbatim dev-v2 write_on_batch_end logic +# --------------------------------------------------------------------------- +def _dev_v2_write_on_batch_end( + data_dir: Path, + output_dir: Path, + prediction: dict[str, Tensor], + batch: dict[str, Tensor], +) -> None: + """Faithful reproduction of the dev-v2 BoltzWriter.write_on_batch_end. + + Extracted as a plain function so we can call it with single-sample slices + to avoid the ``idx_to_rank`` KeyError. + """ + if prediction["exception"]: + return + + records: list[Record] = batch["record"] + + coords = prediction["coords"] + coords = coords.unsqueeze(0) + + pad_masks = prediction["masks"] + + if "confidence_score" in prediction: + argsort = torch.argsort(prediction["confidence_score"], descending=True) + idx_to_rank = {idx.item(): rank for rank, idx in enumerate(argsort)} + else: + idx_to_rank = {i: i for i in range(len(records))} + + for record, coord, pad_mask in zip(records, coords, pad_masks): + path = data_dir / f"{record.id}.npz" + structure: StructureV2 = StructureV2.load(path) + + chain_map = {} + for i, mask in enumerate(structure.mask): + if mask: + chain_map[len(chain_map)] = i + + structure = structure.remove_invalid_chains() + + for model_idx in range(coord.shape[0]): + model_coord = coord[model_idx] + coord_unpad = model_coord[pad_mask.bool()] + coord_unpad = coord_unpad.cpu().numpy() + + atoms = structure.atoms + atoms["coords"] = coord_unpad + atoms["is_present"] = True + structure: StructureV2 + coord_unpad = [(x,) for x in coord_unpad] + coord_unpad = np.array(coord_unpad, dtype=Coords) + + residues = structure.residues + residues["is_present"] = True + + interfaces = np.array([], dtype=Interface) + new_structure: StructureV2 = replace( + structure, + atoms=atoms, + residues=residues, + interfaces=interfaces, + coords=coord_unpad, + ) + + chain_info = [] + for chain in new_structure.chains: + old_chain_idx = chain_map[chain["asym_id"]] + old_chain_info = record.chains[old_chain_idx] + new_chain_info = replace( + old_chain_info, + chain_id=int(chain["asym_id"]), + valid=True, + ) + chain_info.append(new_chain_info) + + struct_dir = output_dir / record.id + struct_dir.mkdir(parents=True, exist_ok=True) + + plddts = None + if "plddt" in prediction: + plddts = prediction["plddt"][model_idx] + + outname = f"{record.id}_model_{idx_to_rank[model_idx]}" + cif_path = struct_dir / f"{outname}.cif" + with cif_path.open("w") as f: + f.write(to_mmcif(new_structure, plddts=plddts, boltz2=True)) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_structure_v2(n_atoms: int, n_residues: int, rng: np.random.Generator) -> StructureV2: + """Create a minimal valid StructureV2 with one chain. + + Each residue gets ``n_atoms // n_residues`` atoms with distinct names + drawn from ``_ATOM_NAMES`` so that ``to_mmcif`` does not reject them. + """ + atoms_per_res = n_atoms // n_residues + assert atoms_per_res * n_residues == n_atoms + assert atoms_per_res <= len(_ATOM_NAMES) + + atoms = np.zeros(n_atoms, dtype=AtomV2) + atoms["coords"] = rng.standard_normal((n_atoms, 3)).astype(np.float32) + atoms["is_present"] = True + for i in range(n_atoms): + atoms[i]["name"] = _ATOM_NAMES[i % atoms_per_res] + + residues = np.zeros(n_residues, dtype=Residue) + for i in range(n_residues): + residues[i]["name"] = "ALA" + residues[i]["res_type"] = 0 + residues[i]["res_idx"] = i + residues[i]["atom_idx"] = i * atoms_per_res + residues[i]["atom_num"] = atoms_per_res + residues[i]["atom_center"] = i * atoms_per_res + residues[i]["atom_disto"] = i * atoms_per_res + residues[i]["is_standard"] = True + residues[i]["is_present"] = True + + chains = np.zeros(1, dtype=Chain) + chains[0]["name"] = "A" + chains[0]["mol_type"] = 0 + chains[0]["entity_id"] = 0 + chains[0]["sym_id"] = 0 + chains[0]["asym_id"] = 0 + chains[0]["atom_idx"] = 0 + chains[0]["atom_num"] = n_atoms + chains[0]["res_idx"] = 0 + chains[0]["res_num"] = n_residues + chains[0]["cyclic_period"] = 0 + + bonds = np.array([], dtype=BondV2) + interfaces = np.array([], dtype=Interface) + mask = np.array([True]) + coords = np.array([(c,) for c in atoms["coords"]], dtype=Coords) + ensemble = np.zeros(1, dtype=Ensemble) + ensemble[0]["atom_coord_idx"] = 0 + ensemble[0]["atom_num"] = n_atoms + + return StructureV2( + atoms=atoms, + bonds=bonds, + residues=residues, + chains=chains, + interfaces=interfaces, + mask=mask, + coords=coords, + ensemble=ensemble, + ) + + +def _make_record(record_id: str, n_residues: int) -> Record: + chain_info = ChainInfo( + chain_id=0, + chain_name="A", + mol_type=0, + cluster_id=0, + msa_id=0, + num_residues=n_residues, + valid=True, + ) + return Record( + id=record_id, + structure=StructureInfo(), + chains=[chain_info], + interfaces=[], + ) + + +def _save_structure_npz(structure: StructureV2, path: Path) -> None: + """Save StructureV2 to npz, omitting None-valued optional fields.""" + save_dict = {k: v for k, v in vars(structure).items() if v is not None} + np.savez(str(path), **save_dict) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("n_diffusion_samples", [1, 5]) +def test_boltz_writer_multi_sample_parity(tmp_path, n_diffusion_samples): + """Current writer (all samples at once) == dev-v2 reference (per sample). + + For each diffusion sample we call the dev-v2 reference with a single-sample + ``coords`` tensor (shape ``(1, N, 3)``) so ``unsqueeze(0)`` yields + ``(1, 1, N, 3)`` and ``idx_to_rank = {0: 0}`` — no crash. The current + writer receives all samples at once and writes ``model_0 … model_{k-1}``. + """ + rng = np.random.default_rng(42) + n_atoms = 30 + n_residues = 10 + n_pad = 2 + n_atoms_padded = n_atoms + n_pad + record_id = "test_struct" + + # Persist structure to disk + data_dir = tmp_path / "structures" + data_dir.mkdir() + structure = _make_structure_v2(n_atoms, n_residues, rng) + _save_structure_npz(structure, data_dir / record_id) + + record = _make_record(record_id, n_residues) + + # Deterministic random coords: (n_diffusion_samples, n_atoms_padded, 3) + torch.manual_seed(123) + all_coords = torch.randn(n_diffusion_samples, n_atoms_padded, 3) + pad_mask = torch.zeros(n_atoms_padded, dtype=torch.bool) + pad_mask[:n_atoms] = True + + # ---- Reference: call dev-v2 logic once per diffusion sample ---- + ref_dir = tmp_path / "ref_output" + ref_cifs: dict[int, str] = {} + for sample_idx in range(n_diffusion_samples): + sample_out = ref_dir / f"_sample_{sample_idx}" + _dev_v2_write_on_batch_end( + data_dir=data_dir, + output_dir=sample_out, + prediction={ + "exception": False, + "coords": all_coords[sample_idx : sample_idx + 1], + "masks": pad_mask.unsqueeze(0), + }, + batch={"record": [record]}, + ) + cif_path = sample_out / record_id / f"{record_id}_model_0.cif" + assert cif_path.exists(), f"Reference CIF missing: {cif_path}" + ref_cifs[sample_idx] = cif_path.read_text() + + # ---- Current writer: single call with all samples ---- + cur_dir = tmp_path / "cur_output" + cur_writer = BoltzWriter( + data_dir=str(data_dir), + output_dir=str(cur_dir), + output_format="mmcif", + boltz2=True, + ) + cur_writer.write_on_batch_end( + trainer=None, + pl_module=None, + prediction={ + "exception": False, + "coords": all_coords, + "masks": pad_mask.unsqueeze(0), + }, + batch_indices=None, + batch={"record": [record]}, + batch_idx=0, + dataloader_idx=0, + ) + + # ---- Compare ---- + cur_struct_dir = cur_dir / record_id + assert cur_struct_dir.exists(), f"Current output dir missing: {cur_struct_dir}" + cur_cif_files = sorted(cur_struct_dir.glob("*.cif")) + assert ( + len(cur_cif_files) == n_diffusion_samples + ), f"Expected {n_diffusion_samples} CIF files, found {len(cur_cif_files)}" + + for sample_idx in range(n_diffusion_samples): + cur_cif_path = cur_struct_dir / f"{record_id}_model_{sample_idx}.cif" + assert cur_cif_path.exists(), f"Current CIF missing: {cur_cif_path}" + cur_text = cur_cif_path.read_text() + assert cur_text == ref_cifs[sample_idx], ( + f"CIF mismatch for diffusion sample {sample_idx}:\n" + f" current file : {cur_cif_path}\n" + f" reference : dev-v2 per-sample call" + ) diff --git a/tests/distributed/__init__.py b/tests/distributed/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/tests/distributed/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/tests/distributed/data/test_dtensor_minimum_lddt_symmetry_coords.py b/tests/distributed/data/test_dtensor_minimum_lddt_symmetry_coords.py new file mode 100644 index 000000000..f21a6c6b4 --- /dev/null +++ b/tests/distributed/data/test_dtensor_minimum_lddt_symmetry_coords.py @@ -0,0 +1,355 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from __future__ import annotations + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard + +from boltz.data.mol import minimum_lddt_symmetry_coords as serial_minimum_lddt_symmetry_coords +from boltz.distributed.data.feature.symmetry import ( + minimum_lddt_symmetry_coords as dtensor_minimum_lddt_symmetry_coords, +) +from boltz.distributed.manager import DistributedManager +from boltz.testing.utils import ( + distribute_atom_features, + get_feature_placements, + random_features, + spawn_multiprocessing, +) + + +def _build_symmetry_features_for_batch( + all_coords_batch: torch.Tensor, + all_resolved_mask_batch: torch.Tensor, + crop_to_all_atom_map_batch: torch.Tensor, + chain_swaps_batch: list | None = None, +): + """Build symmetry features for a batch of samples. + + Parameters + ---------- + chain_swaps_batch : list or None + If None, each sample gets a single identity (no-op) chain swap ``[[]]``. + Otherwise, provide a per-sample list of swap combinations. + + """ + batch_size = all_coords_batch.shape[0] + if chain_swaps_batch is None: + chain_swaps_batch = [[[]] for _ in range(batch_size)] + amino_acids_symmetries_batch = [[] for _ in range(batch_size)] + ligand_symmetries_batch = [[] for _ in range(batch_size)] + + feats = { + "all_coords": all_coords_batch, + "all_resolved_mask": all_resolved_mask_batch, + "crop_to_all_atom_map": crop_to_all_atom_map_batch, + "chain_swaps": chain_swaps_batch, + "amino_acids_symmetries": amino_acids_symmetries_batch, + "ligand_symmetries": ligand_symmetries_batch, + } + return feats + + +def _make_two_chain_swaps(n_atoms: int) -> list: + """Build chain_swaps for one sample with two equal-length chains that can be swapped. + + Splits the atom range [0, n_atoms) into two halves (chain A and chain B) + and returns identity + the A<->B swap. Each swap entry is + (start1, end1, start2, end2, chainidx1, chainidx2). + """ + half = n_atoms // 2 + identity = [] + swap_ab = [ + (0, half, half, 2 * half, 0, 1), + (half, 2 * half, 0, half, 1, 0), + ] + return [identity, swap_ab] + + +_atom_keys = {"atom_pad_mask"} +_placements = get_feature_placements(atom_keys=_atom_keys, token_keys=set()) +_placements_atom_features = _placements["atom_features"] +_placements_cp_atom_features = _placements["cp_atom_features"] + +_placements_sample_coords = {"sample_coords": (Shard(0), Shard(1), Replicate())} +_placements_cp_sample_coords = {"sample_coords": (Shard(0), Replicate())} + + +def parallel_assert_minimum_lddt_symmetry_coords(rank, payload): + """Test distributed minimum_lddt_symmetry_coords against serial Boltz-2 version. + + With DP sharding, each DP rank processes its own local batch of symmetry features. + The coords DTensor is sharded along (DP, CP_0, CP_1) axes. + """ + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + input_feats_global_host, + feats_symmetry_global_host, + expected_true_coords_per_sample, + expected_true_mask_per_sample, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device = manager.device + dtype = torch.float32 + + size_batch = input_feats_global_host["atom_pad_mask"].shape[0] + rank_dp = manager.group_rank["dp"] + + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in input_feats_global_host.items() + if k in _placements_cp_atom_features + } + inputs_atom["sample_coords"] = input_feats_global_host["sample_coords"].to(dtype=dtype) + + placements_cp = _placements_cp_atom_features | _placements_cp_sample_coords + placements_dp_cp = _placements_atom_features | _placements_sample_coords + + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + manager.device_mesh_subgroups, + manager.group["cp"], + ) + + coords_dtensor = feats_atom["sample_coords"] + atom_pad_mask_dtensor = feats_atom["atom_pad_mask"] + + coords_placements = coords_dtensor.placements + + num_dp_ranks = grid_group_sizes["dp"] + global_batch_size = size_batch + local_batch_size = global_batch_size // num_dp_ranks + local_start = rank_dp * local_batch_size + local_end = local_start + local_batch_size + + feats_local = { + "all_coords": feats_symmetry_global_host["all_coords"][local_start:local_end].to(device), + "all_resolved_mask": feats_symmetry_global_host["all_resolved_mask"][local_start:local_end].to(device), + "crop_to_all_atom_map": feats_symmetry_global_host["crop_to_all_atom_map"][local_start:local_end].to(device), + "chain_swaps": feats_symmetry_global_host["chain_swaps"][local_start:local_end], + "amino_acids_symmetries": feats_symmetry_global_host["amino_acids_symmetries"][local_start:local_end], + "ligand_symmetries": feats_symmetry_global_host["ligand_symmetries"][local_start:local_end], + "atom_pad_mask": atom_pad_mask_dtensor, + } + + for i_batch_local in range(local_batch_size): + global_batch_idx = local_start + i_batch_local + + true_coords_dtensor, true_mask_dtensor = dtensor_minimum_lddt_symmetry_coords( + coords=coords_dtensor, + feats=feats_local, + index_batch_local=i_batch_local, + i_batch_multiplicity_local=i_batch_local, + ) + + assert ( + true_coords_dtensor.placements == coords_placements + ), f"Sample {i_batch_local}: true_coords_dtensor.placements mismatch" + assert ( + true_mask_dtensor.placements == coords_placements + ), f"Sample {i_batch_local}: true_mask_dtensor.placements mismatch" + + expected_coords = expected_true_coords_per_sample[global_batch_idx] + expected_mask = expected_true_mask_per_sample[global_batch_idx] + + coords_cp_gathered = true_coords_dtensor.redistribute( + true_coords_dtensor.device_mesh, + (true_coords_dtensor.placements[0], Replicate(), Replicate()), + ).to_local() + mask_cp_gathered = true_mask_dtensor.redistribute( + true_mask_dtensor.device_mesh, + (true_mask_dtensor.placements[0], Replicate(), Replicate()), + ).to_local() + + atom_pad_mask_gathered = atom_pad_mask_dtensor.redistribute( + atom_pad_mask_dtensor.device_mesh, + (atom_pad_mask_dtensor.placements[0], Replicate(), Replicate()), + ).to_local() + + real_atom_mask = atom_pad_mask_gathered[i_batch_local].bool() + coords_no_pad = coords_cp_gathered[i_batch_local, real_atom_mask, :] + mask_no_pad = mask_cp_gathered[i_batch_local, real_atom_mask] + + torch.testing.assert_close( + coords_no_pad.cpu(), + expected_coords, + msg=f"Sample {i_batch_local} (global {global_batch_idx}): true_coords mismatch", + ) + torch.testing.assert_close( + mask_no_pad.cpu(), + expected_mask.squeeze(0) if expected_mask.ndim > 1 else expected_mask, + msg=f"Sample {i_batch_local} (global {global_batch_idx}): true_mask mismatch", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("use_nontrivial_swaps", [False, True], ids=["identity_swaps", "nontrivial_swaps"]) +def test_serial_minimum_lddt_symmetry_coords(use_nontrivial_swaps): + """Verify serial minimum_lddt_symmetry_coords runs correctly on a single GPU. + + Tests both identity (no-op) and non-trivial (two-chain A<->B) swap cases. + """ + batch_size = 2 + n_atoms = 24 + + torch.manual_seed(123) + sample_coords = torch.randn((batch_size, n_atoms, 3), dtype=torch.float32) + all_coords = sample_coords.clone() + all_resolved_mask = torch.ones((batch_size, n_atoms), dtype=torch.bool) + crop_to_all_atom_map = torch.arange(n_atoms, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() + + chain_swaps_batch = None + if use_nontrivial_swaps: + chain_swaps_batch = [_make_two_chain_swaps(n_atoms) for _ in range(batch_size)] + + feats = _build_symmetry_features_for_batch( + all_coords, all_resolved_mask, crop_to_all_atom_map, chain_swaps_batch=chain_swaps_batch + ) + + for i in range(batch_size): + true_coords, true_mask = serial_minimum_lddt_symmetry_coords( + coords=sample_coords[i : i + 1], + feats=feats, + index_batch=i, + ) + assert true_coords.shape[-1] == 3, f"Sample {i}: unexpected coords shape {true_coords.shape}" + assert true_mask.dtype == torch.bool, f"Sample {i}: unexpected mask dtype {true_mask.dtype}" + assert true_mask.any(), f"Sample {i}: all-zero mask unexpected with all-resolved input" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("use_nontrivial_swaps", [False, True], ids=["identity_swaps", "nontrivial_swaps"]) +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), # dp=1, cp=(2,2), world_size=4 + ((2, (2, 2)), True, "cuda", "ENV"), # dp=2, cp=(2,2), world_size=8 + ], + indirect=("setup_env",), +) +def test_dtensor_minimum_lddt_symmetry_coords(setup_env, use_nontrivial_swaps): + """Test distributed symmetry correction against Boltz-2 serial implementation. + + Parametrized for 4-GPU and 8-GPU configs, with both identity (no-op) and + non-trivial (two-chain A<->B) chain swaps. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + size_ring = grid_group_sizes["cp"][0] + num_dp_ranks = grid_group_sizes["dp"] + batch_size_per_dp_rank = 1 + batch_size_global = batch_size_per_dp_rank * num_dp_ranks + + n_atoms_per_token = 3 + n_tokens = size_ring * 4 + n_atoms = n_atoms_per_token * n_tokens + + torch.manual_seed(42) + feats_from_random = random_features( + size_batch=batch_size_global, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=1, + atom_counts_per_token_range=(1, n_atoms_per_token), + device=torch.device("cpu"), + float_value_range=(-1.0, 1.0), + selected_keys=["atom_pad_mask", "coords", "atom_counts_per_token"], + ) + + sample_coords_global = torch.randn((batch_size_global, n_atoms, 3), dtype=torch.float32) + + atom_pad_mask_global = feats_from_random["atom_pad_mask"] + atom_counts_per_token_global = feats_from_random["atom_counts_per_token"] + + input_feats_global = { + "atom_pad_mask": atom_pad_mask_global, + "atom_counts_per_token": atom_counts_per_token_global, + "sample_coords": sample_coords_global, + } + + coords_for_symmetry = sample_coords_global + all_coords_global = coords_for_symmetry.clone() + all_resolved_mask_global = torch.ones((batch_size_global, n_atoms), dtype=torch.bool) + crop_to_all_atom_map_global = ( + torch.arange(n_atoms, dtype=torch.long).unsqueeze(0).expand(batch_size_global, -1).contiguous() + ) + + chain_swaps_batch = None + if use_nontrivial_swaps: + chain_swaps_batch = [_make_two_chain_swaps(n_atoms) for _ in range(batch_size_global)] + + feats_symmetry_global = _build_symmetry_features_for_batch( + all_coords_global, + all_resolved_mask_global, + crop_to_all_atom_map_global, + chain_swaps_batch=chain_swaps_batch, + ) + + expected_true_coords_per_sample = [] + expected_true_mask_per_sample = [] + + for i in range(batch_size_global): + expected_coords, expected_mask = serial_minimum_lddt_symmetry_coords( + coords=coords_for_symmetry[i : i + 1], + feats=feats_symmetry_global, + index_batch=i, + ) + expected_true_coords_per_sample.append(expected_coords.squeeze(0).detach().clone().cpu()) + expected_true_mask_per_sample.append(expected_mask.detach().clone().cpu()) + + input_feats_global_host = {k: v.detach().clone().cpu() for k, v in input_feats_global.items()} + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + input_feats_global_host, + feats_symmetry_global, + expected_true_coords_per_sample, + expected_true_mask_per_sample, + ) + + spawn_multiprocessing(parallel_assert_minimum_lddt_symmetry_coords, world_size, payload) diff --git a/tests/distributed/data/test_dtensor_pack_and_pad_atom_features.py b/tests/distributed/data/test_dtensor_pack_and_pad_atom_features.py new file mode 100644 index 000000000..97fc5203a --- /dev/null +++ b/tests/distributed/data/test_dtensor_pack_and_pad_atom_features.py @@ -0,0 +1,323 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for pack_atom_features function. + +This module tests the pack_atom_features function which removes per-shard trailing +padding from pad_and_scatter_atom_features_dtensor and creates a packed DTensor +with global trailing padding (multiple of W * size_cp). + +Tests verify that: +1. Packed features match serial reference in the valid (non-padding) region +2. The packing operation correctly handles variable atoms per token +3. atom_to_token is converted to global indices (atom_to_token_ids_global) +""" + +import warnings + +import pytest +import torch +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._utils import compute_global_tensor_info + +from boltz.distributed.data.feature.featurizer import pack_atom_features, pad_and_scatter_atom_features_dtensor +from boltz.distributed.manager import DistributedManager +from boltz.testing.utils import ( + assert_tensors_close_with_pad, + assert_tensors_identical, + get_feature_placements, + random_features, + seed_by_rank, + spawn_multiprocessing, +) + +# Subset of keys needed for pack_atom_features test +_selected_atom_keys = { + "atom_pad_mask", + "ref_pos", + "ref_space_uid", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "atom_to_token", + "atom_counts_per_token", # Required by pad_and_scatter_atom_features_dtensor +} + +# Get feature placements from centralized utility function with atom key subset +# Pass empty sets for unused categories to suppress irrelevant placements +_placements = get_feature_placements( + token_keys=set(), + msa_keys=set(), + atom_keys=_selected_atom_keys, + model_io_keys=set(), + model_io_fp32_keys=set(), +) +_placements_cp_atom_features = _placements["cp_atom_features"] +_placements_atom_features = _placements["atom_features"] + + +def parallel_assert_pack_and_pad_atom_features( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + dtype: torch.dtype, + W: int, + # Inputs on host (global tensors) + feats_global_host: dict[str, torch.Tensor], +): + """Parallel worker function for testing pack_atom_features.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # ======================================================================== + # Distribute atom features using pad_and_scatter_atom_features_dtensor + # ======================================================================== + # This follows the pattern from test_dtensor_model.py: + # 1. Each dp rank gets one sample from the batch + # 2. pad_and_scatter_atom_features_dtensor distributes within cp group + # 3. Results are collated to form full DTensors with dp+cp sharding + + size_batch = feats_global_host["atom_pad_mask"].shape[0] + assert size_batch == len(manager.group_ranks["dp"]), "size_batch must equal number of dp ranks" + size_batch_per_dp = size_batch // len(manager.group_ranks["dp"]) + rank_dp = manager.group_rank["dp"] + i_sample_begin = rank_dp * size_batch_per_dp + + # Prepare inputs for pad_and_scatter_atom_features_dtensor + # Only cp rank 0 provides the input (it gets scattered to all cp ranks) + if manager.group_rank["cp"] == 0: + inputs = { + k: v[i_sample_begin].to(device=manager.device, dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + } + else: + inputs = None + + # Distribute atom features using pad_and_scatter_atom_features_dtensor + feats_atom_dtensor_cp = pad_and_scatter_atom_features_dtensor( + inputs, + _placements_cp_atom_features, + manager.group["cp"], + manager.group_ranks["cp"][0], + manager.device_mesh_subgroups["cp_axis_0", "cp_axis_1"], + ) + + # Collate along batch dimension: per-dp rank atom features -> single DTensor with dp+cp sharding + feats_atom_shape_stride_global = { + k: tuple( + map( + tuple, + compute_global_tensor_info( + v.to_local().unsqueeze(0), manager.device_mesh_subgroups, _placements_atom_features[k] + ), + ) + ) + for k, v in feats_atom_dtensor_cp.items() + } + feats_dt = { + k: DTensor.from_local( + v.to_local().unsqueeze(0), + manager.device_mesh_subgroups, + _placements_atom_features[k], + shape=feats_atom_shape_stride_global[k][0], + stride=feats_atom_shape_stride_global[k][1], + ) + for k, v in feats_atom_dtensor_cp.items() + } + + # ======================================================================== + # Pack and pad atom features using pack_atom_features + # This removes per-shard trailing padding from pad_and_scatter_atom_features_dtensor + # and creates a packed DTensor with global trailing padding (multiple of W * size_cp) + # ======================================================================== + # Suppress the expected warning about atom_to_token not being packed + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="pack_atom_features: 'atom_to_token'") + feats_dt_packed = pack_atom_features(feats_dt, set(feats_dt.keys()), W) + + # ======================================================================== + # Verify consistency between packed features and serial reference + # Packed features should match serial except for trailing padding along N_atoms axis + # ======================================================================== + N_atoms_actual = feats_global_host["atom_pad_mask"].shape[1] + + # Check that packed features match serial in valid region + # Note: atom_to_token is excluded because pack_atom_features only outputs + # atom_to_token_ids_global (global indices), not the original one-hot matrix. + # The global indices are computed via shardwise_argmax + shardwise_offset before packing. + atom_feature_keys_to_check = [ + "ref_pos", + "ref_space_uid", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "atom_pad_mask", + ] + for key in atom_feature_keys_to_check: + if key not in feats_dt_packed: + continue + packed_full = feats_dt_packed[key].full_tensor() + serial_ref = feats_global_host[key].to(device=manager.device, dtype=packed_full.dtype) + assert_tensors_close_with_pad( + packed_full, + serial_ref, + axis=1, + pad_val=0, + msg=lambda m, k=key: f"Packed feature {k} mismatch: {m}", + ) + + # Verify atom_to_token_ids_global is present and has correct shape + assert "atom_to_token_ids_global" in feats_dt_packed, "atom_to_token_ids_global should be in packed features" + atom_to_token_ids_global_packed = feats_dt_packed["atom_to_token_ids_global"] + # Shape should be (B, N_atoms_packed) where N_atoms_packed >= N_atoms_actual + assert atom_to_token_ids_global_packed.shape[0] == size_batch + assert atom_to_token_ids_global_packed.shape[1] >= N_atoms_actual + + # 'atom_to_token' is returned as it is with a different feature name + assert "atom_to_token_local_onehot" in feats_dt_packed, "atom_to_token_local_onehot should be in packed features" + atom_to_token_local_onehot_packed = feats_dt_packed["atom_to_token_local_onehot"] + atom_to_token_local_onehot_packed_full = atom_to_token_local_onehot_packed.full_tensor() + atom_to_token_expected_full = feats_dt["atom_to_token"].full_tensor() + assert_tensors_identical(atom_to_token_local_onehot_packed_full, atom_to_token_expected_full) + + # Verify the global indices match the serial reference in valid region + # Serial atom_to_token is one-hot (B, N_atoms, N_tokens), extract indices via argmax + atom_to_token_serial = feats_global_host["atom_to_token"].to(device=manager.device) + atom_to_token_ids_serial = atom_to_token_serial.argmax(dim=-1) # (B, N_atoms) + atom_to_token_ids_global_packed_full = atom_to_token_ids_global_packed.full_tensor() + assert_tensors_close_with_pad( + atom_to_token_ids_global_packed_full, + atom_to_token_ids_serial, + axis=1, + pad_val=0, + msg=lambda m: f"atom_to_token_ids_global mismatch: {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize( + "n_atoms_per_token_range", + [(8, 20)], # can use (1, 1) for debugging purpose + ids=lambda x: f"atoms_per_token:{x[0]}-{x[1]}", +) +def test_pack_and_pad_atom_features(setup_env, n_atoms_per_token_range: tuple[int, int]): + """Test pack_atom_features function. + + Verifies that: + 1. Packed features match serial reference in valid region + 2. Variable atoms per token are handled correctly + 3. atom_to_token_ids_global contains correct global token indices + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + # Configuration + dtype = torch.float32 + seed = 42 + seed_by_rank(0, seed=seed) + + size_cp = grid_group_sizes["cp"][0] + B = 1 * grid_group_sizes["dp"] # batch size per rank = 1 + + # Test parameters + n_atoms_per_token_min, n_atoms_per_token_max = n_atoms_per_token_range + W = 32 # atoms per window for queries + # N_tokens must be divisible by size_cp for even token sharding + N_tokens = 1000 * size_cp + # With max atoms per token, N_atoms = N_tokens * n_atoms_per_token_max + N_atoms = N_tokens * n_atoms_per_token_max + N_msa = 1 # minimal MSA + + val_init_min_max = (-0.5, 0.5) + + # Verify constraints + assert N_tokens % size_cp == 0, f"N_tokens ({N_tokens}) must be divisible by size_cp ({size_cp})" + + # ======================================================================== + # Generate features using random_features + # This subset of features is for AtomAttentionEncoder usage + # ======================================================================== + selected_keys = [ + "atom_pad_mask", + "ref_pos", + "ref_space_uid", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "atom_to_token", + "atom_counts_per_token", # Required by pad_and_scatter_atom_features_dtensor + ] + + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=selected_keys, + ) + + # Convert float64 to float32 for consistency + feats = {k: v.to(dtype=dtype) if v.dtype == torch.float64 else v for k, v in feats.items()} + + # Prepare inputs for distributed test + feats_global_host = {k: v.detach().cpu() for k, v in feats.items()} + + # Launch multiprocess test + spawn_multiprocessing( + parallel_assert_pack_and_pad_atom_features, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + W, + feats_global_host, + ) diff --git a/tests/distributed/data/test_dtensor_scatter_features.py b/tests/distributed/data/test_dtensor_scatter_features.py new file mode 100644 index 000000000..cb13e2ef1 --- /dev/null +++ b/tests/distributed/data/test_dtensor_scatter_features.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from typing import Dict, Optional + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard + +from boltz.distributed.data.feature.featurizer import pad_and_scatter_atom_features_dtensor +from boltz.distributed.data.feature.featurizer_utils import remap_atom_indices_repad +from boltz.distributed.manager import DistributedManager +from boltz.testing.utils import seed_by_rank, spawn_multiprocessing + + +@pytest.mark.parametrize( + "old_stride,new_stride,n_shards", + [ + (5, 8, 3), + (4, 4, 2), + (6, 10, 4), + ], +) +def test_remap_atom_indices_repad(old_stride, new_stride, n_shards): + """Verify remap_atom_indices_repad correctly remaps padded atom indices.""" + # Build indices that cover multiple shards: + # for each shard s, place a few valid offsets within [0, old_stride). + indices = [] + expected = [] + for s in range(n_shards): + for offset in [0, 1, old_stride - 1]: + indices.append(s * old_stride + offset) + expected.append(s * new_stride + offset) + + indices_t = torch.tensor(indices, dtype=torch.int64) + expected_t = torch.tensor(expected, dtype=torch.int64) + + result = remap_atom_indices_repad(indices_t, old_stride, new_stride) + torch.testing.assert_close(result, expected_t) + + +def _make_sample_features(n_tokens, atom_counts_per_token, device): + """Build minimal feature/placement dicts for a single sample.""" + total_atoms = atom_counts_per_token.sum().item() + features = { + "atom_counts_per_token": atom_counts_per_token.to(device), + "atom_pad_mask": torch.ones(total_atoms, device=device), + "frames_idx": torch.randint(0, total_atoms, (1, n_tokens, 3), device=device, dtype=torch.int64), + } + placements = { + "atom_counts_per_token": (Shard(0), Replicate()), + "atom_pad_mask": (Shard(0), Replicate()), + "frames_idx": (Shard(1), Replicate()), + } + return features, placements + + +def parallel_assert_collate_dtensor_atom_index_remap( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + """Verify CollateDTensor remaps atom-index features when samples differ in max_atoms_per_shard.""" + from boltz.distributed.data.utils import CollateDTensor, map_subgroup_mesh_to_cpu + + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + seed_by_rank(0, seed=99) + + cp_submesh = manager.device_mesh_subgroups["cp_axis_0", "cp_axis_1"] + n_cp = cp_submesh.shape[0] + cp_group = manager.group["cp"] + src_rank = cp_submesh.mesh.flatten()[0].item() + is_src = manager.group_rank["world"] == src_rank + device = manager.device + + # Two samples with different atom counts → different max_atoms_per_shard + n_tokens = 4 * n_cp + counts_a = torch.tensor([2] * n_tokens, device=device) + counts_b = torch.tensor([4] * n_tokens, device=device) + + feats_a, place_a = _make_sample_features(n_tokens, counts_a, device) + feats_b, place_b = _make_sample_features(n_tokens, counts_b, device) + + dtensors_a = pad_and_scatter_atom_features_dtensor( + feats_a if is_src else None, place_a, cp_group, src_rank, cp_submesh + ) + dtensors_b = pad_and_scatter_atom_features_dtensor( + feats_b if is_src else None, place_b, cp_group, src_rank, cp_submesh + ) + + # Record per-sample atom dim before collation + atoms_per_shard_a = dtensors_a["atom_pad_mask"].to_local().shape[0] + atoms_per_shard_b = dtensors_b["atom_pad_mask"].to_local().shape[0] + assert atoms_per_shard_a != atoms_per_shard_b, "samples must differ in max_atoms_per_shard" + smaller, larger = sorted([atoms_per_shard_a, atoms_per_shard_b]) + + # Save frames_idx local values before collation + fidx_local_a = dtensors_a["frames_idx"].to_local().clone() + fidx_local_b = dtensors_b["frames_idx"].to_local().clone() + + # Collate + dp_cp_mesh = map_subgroup_mesh_to_cpu(manager) + collator = CollateDTensor(dp_cp_mesh) + batch = collator([dtensors_a, dtensors_b]) + + # After collation, atom dim should be the larger of the two + final_atoms_per_shard = batch["atom_pad_mask"].to_local().shape[1] + assert final_atoms_per_shard == larger + + # frames_idx for the smaller sample should have been remapped + fidx_batch = batch["frames_idx"].to_local() + fidx_collated_a = fidx_batch[0] + fidx_collated_b = fidx_batch[1] + + expected_a = remap_atom_indices_repad(fidx_local_a, atoms_per_shard_a, final_atoms_per_shard) + expected_b = remap_atom_indices_repad(fidx_local_b, atoms_per_shard_b, final_atoms_per_shard) + + torch.testing.assert_close(fidx_collated_a, expected_a) + torch.testing.assert_close(fidx_collated_b, expected_b) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +def test_collate_dtensor_atom_index_remap(setup_env): + """Verify CollateDTensor remaps atom-index features when batch samples have different atom counts.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + spawn_multiprocessing( + parallel_assert_collate_dtensor_atom_index_remap, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) diff --git a/tests/distributed/dtensor_train_harness.py b/tests/distributed/dtensor_train_harness.py new file mode 100644 index 000000000..13c2043fc --- /dev/null +++ b/tests/distributed/dtensor_train_harness.py @@ -0,0 +1,365 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Shared harness for lightweight DTensor distributed training tests.""" + +from __future__ import annotations + +from pathlib import Path + +import pytorch_lightning as pl +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor +from torch.utils.data import DataLoader, Dataset + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.loss.distogram import distogram_loss as distogram_loss_dtensor +from boltz.distributed.model.modules.trunkv2 import DistogramModule as DistogramModuleDTensor +from boltz.distributed.model.optim.ema import DistributedEMA +from boltz.model.loss.distogramv2 import distogram_loss as distogram_loss_serial +from boltz.model.modules.trunkv2 import DistogramModule as SerialDistogramModule +from boltz.testing.utils import init_module_params_uniform + + +class SyntheticDistogramDataset(Dataset): + """Deterministic synthetic dataset for CP distogram training tests.""" + + def __init__( + self, + *, + seq_len: int, + token_z: int, + num_bins: int, + num_conformers: int, + num_samples: int, + seed: int, + ) -> None: + super().__init__() + generator = torch.Generator() + generator.manual_seed(seed) + + samples: list[dict[str, torch.Tensor]] = [] + for _ in range(num_samples): + z = torch.rand((seq_len, seq_len, token_z), generator=generator) + target = torch.softmax( + torch.randn((seq_len, seq_len, num_conformers, num_bins), generator=generator), + dim=-1, + ) + mask = torch.randint(0, 2, (seq_len,), generator=generator, dtype=torch.bool) + samples.append({"z": z, "target": target, "mask": mask}) + self.samples = samples + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: + return self.samples[index] + + +class TinyDistogramCPModel(pl.LightningModule): + """Tiny LightningModule that exercises DTensor distogram forward/loss in CP.""" + + def __init__( + self, + *, + dist_manager: DistributedManager, + token_z: int, + num_bins: int, + num_distograms: int, + num_conformers: int, + serial_state_dict: dict[str, torch.Tensor], + learning_rate: float, + adam_betas: tuple[float, float] = (0.9, 0.999), + ) -> None: + super().__init__() + self.dist_manager = dist_manager + self.num_conformers = num_conformers + self.learning_rate = learning_rate + self.adam_betas = adam_betas + self.last_pred_local: torch.Tensor | None = None + + serial_module = SerialDistogramModule( + token_z=token_z, + num_bins=num_bins, + num_distograms=num_distograms, + ) + serial_module.load_state_dict(serial_state_dict) + serial_module.to(dist_manager.device) # Must be on mesh device before DTensor wrapping + self.loss_comm = TransposeComm(dist_manager.group["cp"], dist_manager.layout_subgroups["cp"]) + self.distogram_module = DistogramModuleDTensor( + module=serial_module, + dist_manager=dist_manager, + distogram_comm=self.loss_comm, + ) + + def on_train_start(self) -> None: + # Loss log keyed by (epoch, batch_idx) for stop/go trajectory comparison. + self._loss_log: dict[tuple[int, int], float] = {} + + def _assert_sharding_active( + self, + global_tensor: torch.Tensor, + dtensor: torch.Tensor, + label: str, + ) -> None: + """Assert local shape is smaller than global for each Shard dim. + + Catches bugs where CP silently falls back to replication. + """ + from torch.distributed.tensor import DTensor as _DTensor + + if not isinstance(dtensor, _DTensor): + raise AssertionError(f"{label}: expected DTensor, got {type(dtensor)}") + + local = dtensor.to_local() + for mesh_dim, placement in enumerate(dtensor.placements): + if isinstance(placement, Shard): + shard_dim = placement.dim + mesh_size = dtensor.device_mesh.size(mesh_dim) + if mesh_size > 1: + assert local.shape[shard_dim] < global_tensor.shape[shard_dim], ( + f"{label}: Shard({shard_dim}) on mesh dim {mesh_dim} (size={mesh_size}) " + f"did not reduce local shape — global {global_tensor.shape[shard_dim]}, " + f"local {local.shape[shard_dim]}. CP may not be active." + ) + + def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: + z = batch["z"].to(self.dist_manager.device) + target = batch["target"].to(self.dist_manager.device) + mask = batch["mask"].to(self.dist_manager.device) + + # Mesh dims: (dp, cp_row, cp_col). Shard(0) splits the batch across + # dp ranks; DataLoader supplies global batch (micro_batch * dp_size). + z_dtensor = distribute_tensor( + z, + device_mesh=self.dist_manager.device_mesh_subgroups, + placements=(Shard(0), Shard(1), Shard(2)), + ) + target_dtensor = distribute_tensor( + target, + device_mesh=self.dist_manager.device_mesh_subgroups, + placements=(Shard(0), Shard(1), Shard(2)), + ) + mask_dtensor = distribute_tensor( + mask, + device_mesh=self.dist_manager.device_mesh_subgroups, + placements=(Shard(0), Shard(1), Replicate()), + ) + + # Verify DTensor sharding is active — proves CP is real, not silently + # replicated. Each mesh dim shrinks the corresponding tensor dim. + self._assert_sharding_active(z, z_dtensor, "z") + + pred_dtensor = self.distogram_module(z_dtensor) + # Capture local output for tests without altering Trainer internals. + self.last_pred_local = pred_dtensor.to_local().detach().cpu() + global_loss_dtensor, _ = distogram_loss_dtensor( + {"pdistogram": pred_dtensor}, + {"disto_target": target_dtensor, "token_disto_mask": mask_dtensor}, + self.loss_comm, + aggregate_distogram=False, + ) + loss = global_loss_dtensor.to_local() + self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False) + self._loss_log[(self.current_epoch, batch_idx)] = loss.detach().item() + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.learning_rate, betas=self.adam_betas) + + +class TinyDistogramCPModelWithEMA(TinyDistogramCPModel): + """TinyDistogramCPModel with ``DistributedEMA`` via ``configure_callbacks()``.""" + + def __init__( + self, + *, + ema_decay: float = 0.999, + **kwargs: object, + ) -> None: + super().__init__(**kwargs) # type: ignore[arg-type] + self.ema_decay = ema_decay + + def configure_callbacks(self) -> list[pl.Callback]: + return [DistributedEMA(decay=self.ema_decay)] + + +class TinyDistogramSerialModel(pl.LightningModule): + """Serial (plain-tensor) counterpart of ``TinyDistogramCPModel``. + + Same loss, hyperparameters, and state-dict keys — used for cross-mode + checkpoint tests (serial ↔ distributed). + """ + + def __init__( + self, + *, + token_z: int, + num_bins: int, + num_distograms: int, + num_conformers: int, + serial_state_dict: dict[str, torch.Tensor], + learning_rate: float, + adam_betas: tuple[float, float] = (0.9, 0.999), + ) -> None: + super().__init__() + self.num_conformers = num_conformers + self.learning_rate = learning_rate + self.adam_betas = adam_betas + + self.distogram_module = SerialDistogramModule( + token_z=token_z, + num_bins=num_bins, + num_distograms=num_distograms, + ) + self.distogram_module.load_state_dict(serial_state_dict) + + def on_train_start(self) -> None: + self._loss_log: dict[tuple[int, int], float] = {} + + def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: + z = batch["z"].to(self.device) + target = batch["target"].to(self.device) + mask = batch["mask"].to(self.device) + + pred = self.distogram_module(z) + global_loss, _ = distogram_loss_serial( + {"pdistogram": pred}, + {"disto_target": target, "token_disto_mask": mask}, + aggregate_distogram=False, + ) + self.log("train/loss", global_loss, on_step=True, on_epoch=True, prog_bar=False) + self._loss_log[(self.current_epoch, batch_idx)] = global_loss.detach().item() + return global_loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.learning_rate, betas=self.adam_betas) + + +class DistogramTrainDataModule(pl.LightningDataModule): + """LightningDataModule wrapping :class:`SyntheticDistogramDataset`. + + Used by stop/go tests that route through ``train.py``, which requires a + ``LightningDataModule`` (not a bare ``DataLoader``). + """ + + def __init__( + self, + *, + seq_len: int, + token_z: int, + num_bins: int, + num_conformers: int, + num_samples: int, + seed: int, + batch_size: int = 1, + dp_size: int = 1, + ) -> None: + super().__init__() + self._seq_len = seq_len + self._token_z = token_z + self._num_bins = num_bins + self._num_conformers = num_conformers + self._num_samples = num_samples + self._seed = seed + self._batch_size = batch_size + self._dp_size = dp_size + + def train_dataloader(self) -> DataLoader: + return create_train_dataloader( + seq_len=self._seq_len, + token_z=self._token_z, + num_bins=self._num_bins, + num_conformers=self._num_conformers, + num_samples=self._num_samples, + seed=self._seed, + batch_size=self._batch_size, + dp_size=self._dp_size, + ) + + +def create_initial_serial_state_dict( + *, + token_z: int, + num_bins: int, + num_distograms: int, + seed: int, +) -> dict[str, torch.Tensor]: + """Create deterministic initial serial distogram parameters.""" + with torch.random.fork_rng(devices=[]): + torch.manual_seed(seed) + serial_module = SerialDistogramModule( + token_z=token_z, + num_bins=num_bins, + num_distograms=num_distograms, + ) + init_module_params_uniform(serial_module, low=-0.25, high=0.25) + return {key: value.detach().clone().cpu() for key, value in serial_module.state_dict().items()} + + +def create_train_dataloader( + *, + seq_len: int, + token_z: int, + num_bins: int, + num_conformers: int, + num_samples: int, + seed: int, + batch_size: int = 1, + dp_size: int = 1, +) -> DataLoader: + """Create deterministic dataloader for distributed train tests. + + Produces global batches of ``batch_size * dp_size``; the test harness + splits them across dp ranks via ``distribute_tensor(Shard(0))``. + """ + global_batch = batch_size * dp_size + dataset = SyntheticDistogramDataset( + seq_len=seq_len, + token_z=token_z, + num_bins=num_bins, + num_conformers=num_conformers, + num_samples=num_samples, + seed=seed, + ) + return DataLoader(dataset, batch_size=global_batch, shuffle=False, num_workers=0) + + +def state_dicts_differ( + baseline_state: dict[str, torch.Tensor], + candidate_state: dict[str, torch.Tensor], +) -> bool: + """Return True if at least one key differs.""" + for key in baseline_state: + if key not in candidate_state: + return True + if not torch.equal(baseline_state[key], candidate_state[key].cpu()): + return True + return False + + +def save_rank_state_dict(output_dir: Path, rank: int, state_dict: dict[str, torch.Tensor]) -> Path: + """Save rank-local serialized state dict for cross-rank checks.""" + output_dir.mkdir(parents=True, exist_ok=True) + path = output_dir / f"rank_{rank}_state.pt" + torch.save(state_dict, path) + return path diff --git a/tests/distributed/model/__init__.py b/tests/distributed/model/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/tests/distributed/model/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/tests/distributed/model/layers/test_attention_with_dtensor_for_pairformer_use_case.py b/tests/distributed/model/layers/test_attention_with_dtensor_for_pairformer_use_case.py new file mode 100755 index 000000000..57e97e27d --- /dev/null +++ b/tests/distributed/model/layers/test_attention_with_dtensor_for_pairformer_use_case.py @@ -0,0 +1,548 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests a single instance of AttentionPairBiasWithDTensor + +Verification requirements + + V1: single-proc FW input tensor values unchanged by FW and BW + V2: single-proc BW input tensor values unchanged by BW + V3: single-proc FW input tensor grads are zero at padded locations (virtual atoms) + - for input tensors that require grads + + V4: multi-proc version of V1 + V5: multi-proc version of V2 + V6: multi-proc version of V3: implied by V3 and V9 + + V7: multi-proc FW input tensor values and meta match single-proc inputs + V8: multi-proc FW output tensor values close-to single-proc + V9: multi-proc FW input gradient values close-to single-proc + V10: multi-proc parameter gradient values close-to single-proc + V11: multi-proc parameter gradident values identical across proc's + +Implementation status + V1: + V2: + V3: NA + V4: implemented + V5: implemented + V6: implied by V3 and V9 + V7: same data + V8: implemented + V9: implemented + V10: implemented + V11: + +Assertion threshold defaults for pytorch + +dtype rtol atol +-------- ------- ------ +float16 1e-3 1e-5 +bfloat16 1.6e-2 1e-5 +float32 1.3e-6 1e-5 + +""" + +from collections import OrderedDict +from copy import deepcopy +from typing import Any, Union + +import pytest +import torch +from torch import Tensor +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor +from torch.testing import assert_close + +from boltz.distributed.comm import AttentionPairBiasComm +from boltz.distributed.manager import DistributedManager, _GridGroupSizesType +from boltz.distributed.model.layers.attention import AttentionPairBias as AttentionPairBiasWithDTensor +from boltz.model.layers.attention import AttentionPairBias as AttentionPairBiasSerialV1 +from boltz.model.layers.attentionv2 import AttentionPairBias as AttentionPairBiasSerialV2 +from boltz.testing.utils import ( + assert_all_identical, + assert_tensors_identical, + seed_by_rank, + skip_if_cuda_not_avail_or_device_count_less_than_word_size, + spawn_multiprocessing, +) + +SEED = 42 + + +def assert_attention_pair_bias_with_dtensor_fw_bw( + rank: int, + input_example: OrderedDict[str, Tensor], + output_ref: Tensor, + output_grad_example: Tensor, + input_grads_ref: OrderedDict[str, Tensor], + c_s: int, + c_z: int, + parameter_grads_ref_as_tensors: OrderedDict[str, Tensor], + num_heads: int, + layer_state_dict: OrderedDict[str, Tensor], + inf: float, + grid_group_sizes: _GridGroupSizesType, + device_type: str, + backend: str, + env_per_rank: dict[str, str], + serial_version: str, + apply_initial_norm: bool, + use_model_cache: bool, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # ------------------------------------------------------------- + # Setup comm objects + # -------------------------------------------------------------- + ring_comm = AttentionPairBiasComm( + process_group=manager.group["cp"], + group_layout=manager.layout_subgroups["cp"], + cp_axis_0_group=manager.subgroups["cp"][0], + cp_axis_1_group=manager.subgroups["cp"][1], + ) + # ------------------------------------------------------------- + # Move inputs and ref outputs to device + # -------------------------------------------------------------- + input_example_device = OrderedDict( + [ + (k, v.detach().to(manager.device) if isinstance(v, Tensor) else deepcopy(v)) + for (k, v) in input_example.items() + ] + ) + output_ref = output_ref.to(manager.device) + + inputs_grad_ref_device = OrderedDict( + [ + (k, v.detach().to(manager.device) if isinstance(v, Tensor) else deepcopy(v)) + for (k, v) in input_grads_ref.items() + ] + ) + s_grad_ref = inputs_grad_ref_device.get("s", None) + z_grad_ref = inputs_grad_ref_device.get("z", None) + + output_grad_example: Tensor = output_grad_example.detach().to(manager.device) + + # ------------------------------------------------------------- + # Create module to test using the appropriate serial version + # -------------------------------------------------------------- + if serial_version == "v1": + single_proc_module = AttentionPairBiasSerialV1( + c_s=c_s, + c_z=c_z, + num_heads=num_heads, + inf=inf, + initial_norm=apply_initial_norm, + ) + else: + single_proc_module = AttentionPairBiasSerialV2( + c_s=c_s, + c_z=c_z, + num_heads=num_heads, + inf=inf, + compute_pair_bias=True, + ) + single_proc_module.load_state_dict(state_dict=layer_state_dict) + single_proc_module = single_proc_module.train() + single_proc_module = single_proc_module.to(manager.device) + + multi_proc_module = AttentionPairBiasWithDTensor( + attn_pair_bias=single_proc_module, + device_mesh=manager.device_mesh_subgroups, + ring_comm=ring_comm, + apply_initial_norm=apply_initial_norm, + compute_pair_bias=True, # PairFormer always computes pair bias + use_model_cache=use_model_cache, + ) + # ----------------------------------------------------- + # Create input DTensors + # s is on device, the whole example input + # mask is on device, the whole example input + # s_grad is on device, the whole reference tensor + # z_grad is on device + # ---------------------------------------------------- + placements_for_single_rep_nonparam = (Shard(0), Shard(1), Replicate()) + placements_for_pair_rep_nonparam = (Shard(0), Shard(1), Shard(2)) + # Note: pair_mask is not used in Boltz-2 serial AttentionPairBias - only 1D mask is used + input_meta: OrderedDict[str, dict] = OrderedDict( + [ + ("s", dict(placements=placements_for_single_rep_nonparam, requires_grad=True)), # noqa: C408 + ("mask", dict(placements=placements_for_single_rep_nonparam, requires_grad=False)), # noqa: C408 + ("z", dict(placements=placements_for_pair_rep_nonparam, requires_grad=True)), # noqa: C408 + ] + ) + input_example_as_dtensors = OrderedDict() + for name, meta in input_meta.items(): + input_example_as_dtensors[name] = distribute_tensor( + input_example_device[name], manager.device_mesh_subgroups, meta["placements"] + ).requires_grad_(meta["requires_grad"]) + + output_ref_dt = distribute_tensor( + output_ref, + manager.device_mesh_subgroups, + placements_for_single_rep_nonparam, + ).requires_grad_(False) + + output_grad_example_dt = distribute_tensor( + output_grad_example, + manager.device_mesh_subgroups, + placements_for_single_rep_nonparam, + ).requires_grad_(False) + + # ------------------------------------------------- + # Run FW + # ------------------------------------------------- + input_example_clone_as_dtensors = OrderedDict( + [(k, v.detach().clone().requires_grad_(v.requires_grad)) for k, v in input_example_as_dtensors.items()] + ) + output_actual_dt: DTensor = multi_proc_module(**input_example_as_dtensors) + + # ------------------------------------------------------- + # V8: multi-proc FW output tensor values close-to single-proc + # ------------------------------------------------------ + assert ( + output_actual_dt.shape == output_ref_dt.shape + ), f"Output shape mismatch: {output_actual_dt.shape} != {output_ref_dt.shape}" + assert ( + output_actual_dt.stride() == output_ref_dt.stride() + ), f"Output stride mismatch: {output_actual_dt.stride()} != {output_ref_dt.stride()}" + assert_close(output_actual_dt.full_tensor(), output_ref_dt.full_tensor()) + + # ------------------------------------------------- + # Run BW + # ------------------------------------------------- + output_grad_example_clone_dt = ( + output_grad_example_dt.detach().clone().requires_grad_(output_grad_example_dt.requires_grad) + ) + output_actual_clone_dt = output_actual_dt.detach().clone().requires_grad_(output_actual_dt.requires_grad) + output_actual_dt.backward(output_grad_example_dt) + + # ------------------------------------------------------- + # V4: FW input tensor values unchanged by FW and BW + # --------------------------------------------------------- + for k in input_example_as_dtensors.keys(): + assert_tensors_identical( + input_example_as_dtensors[k].full_tensor(), + input_example_clone_as_dtensors[k].full_tensor(), + check_grad=False, + check_grad_fn=False, + ) + # ----------------------------------------------------------- + # V5: BW input tensor values unchanged by BW + # ------------------------------------------------------------- + assert_tensors_identical( + output_grad_example_dt, + output_grad_example_clone_dt, + check_grad=False, + check_grad_fn=False, + ) + assert_tensors_identical( + output_actual_dt, + output_actual_clone_dt, + check_grad=False, + check_grad_fn=False, + ) + # -------------------------------------------------------------------- + # V9: multi-proc FW input gradient values close-to single-proc + # - conduct the check by materializing the full distributed tensor + # - s_grad, z_grad on device + # --------------------------------------------------------------------- + s_grad_actual_full: Tensor = input_example_as_dtensors["s"].grad.full_tensor() + assert ( + s_grad_actual_full.shape == s_grad_ref.shape + ), f"Gradient shape mismatch: {s_grad_actual_full.shape} != {s_grad_ref.shape}" + assert ( + s_grad_actual_full.stride() == s_grad_ref.stride() + ), f"Gradient stride mismatch: {s_grad_actual_full.stride()} != {s_grad_ref.stride()}" + assert_close(s_grad_actual_full, s_grad_ref) + + if z_grad_ref is not None: + z_grad_actual_full: Tensor = input_example_as_dtensors["z"].grad.full_tensor() + assert ( + z_grad_actual_full.shape == z_grad_ref.shape + ), f"Gradient shape mismatch: {z_grad_actual_full.shape} != {z_grad_ref.shape}" + assert ( + z_grad_actual_full.stride() == z_grad_ref.stride() + ), f"Gradient stride mismatch: {z_grad_actual_full.stride()} != {z_grad_ref.stride()}" + assert_close(z_grad_actual_full, z_grad_ref) + + # -------------------------------------------------------------------- + # V10: multi-proc parameter gradient values close-to single-proc + # + # (1) Trigger reductions on DTensor gradients before evaluating assert + # --------------------------------------------------------------------- + param_grads_actual_as_tensors = OrderedDict() + for name, param in multi_proc_module.named_parameters(): + if (param.grad is None) != (parameter_grads_ref_as_tensors[name] is None): + raise ValueError( + f"Inconsistent grad state for {name} on rank {rank}: " + f"result grad is {param.grad is None}, " + f"reference grad is {parameter_grads_ref_as_tensors[name] is None}" + ) + param_grads_actual_as_tensors[name] = None if param.grad is None else param.grad.full_tensor() + + for name, grad_ref in parameter_grads_ref_as_tensors.items(): + if (grad_ref is None) != (param_grads_actual_as_tensors[name] is None): + raise ValueError( + f"Inconsistent grad state for {name} on rank {rank}: " + f"result grad is {param_grads_actual_as_tensors[name] is None}, " + f"reference grad is {grad_ref is None}" + ) + grad_actual = param_grads_actual_as_tensors[name] + assert grad_actual.shape == grad_ref.shape, f"Gradient shape mismatch: {grad_actual.shape} != {grad_ref.shape}" + assert ( + grad_actual.stride() == grad_ref.stride() + ), f"Gradient stride mismatch: {grad_actual.stride()} != {grad_ref.stride()}" + assert_close( + grad_actual.cpu(), + grad_ref, + msg=lambda msg: f"Rank {rank} {name} grad mismatch\n{msg}\ngot:{grad_actual}\nwant:{grad_ref}", + ) + assert_all_identical(grad_actual, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +def get_example_input_and_reference_output( + bs: int, + N_tokens: int, + c_s: int, + c_z: int, + num_heads: int, + inf: float, + serial_version: str, + apply_initial_norm: bool, + multiplicity: int = 1, + seed: int = SEED, +) -> tuple[ + OrderedDict[str, Union[Tensor, Any]], + Tensor, + Tensor, + OrderedDict[str, Union[Tensor, Any]], + OrderedDict[str, Union[Tensor, Any]], + OrderedDict[str, Union[Tensor, Any]], +]: + # ---------------------------------------- + # (0) Check use-case requirements + # ---------------------------------------- + if multiplicity != 1: + raise ValueError("multiplicity must be 1 for this use-case") + + # ---------------------------------------- + # (1) Initialize RNG + # ---------------------------------------- + seed_by_rank(0, seed=seed) + + # ------------------------------------- + # (2) Create example inputs on host + # Both V1 and V2 pairformer use k_in=s (no to_keys transformation) + # ------------------------------------- + s = torch.rand(size=(bs, N_tokens, c_s), dtype=torch.float, requires_grad=True) + input_example: OrderedDict[str, Union[Tensor, Any]] = OrderedDict( + [ + ("s", s), + ("z", torch.rand(size=(bs, N_tokens, N_tokens, c_z), dtype=torch.float, requires_grad=True)), + ("mask", torch.randint(0, 2, size=(bs, N_tokens), dtype=torch.float)), + ("multiplicity", multiplicity), + ] + ) + input_example_clone: OrderedDict[str, Union[Tensor, int, None]] = OrderedDict( + [(k, v.detach().clone() if isinstance(v, Tensor) else deepcopy(v)) for k, v in input_example.items()] + ) + # ---------------------------------------------------------------------- + # (3) Create single proc module on cpu / host + # Run FW, BW with reference module on example inputs + # ----------------------------------------------------------------------- + if serial_version == "v1": + module_ref = AttentionPairBiasSerialV1( + c_s=c_s, + c_z=c_z, + num_heads=num_heads, + inf=inf, + initial_norm=apply_initial_norm, + ) + else: + module_ref = AttentionPairBiasSerialV2( + c_s=c_s, + c_z=c_z, + num_heads=num_heads, + inf=inf, + compute_pair_bias=True, + ) + module_ref.proj_o.reset_parameters() # avoid zero initialization + module_ref = module_ref.train() + state_dict_ref = module_ref.state_dict() + + # For pairformer, k_in=s (queries equal keys) + # V1: forward(s, z, mask, multiplicity) -- k_in defaults to s internally + # V2: forward(s, z, mask, k_in=s, multiplicity) + if serial_version == "v1": + output_ref = module_ref( + s=input_example["s"], + z=input_example["z"], + mask=input_example["mask"], + multiplicity=input_example["multiplicity"], + ) + else: + output_ref = module_ref( + s=input_example["s"], + z=input_example["z"], + mask=input_example["mask"], + k_in=input_example["s"], # k_in = s for pairformer + multiplicity=input_example["multiplicity"], + ) + output_grad_example = torch.rand_like(output_ref) + output_grad_example_clone = output_grad_example.detach().clone() + output_ref.backward(output_grad_example) + + input_grads_ref_as_tensors: OrderedDict[str, Union[Tensor, None]] = OrderedDict( + [ + ("s", input_example["s"].grad), + ("z", input_example["z"].grad), + ] + ) + + # ---------------------------------------------------------------- + # V1: single-proc FW input tensor values unchanged by FW and BW + # ---------------------------------------------------------------- + for k in input_example: + assert_close(input_example[k], input_example_clone[k]) + + # --------------------------------------------------------------- + # V2: single-proc BW input tensor values unchanged by BW + # --------------------------------------------------------------- + assert_close(output_grad_example, output_grad_example_clone) + + # -------------------------------------- + # (3) Get parameter gradients, on host + # -------------------------------------- + parameter_grads_ref_as_tensors = OrderedDict() + for name, parameter in module_ref.named_parameters(): + if parameter.grad is not None: + parameter_grads_ref_as_tensors[name] = parameter.grad + else: + parameter_grads_ref_as_tensors[name] = None + + return ( + input_example_clone, + output_ref.detach(), + output_grad_example, + input_grads_ref_as_tensors, + parameter_grads_ref_as_tensors, + state_dict_ref, + ) + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize( + "version_config", + [ + # (serial_version, apply_initial_norm, use_model_cache) + # PairFormer always uses compute_pair_bias=True + ("v1", True, False), # V1 PairFormer: initial_norm=True, no cache + ("v2", False, False), # V2 PairFormer: no initial norm, no cache + ], + ids=lambda x: f"serial:{x[0]}, init_norm:{x[1]}, cache:{x[2]}", +) +def test_attention_pair_bias_with_dtensor_for_pairformer_use_case( + setup_env: dict[str, int], + version_config: tuple[str, bool, bool], + bs: int = 2, + c_s: int = 2 * 5, + num_heads: int = 5, + c_z: int = 7, + multiplicity: int = 1, + seed: int = SEED, +): + serial_version, apply_initial_norm, use_model_cache = version_config + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + skip_if_cuda_not_avail_or_device_count_less_than_word_size(device_type, world_size) + N_tokens: int = grid_group_sizes["cp"][0] * 32 + inf = 1e6 + + # ------------------------------------------------------------ + # (0) Check use-case requirements / implementation scope + # ------------------------------------------------------------ + if multiplicity != 1: + raise ValueError("multiplicity must be 1 for this use-case") + + # ---------------------------------------- + # (1) Get example inputs and reference outputs + # ---------------------------------------- + ( + input_example, + output_ref, + output_grad_example, + input_grads_ref_as_tensors, + parameter_grads_ref_as_tensors, + state_dict_ref, + ) = get_example_input_and_reference_output( + bs, + N_tokens, + c_s, + c_z, + num_heads, + inf, + serial_version, + apply_initial_norm, + multiplicity, + seed=seed, + ) + spawn_multiprocessing( + assert_attention_pair_bias_with_dtensor_fw_bw, + world_size, + input_example, + output_ref, + output_grad_example, + input_grads_ref_as_tensors, + c_s, + c_z, + parameter_grads_ref_as_tensors, + num_heads, + state_dict_ref, + inf, + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_version, + apply_initial_norm, + use_model_cache, + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/distributed/model/layers/test_dtensor_atom_to_token.py b/tests/distributed/model/layers/test_dtensor_atom_to_token.py new file mode 100644 index 000000000..093788a45 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_atom_to_token.py @@ -0,0 +1,816 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from math import isqrt +from typing import Dict, Optional + +import pytest +import torch +from torch import Tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.atom_to_token import ( + pair_repr_token_to_atom, + reconstruct_atom_to_token_global, + reconstruct_r_set_to_rep_atom_global, + reconstruct_token_to_rep_atom_global, + single_repr_atom_to_token, + single_repr_token_to_atom, +) +from boltz.distributed.testing.utils import create_atom_to_token_dtensor +from boltz.testing.utils import seed_by_rank, spawn_multiprocessing + + +def create_mock_atom_to_token_tensor(batch_size: int, n_tokens: int, n_atoms: int, cp_size: int) -> Tensor: + """Create a mock atom_to_token one-hot tensor with diagonal block structure. + + Each atom maps to exactly one token within the same CP shard, producing a + (B, N_atoms, N_tokens) one-hot matrix with non-zero entries only on the + block diagonal. Multiple atoms may map to the same token. + """ + assert n_atoms >= n_tokens, "n_atoms must be greater than or equal to n_tokens" + assert n_tokens % cp_size == 0 and n_atoms % cp_size == 0, "n_tokens and n_atoms must be divisible by cp_size" + num_tokens_per_shard = n_tokens // cp_size + num_atoms_per_shard = n_atoms // cp_size + + atom_to_token_global = torch.zeros(batch_size, n_atoms, n_tokens) # block diagonal one-hot matrix + for sample_idx in range(batch_size): + for cp_idx in range(cp_size): + start_token_idx = cp_idx * num_tokens_per_shard + num_atoms_per_token = torch.randint( + 1, num_atoms_per_shard // num_tokens_per_shard + 1, (num_tokens_per_shard,) + ) + atom_indices = torch.cumsum(num_atoms_per_token, dim=0) + atom_indices = torch.clamp(atom_indices, max=num_atoms_per_shard) + cp_idx * num_atoms_per_shard + atom_indices = [cp_idx * num_atoms_per_shard] + atom_indices.tolist() + for token_idx_in_shard, (atom_start_idx, atom_end_idx) in enumerate( + zip(atom_indices[:-1], atom_indices[1:]) + ): + atom_to_token_global[sample_idx, atom_start_idx:atom_end_idx, start_token_idx + token_idx_in_shard] = 1 + + return atom_to_token_global + + +def create_mock_token_to_rep_atom_tensor(batch_size: int, n_tokens: int, n_atoms: int, cp_size: int) -> Tensor: + """Create a mock token_to_rep_atom one-hot tensor with diagonal block structure. + + Each token selects exactly one representative atom from the atoms belonging + to the same CP shard, producing a (B, N_tokens, N_atoms) one-hot matrix + with non-zero entries only on the block diagonal. + """ + assert n_atoms >= n_tokens + assert n_tokens % cp_size == 0 and n_atoms % cp_size == 0 + num_tokens_per_shard = n_tokens // cp_size + num_atoms_per_shard = n_atoms // cp_size + + token_to_rep_atom_global = torch.zeros(batch_size, n_tokens, n_atoms) + for sample_idx in range(batch_size): + for cp_idx in range(cp_size): + token_start = cp_idx * num_tokens_per_shard + atom_start = cp_idx * num_atoms_per_shard + for t in range(num_tokens_per_shard): + rep_atom = atom_start + torch.randint(0, num_atoms_per_shard, (1,)).item() + token_to_rep_atom_global[sample_idx, token_start + t, rep_atom] = 1 + + return token_to_rep_atom_global + + +def create_mock_r_set_to_rep_atom_tensor(batch_size: int, n_r_set: int, n_atoms: int, cp_size: int) -> Tensor: + """Create a mock r_set_to_rep_atom one-hot tensor with diagonal block structure. + + Each R-set element selects exactly one representative atom from the atoms + belonging to the same CP shard, producing a (B, N_R, N_atoms) one-hot matrix + with non-zero entries only on the block diagonal. n_r_set must be divisible + by cp_size. + """ + assert n_atoms >= n_r_set + assert n_r_set % cp_size == 0 and n_atoms % cp_size == 0 + num_r_per_shard = n_r_set // cp_size + num_atoms_per_shard = n_atoms // cp_size + + r_set_global = torch.zeros(batch_size, n_r_set, n_atoms) + for sample_idx in range(batch_size): + for cp_idx in range(cp_size): + r_start = cp_idx * num_r_per_shard + atom_start = cp_idx * num_atoms_per_shard + for r in range(num_r_per_shard): + rep_atom = atom_start + torch.randint(0, num_atoms_per_shard, (1,)).item() + r_set_global[sample_idx, r_start + r, rep_atom] = 1 + + return r_set_global + + +def create_single_repr_token_to_atom_global_expectation( + batch_size: int, n_tokens: int, n_atoms: int, dim: int, cp_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Create global tensors for single_repr_token_to_atom operation. + + Args: + batch_size: Batch size + n_tokens: Number of tokens + n_atoms: Number of atoms + dim: Feature dimension + device: Device to place tensors on + + Returns: + tuple: (token_single_repr_global, atom_to_token_global) + """ + # Create global tensors + token_repr_global = torch.randn(batch_size, n_tokens, dim, device=device, requires_grad=True) + atom_to_token_global = create_mock_atom_to_token_tensor(batch_size, n_tokens, n_atoms, cp_size) + atom_to_token_global = atom_to_token_global.to(device) + + # Clone inputs for distribution + token_repr_global_clone = token_repr_global.detach().clone() + atom_to_token_global_clone = atom_to_token_global.detach().clone() + + # Compute expected result + result_global_expected = torch.bmm(atom_to_token_global, token_repr_global) + + # Create gradients for backward pass + dy_global = torch.rand_like(result_global_expected) + + # Backward pass on global tensors + result_global_expected.backward(dy_global) + + return ( + token_repr_global_clone, + atom_to_token_global_clone, + result_global_expected.detach().clone(), + token_repr_global.grad.detach().clone(), + dy_global.detach().clone(), + ) + + +def create_single_repr_atom_to_token_global_expectation( + batch_size: int, n_tokens: int, n_atoms: int, dim: int, cp_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Create global tensors for single_repr_atom_to_token operation. + + Args: + batch_size: Batch size + n_tokens: Number of tokens + n_atoms: Number of atoms + dim: Feature dimension + device: Device to place tensors on + + Returns: + tuple: (atom_single_repr_global, atom_to_token_global, result_global_expected, + atom_repr_global_grad, dy_global) + """ + # Create global tensors + atom_repr_global = torch.randn(batch_size, n_atoms, dim, device=device, requires_grad=True) + atom_to_token_global = create_mock_atom_to_token_tensor(batch_size, n_tokens, n_atoms, cp_size) + atom_to_token_global = atom_to_token_global.to(device) + + # Clone inputs for distribution + atom_repr_global_clone = atom_repr_global.detach().clone() + atom_to_token_global_clone = atom_to_token_global.detach().clone() + + # Compute expected result + atom_to_token_sum = atom_to_token_global.sum(dim=1, keepdim=True) + 1e-6 + atom_to_token_mean = atom_to_token_global / atom_to_token_sum + result_global_expected = torch.bmm(atom_to_token_mean.transpose(1, 2), atom_repr_global) + + # Create gradients for backward pass + dy_global = torch.rand_like(result_global_expected) + + # Backward pass on global tensors + result_global_expected.backward(dy_global) + + return ( + atom_repr_global_clone, + atom_to_token_global_clone, + result_global_expected.detach().clone(), + atom_repr_global.grad.detach().clone(), + dy_global.detach().clone(), + ) + + +def compute_pair_repr_token_to_atom_global_expectation(batch_size, n_tokens, n_atoms, dim, cp_size, device): + """Compute expected results using global tensors. + + Args: + batch_size: Batch size + n_tokens: Number of tokens + n_atoms: Number of atoms + dim: Feature dimension + device: Device to place tensors on + Returns: + tuple: (token_repr_global, atom_to_token_global, result_global_expected, + token_repr_global_grad, dy_global) + """ + # Create global tensors + token_repr_global = torch.randn(batch_size, n_tokens, n_tokens, dim, device=device, requires_grad=True) + atom_to_token_global = create_mock_atom_to_token_tensor(batch_size, n_tokens, n_atoms, cp_size) + atom_to_token_global = atom_to_token_global.to(device) + + # Clone inputs for distribution + token_repr_global_clone = token_repr_global.detach().clone() + atom_to_token_global_clone = atom_to_token_global.detach().clone() + + # Compute expected result + result_global_expected = torch.einsum( + "bijd,bmi,bnj->bmnd", token_repr_global, atom_to_token_global, atom_to_token_global + ) + + # Create gradients for backward pass + dy_global = torch.rand_like(result_global_expected) + + # Backward pass on global tensors + result_global_expected.backward(dy_global) + + return ( + token_repr_global_clone, + atom_to_token_global_clone, + result_global_expected.detach().clone(), + token_repr_global.grad.detach().clone(), + dy_global.detach().clone(), + ) + + +def assert_single_repr_token_to_atom( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + """Test distributed single_repr_token_to_atom operation in a parallel environment. + + This test validates that the single_repr_token_to_atom function produces identical + results to the equivalent global tensor computation. It verifies: + + 1. Forward pass produces the same results as global tensor computation + 2. Backward pass correctly propagates gradients through the distributed operation + 3. Results and gradients match the equivalent global tensor operations + + Args: + rank: The process rank in the distributed environment + grid_group_sizes: Dictionary mapping group names to their sizes for distributed setup + device_type: Device to run the test on ("cpu" or "cuda") + backend: The distributed backend to use (e.g., "gloo", "nccl") + env_map: Optional dictionary of environment variables to set before initialization + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters + batch_size = 2 + n_tokens_per_rank = 4 + n_tokens_global = size_ring * n_tokens_per_rank + n_atoms_per_rank = n_tokens_per_rank * 3 + n_atoms_global = size_ring * n_atoms_per_rank + dim = 5 + + # Set random seed based on rank for reproducibility + seed_by_rank(0) + + # Compute global expectations + ( + token_repr_global, + atom_to_token_global, + result_global_expected, + token_repr_global_grad, + dy_global, + ) = create_single_repr_token_to_atom_global_expectation( + batch_size, n_tokens_global, n_atoms_global, dim, size_ring, manager.device + ) + + # Create distributed tensors + # token_repr: Shape (B, n_tokens, D) with placement (Shard(0), Shard(1), Replicate()) + single_repr_placements = [Shard(dim=0), Shard(dim=1), Replicate()] + device_mesh: DeviceMesh = manager.device_mesh_subgroups + token_repr_dtensor = distribute_tensor(token_repr_global, device_mesh, single_repr_placements) + token_repr_dtensor.requires_grad = True + + # atom_to_token: Shape (B, n_tokens, n_atoms) with placement (Shard(0), Shard(1), Replicate()) + atom_to_token_dtensor = create_atom_to_token_dtensor(atom_to_token_global, manager.device_mesh_subgroups) + + # Compute on distributed tensors using single_repr_token_to_atom + result_dtensor = single_repr_token_to_atom(token_repr_dtensor, atom_to_token_dtensor) + + # Distribute the upstream adjoint for backward pass + # Expected result shape: (B, n_atoms, D) + dy_dtensor = distribute_tensor(dy_global, device_mesh, single_repr_placements) + + # Perform backward pass + result_dtensor.backward(dy_dtensor) + + # Create distributed tensors from global expectations for comparison + token_repr_grad_dtensor_expected = distribute_tensor(token_repr_global_grad, device_mesh, single_repr_placements) + result_dtensor_expected = distribute_tensor(result_global_expected, device_mesh, single_repr_placements) + + # Compare results with expected local shards + torch.testing.assert_close(result_dtensor_expected, result_dtensor) + torch.testing.assert_close(token_repr_grad_dtensor_expected, token_repr_dtensor.grad) + + # Test shape and stride consistency + assert ( + result_dtensor.shape == result_dtensor_expected.shape + ), f"Output shape mismatch: {result_dtensor.shape} != {result_dtensor_expected.shape}" + assert ( + result_dtensor.stride() == result_dtensor_expected.stride() + ), f"Output stride mismatch: {result_dtensor.stride()} != {result_dtensor_expected.stride()}" + assert ( + token_repr_dtensor.grad.shape == token_repr_grad_dtensor_expected.shape + ), f"Gradient shape mismatch: {token_repr_dtensor.grad.shape} != {token_repr_grad_dtensor_expected.shape}" + assert ( + token_repr_dtensor.grad.stride() == token_repr_grad_dtensor_expected.stride() + ), f"Gradient stride mismatch: {token_repr_dtensor.grad.stride()} != {token_repr_grad_dtensor_expected.stride()}" + + # Collect results as global tensors and compare with original global tensors + result_global_result = result_dtensor.full_tensor() + token_repr_grad_global_result = token_repr_dtensor.grad.full_tensor() + + # Assert output and input gradients match the global computation + torch.testing.assert_close(result_global_result, result_global_expected) + torch.testing.assert_close(token_repr_grad_global_result, token_repr_global_grad) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_single_repr_token_to_atom(setup_env): + """Test distributed single_repr_token_to_atom operation across multiple processes.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + assert_single_repr_token_to_atom, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +def assert_single_repr_atom_to_token( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + """Test distributed single_repr_atom_to_token operation in a parallel environment. + + This test validates that the single_repr_atom_to_token function produces identical + results to the equivalent global tensor computation. It verifies: + + 1. Forward pass produces the same results as global tensor computation + 2. Backward pass correctly propagates gradients through the distributed operation + 3. Results and gradients match the equivalent global tensor operations + + Args: + rank: The process rank in the distributed environment + grid_group_sizes: Dictionary mapping group names to their sizes for distributed setup + device_type: Device to run the test on ("cpu" or "cuda") + backend: The distributed backend to use (e.g., "gloo", "nccl") + env_map: Optional dictionary of environment variables to set before initialization + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters + batch_size = 2 + n_tokens_per_rank = 4 + n_tokens_global = size_ring * n_tokens_per_rank + n_atoms_per_rank = n_tokens_per_rank * 3 + n_atoms_global = size_ring * n_atoms_per_rank + dim = 5 + + # Set random seed based on rank for reproducibility + seed_by_rank(0) + + # Compute global expectations + ( + atom_repr_global, + atom_to_token_global, + result_global_expected, + atom_repr_global_grad, + dy_global, + ) = create_single_repr_atom_to_token_global_expectation( + batch_size, n_tokens_global, n_atoms_global, dim, size_ring, manager.device + ) + + # Create distributed tensors + # atom_repr: Shape (B, n_atoms, D) with placement (Shard(0), Shard(1), Replicate()) + single_repr_placements = [Shard(dim=0), Shard(dim=1), Replicate()] + device_mesh: DeviceMesh = manager.device_mesh_subgroups + atom_repr_dtensor = distribute_tensor(atom_repr_global, device_mesh, single_repr_placements) + atom_repr_dtensor.requires_grad = True + + # atom_to_token: Shape (B, n_atoms, n_tokens) with placement (Shard(0), Shard(1), Replicate()) + atom_to_token_dtensor = create_atom_to_token_dtensor(atom_to_token_global, manager.device_mesh_subgroups) + + # Compute on distributed tensors using single_repr_atom_to_token + result_dtensor = single_repr_atom_to_token(atom_repr_dtensor, atom_to_token_dtensor) + + # Distribute the upstream adjoint for backward pass + # Expected result shape: (B, n_tokens, D) + dy_dtensor = distribute_tensor(dy_global, device_mesh, single_repr_placements) + + # Perform backward pass + result_dtensor.backward(dy_dtensor) + + # Create distributed tensors from global expectations for comparison + atom_repr_grad_dtensor_expected = distribute_tensor(atom_repr_global_grad, device_mesh, single_repr_placements) + result_dtensor_expected = distribute_tensor(result_global_expected, device_mesh, single_repr_placements) + + # Compare results with expected local shards + assert ( + result_dtensor.shape == result_dtensor_expected.shape + ), f"Output shape mismatch: {result_dtensor.shape} != {result_dtensor_expected.shape}" + assert ( + result_dtensor.stride() == result_dtensor_expected.stride() + ), f"Output stride mismatch: {result_dtensor.stride()} != {result_dtensor_expected.stride()}" + torch.testing.assert_close(result_dtensor_expected, result_dtensor) + + assert ( + atom_repr_dtensor.grad.shape == atom_repr_grad_dtensor_expected.shape + ), f"Gradient shape mismatch: {atom_repr_dtensor.grad.shape} != {atom_repr_grad_dtensor_expected.shape}" + assert ( + atom_repr_dtensor.grad.stride() == atom_repr_grad_dtensor_expected.stride() + ), f"Gradient stride mismatch: {atom_repr_dtensor.grad.stride()} != {atom_repr_grad_dtensor_expected.stride()}" + torch.testing.assert_close(atom_repr_grad_dtensor_expected, atom_repr_dtensor.grad) + + # Collect results as global tensors and compare with original global tensors + result_global_result = result_dtensor.full_tensor() + atom_repr_grad_global_result = atom_repr_dtensor.grad.full_tensor() + + # Assert output and input gradients match the global computation + torch.testing.assert_close(result_global_result, result_global_expected) + torch.testing.assert_close(atom_repr_grad_global_result, atom_repr_global_grad) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_single_repr_atom_to_token(setup_env): + """Test distributed single_repr_atom_to_token operation across multiple processes.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + assert_single_repr_atom_to_token, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +def assert_pair_repr_token_to_atom( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + """Test distributed atom_to_token operation in a parallel environment. + + This test validates that the pair_repr_token_to_atom function produces identical + results to the equivalent global tensor computation. It verifies: + + 1. Forward pass produces the same results as global tensor computation + 2. Backward pass correctly propagates gradients through the distributed operation + 3. Results and gradients match the equivalent global tensor operations + + Args: + rank: The process rank in the distributed environment + grid_group_sizes: Dictionary mapping group names to their sizes for distributed setup + device_type: Device to run the test on ("cpu" or "cuda") + backend: The distributed backend to use (e.g., "gloo", "nccl") + env_map: Optional dictionary of environment variables to set before initialization + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters + batch_size = 2 + n_tokens_per_rank = 4 + n_tokens_global = size_ring * n_tokens_per_rank + n_atoms_per_rank = n_tokens_per_rank * 3 + n_atoms_global = size_ring * n_atoms_per_rank + dim = 5 + + # Set random seed based on rank for reproducibility + seed_by_rank(0) + + # Compute global expectations + ( + token_repr_global, + atom_to_token_global, + result_global_expected, + token_repr_global_grad, + dy_global, + ) = compute_pair_repr_token_to_atom_global_expectation( + batch_size, n_tokens_global, n_atoms_global, dim, size_ring, manager.device + ) + + # Create distributed tensors + # token_repr: Shape (B, n_tokens, n_tokens, D) with placement (Shard(0), Shard(1), Shard(2)) + pair_repr_placements = [Shard(dim=0), Shard(dim=1), Shard(dim=2)] + device_mesh: DeviceMesh = manager.device_mesh_subgroups + token_repr_dtensor = distribute_tensor(token_repr_global, device_mesh, pair_repr_placements) + token_repr_dtensor.requires_grad = True + + # atom_to_token: Shape (B, n_tokens, n_atoms) with placement (Shard(0), Shard(1), Replicate()) + atom_to_token_dtensor = create_atom_to_token_dtensor(atom_to_token_global, manager.device_mesh_subgroups) + + # Create TransposeComm for communication + # The function requires a transpose communication object + cp_group = manager.group["cp"] + layout_group_cp = manager.layout_subgroups["cp"] + transpose_comm = TransposeComm(cp_group, layout_group_cp) + + # Compute on distributed tensors using pair_repr_token_to_atom + result_dtensor = pair_repr_token_to_atom(token_repr_dtensor, atom_to_token_dtensor, transpose_comm) + + # Distribute the upstream adjoint for backward pass + # Expected result shape: (B, n_atoms, n_atoms, D) + dy_dtensor = distribute_tensor(dy_global, device_mesh, pair_repr_placements) + + # Perform backward pass + result_dtensor.backward(dy_dtensor) + + # Create distributed tensors from global expectations for comparison + token_repr_grad_dtensor_expected = distribute_tensor(token_repr_global_grad, device_mesh, pair_repr_placements) + result_dtensor_expected = distribute_tensor(result_global_expected, device_mesh, pair_repr_placements) + + # Compare results with expected local shards + assert ( + result_dtensor.shape == result_dtensor_expected.shape + ), f"Output shape mismatch: {result_dtensor.shape} != {result_dtensor_expected.shape}" + # We can't guarantee same layout because two different einsum operations are used + # in the DTensor version and the serial version + # assert result_dtensor.stride() == result_dtensor_expected.stride(), ( + # f"Output stride mismatch: {result_dtensor.stride()} != {result_dtensor_expected.stride()}" + # ) + torch.testing.assert_close(result_dtensor_expected, result_dtensor) + + assert ( + token_repr_dtensor.grad.shape == token_repr_grad_dtensor_expected.shape + ), f"Gradient shape mismatch: {token_repr_dtensor.grad.shape} != {token_repr_grad_dtensor_expected.shape}" + assert ( + token_repr_dtensor.grad.stride() == token_repr_grad_dtensor_expected.stride() + ), f"Gradient stride mismatch: {token_repr_dtensor.grad.stride()} != {token_repr_grad_dtensor_expected.stride()}" + torch.testing.assert_close(token_repr_grad_dtensor_expected, token_repr_dtensor.grad) + + # Collect results as global tensors and compare with original global tensors + result_global_result = result_dtensor.full_tensor() + token_repr_grad_global_result = token_repr_dtensor.grad.full_tensor() + + # Assert output and input gradients match the global computation + torch.testing.assert_close(result_global_result, result_global_expected) + torch.testing.assert_close(token_repr_grad_global_result, token_repr_global_grad) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_pair_repr_token_to_atom(setup_env): + """Test distributed pair_repr_token_to_atom operation across multiple processes.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + assert_pair_repr_token_to_atom, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +def assert_reconstruct_onehot_diag_block_global( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + """Validate all three reconstruct functions for diagonally-sharded one-hot DTensors. + + Tests: + 1. reconstruct_atom_to_token_global — round-trip matches original global tensor and + produces correct results when used with single_repr_token_to_atom. + 2. reconstruct_token_to_rep_atom_global — round-trip matches original global tensor. + 3. reconstruct_r_set_to_rep_atom_global — round-trip matches original global tensor. + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + batch_size = grid_group_sizes["dp"] + n_tokens_per_rank = 4 + n_tokens_global = size_ring * n_tokens_per_rank + n_atoms_per_rank = n_tokens_per_rank * 3 + n_atoms_global = size_ring * n_atoms_per_rank + n_r_set_per_rank = 2 + n_r_set_global = size_ring * n_r_set_per_rank + dim = 5 + + seed_by_rank(0) + device_mesh: DeviceMesh = manager.device_mesh_subgroups + + dp_rank = device_mesh.get_coordinate()[0] + dp_size = device_mesh.shape[0] + local_batch = batch_size // dp_size + + # --- reconstruct_atom_to_token_global --- + token_repr_global = torch.randn(batch_size, n_tokens_global, dim, device=manager.device) + atom_to_token_global = create_mock_atom_to_token_tensor(batch_size, n_tokens_global, n_atoms_global, size_ring).to( + manager.device + ) + + single_repr_placements = [Shard(dim=0), Shard(dim=1), Replicate()] + token_repr_dtensor = distribute_tensor(token_repr_global, device_mesh, single_repr_placements) + atom_to_token_dtensor = create_atom_to_token_dtensor(atom_to_token_global, device_mesh) + + atom_to_token_reconstructed = reconstruct_atom_to_token_global(atom_to_token_dtensor) + atom_to_token_dp_local = atom_to_token_global[dp_rank * local_batch : (dp_rank + 1) * local_batch] + torch.testing.assert_close(atom_to_token_reconstructed, atom_to_token_dp_local) + + result_dtensor = single_repr_token_to_atom(token_repr_dtensor, atom_to_token_dtensor) + result_full_local = result_dtensor.redistribute( + placements=[Shard(dim=0), Replicate(), Replicate()], + ).to_local() + token_repr_full_local = token_repr_dtensor.redistribute( + placements=[Shard(dim=0), Replicate(), Replicate()], + ).to_local() + expected_full_local = torch.bmm( + atom_to_token_reconstructed.to(dtype=token_repr_full_local.dtype), + token_repr_full_local, + ) + torch.testing.assert_close(result_full_local, expected_full_local) + + # --- reconstruct_token_to_rep_atom_global --- + token_to_rep_atom_global = create_mock_token_to_rep_atom_tensor( + batch_size, n_tokens_global, n_atoms_global, size_ring + ).to(manager.device) + token_to_rep_atom_dtensor = create_atom_to_token_dtensor(token_to_rep_atom_global, device_mesh) + + token_to_rep_atom_reconstructed = reconstruct_token_to_rep_atom_global(token_to_rep_atom_dtensor) + token_to_rep_atom_dp_local = token_to_rep_atom_global[dp_rank * local_batch : (dp_rank + 1) * local_batch] + torch.testing.assert_close(token_to_rep_atom_reconstructed, token_to_rep_atom_dp_local) + + # --- reconstruct_r_set_to_rep_atom_global --- + r_set_to_rep_atom_global = create_mock_r_set_to_rep_atom_tensor( + batch_size, n_r_set_global, n_atoms_global, size_ring + ).to(manager.device) + r_set_to_rep_atom_dtensor = create_atom_to_token_dtensor(r_set_to_rep_atom_global, device_mesh) + + r_set_to_rep_atom_reconstructed = reconstruct_r_set_to_rep_atom_global(r_set_to_rep_atom_dtensor) + r_set_to_rep_atom_dp_local = r_set_to_rep_atom_global[dp_rank * local_batch : (dp_rank + 1) * local_batch] + torch.testing.assert_close(r_set_to_rep_atom_reconstructed, r_set_to_rep_atom_dp_local) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_reconstruct_onehot_diag_block_global(setup_env): + """Test all diagonal-block reconstruction functions: atom_to_token, token_to_rep_atom, r_set_to_rep_atom.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + assert_reconstruct_onehot_diag_block_global, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) diff --git a/tests/distributed/model/layers/test_dtensor_attention.py b/tests/distributed/model/layers/test_dtensor_attention.py new file mode 100644 index 000000000..8c4b097af --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_attention.py @@ -0,0 +1,908 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from functools import partial + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.comm import AttentionPairBiasComm +from boltz.distributed.data.feature.featurizer_utils import get_pair_mask +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.attention import AttentionPairBias, AttentionPairBiasShardwise +from boltz.distributed.model.layers.utils import convert_single_repr_window_batched_query_to_key +from boltz.distributed.model.modules.utils import SDPAWithBiasBackend +from boltz.model.layers.attention import AttentionPairBias as SerialAttentionPairBiasV1 +from boltz.model.layers.attentionv2 import AttentionPairBias as SerialAttentionPairBiasV2 +from boltz.model.modules.encodersv2 import get_indexing_matrix, single_to_keys +from boltz.testing.utils import ( + assert_tensors_identical, + get_to_keys, + init_module_params_uniform, + init_tensors_uniform, + is_a6000_gpu, + pair_global_to_window_batch, + seed_by_rank, + spawn_multiprocessing, +) + + +def assert_attention_pair_bias_for_atom_diffusion( + rank: int, + payload: tuple, +): + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + multiplicity, + sdpa_with_bias_backend, + c_s, + c_z, + num_heads, + inf, + state_dict_reference, + s_global_host_fp64, + z_global_host_fp64, + mask_global_host_fp64, + pair_mask_global_host_fp64, + o_global_host_fp64, + d_o_global_host_fp64, + d_s_expected_global_host_fp64, + d_z_expected_global_host_fp64, + grad_params_fp64_expected_global_host, + serial_version, + apply_initial_norm, + compute_pair_bias, + use_model_cache, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + if torch.finfo(dtype).resolution < torch.finfo(s_global_host_fp64.dtype).resolution: + raise ValueError( + f"Target dtype {dtype} has higher precision than reference output's dtype {s_global_host_fp64.dtype}" + ) + + seed_by_rank(rank) + + # create module and copy state dict using the appropriate serial version + + if serial_version == "v1": + serial_module = SerialAttentionPairBiasV1( + c_s=c_s, + c_z=c_z, + num_heads=num_heads, + inf=inf, + initial_norm=apply_initial_norm, + ) + else: + serial_module = SerialAttentionPairBiasV2( + c_s=c_s, + c_z=c_z if compute_pair_bias else None, + num_heads=num_heads, + inf=inf, + compute_pair_bias=compute_pair_bias, + ) + serial_module.load_state_dict(state_dict_reference) + serial_module = serial_module.to(device=manager.device) + + ring_comm = AttentionPairBiasComm( + process_group=manager.group["cp"], + group_layout=manager.layout_subgroups["cp"], + cp_axis_0_group=manager.subgroups["cp"][0], + cp_axis_1_group=manager.subgroups["cp"][1], + ) + module = AttentionPairBias( + attn_pair_bias=serial_module, + device_mesh=manager.device_mesh_subgroups, + ring_comm=ring_comm, + sdpa_with_bias_backend=sdpa_with_bias_backend, + apply_initial_norm=apply_initial_norm, + compute_pair_bias=compute_pair_bias, + use_model_cache=use_model_cache, + ) + module = module.to(device=manager.device, dtype=dtype) + module = module.train() + + # Distribute input tensors + placements_single = [Shard(0), Shard(1), Replicate()] + placements_pair = [Shard(0), Shard(1), Shard(2)] + + s_dtensor = distribute_tensor( + s_global_host_fp64.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_single, + ).requires_grad_(True) + z_dtensor = distribute_tensor( + z_global_host_fp64.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_pair, + ).requires_grad_(True) + mask_dtensor = distribute_tensor( + mask_global_host_fp64.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_single, + ).requires_grad_(False) + pair_mask_dtensor = distribute_tensor( + pair_mask_global_host_fp64.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_pair, + ).requires_grad_(False) + + # Distribute output gradient + d_o_dtensor = distribute_tensor( + d_o_global_host_fp64.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_single, + ).requires_grad_(False) + + # Create copies to verify inputs/upstream adjoint aren't modified + s_dtensor_copy = s_dtensor.detach().clone().requires_grad_(True) + z_dtensor_copy = z_dtensor.detach().clone().requires_grad_(True) + mask_dtensor_copy = mask_dtensor.detach().clone().requires_grad_(False) + pair_mask_dtensor_copy = pair_mask_dtensor.detach().clone().requires_grad_(False) + d_o_dtensor_copy = d_o_dtensor.detach().clone().requires_grad_(False) + + # Forward pass + o_dtensor = module( + s=s_dtensor, + z=z_dtensor, + mask=mask_dtensor, + pair_mask=pair_mask_dtensor, + multiplicity=multiplicity, + ) + + # Verify inputs/upstream adjoint weren't modified + assert_tensors_identical(s_dtensor_copy.to_local(), s_dtensor.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(z_dtensor_copy.to_local(), z_dtensor.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical( + mask_dtensor_copy.to_local(), mask_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical( + pair_mask_dtensor_copy.to_local(), pair_mask_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(d_o_dtensor_copy.to_local(), d_o_dtensor.to_local(), check_grad=False, check_grad_fn=False) + + # Test forward pass results + assert ( + o_dtensor.shape == o_global_host_fp64.shape + ), f"Output shape mismatch: {o_dtensor.shape} != {o_global_host_fp64.shape}" + assert ( + o_dtensor.stride() == o_global_host_fp64.stride() + ), f"Output stride mismatch: {o_dtensor.stride()} != {o_global_host_fp64.stride()}" + atom_pad_mask_bool = mask_dtensor.full_tensor().bool() + atom_pad_mask_bool_expanded = atom_pad_mask_bool + if atom_pad_mask_bool.shape[0] != o_dtensor.shape[0]: + atom_pad_mask_bool_expanded = atom_pad_mask_bool.repeat_interleave(multiplicity, 0) + + o_dtensor_full = o_dtensor.full_tensor() + torch.testing.assert_close( + (o_dtensor_full * atom_pad_mask_bool_expanded[:, :, None]).cpu(), + (o_global_host_fp64 * atom_pad_mask_bool_expanded[:, :, None].cpu()).to(dtype=dtype), + ) + + # Backward pass + o_dtensor.backward(d_o_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(s_dtensor_copy.to_local(), s_dtensor.to_local(), check_grad=False, check_grad_fn=False) + + # Test input gradients + s_inputs_dtensor_grad = s_dtensor.grad.full_tensor() + torch.testing.assert_close( + s_inputs_dtensor_grad[~atom_pad_mask_bool_expanded], + torch.zeros_like(s_inputs_dtensor_grad[~atom_pad_mask_bool_expanded]), + ) + torch.testing.assert_close( + s_inputs_dtensor_grad[atom_pad_mask_bool_expanded].cpu(), + d_s_expected_global_host_fp64[atom_pad_mask_bool_expanded.cpu()].to(dtype=dtype), + ) + + z_inputs_dtensor_grad = z_dtensor.grad.full_tensor() + pair_mask_dtensor_full = pair_mask_dtensor.full_tensor().bool() + + # In broadcasting mode, atom_pad_mask_bool and pair_mask_dtensor_full + # already have the original batch size (no need to undo repeat_interleave) + if mask_dtensor.shape[0] == o_dtensor.shape[0]: + pair_mask_dtensor_full_z = pair_mask_dtensor_full[::multiplicity] + else: + pair_mask_dtensor_full_z = pair_mask_dtensor_full + + # Test z gradient (window batching) + bs, num_atoms = z_dtensor.shape[:2] + z_inputs_dtensor_grad_reshaped = pair_global_to_window_batch( + z_inputs_dtensor_grad, + n_atoms_no_pads=torch.tensor([num_atoms] * bs, device=manager.device), + pair_mask_global=pair_mask_dtensor_full_z[:, :, :, None], + ) + torch.testing.assert_close( + z_inputs_dtensor_grad_reshaped.cpu(), + d_z_expected_global_host_fp64.to(dtype=dtype), + ) + + # Gather weight gradients using named_parameters + result_param_grads_dict = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in grad_params_fp64_expected_global_host: + raise ValueError(f"Parameter {name} has a resulting gradient but it is not in the reference module") + result_param_grads_dict[name] = param.grad + + # Compare parameter gradients + for name, expected_grad_global_host in grad_params_fp64_expected_global_host.items(): + assert name in result_param_grads_dict, f"Parameter {name}'s gradient is not found in result gradients" + result_grad = result_param_grads_dict[name] + torch.testing.assert_close(result_grad.full_tensor().cpu(), expected_grad_global_host.to(dtype=dtype)) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + ( + params_test := [ + ((2, (2, 2)), True, "cuda", "ENV"), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}" + for x in params_test + ], +) +@pytest.mark.parametrize( + "config", + [ + (1, False), + (3, False), + (3, True), + ], + ids=lambda x: f"multiplicity:{x[0]}, fix_window_batching:{x[1]}", +) +@pytest.mark.parametrize( + "sdpa_with_bias_backend", + [SDPAWithBiasBackend.REFERENCE, SDPAWithBiasBackend.TORCH_FLEX_ATTN], + ids=lambda x: x.value, +) +@pytest.mark.parametrize( + "version_config", + [ + # (serial_version, apply_initial_norm, compute_pair_bias, use_model_cache) + ("v1", False, True, True), # V1 DTL: initial_norm=False, compute bias, cache z + ("v2", False, False, False), # V2 DTL: no init norm, pre-computed bias, no cache + ], + ids=lambda x: f"serial:{x[0]}, init_norm:{x[1]}, cpb:{x[2]}, cache:{x[3]}", +) +def test_attention_pair_bias_for_atom_diffusion( + setup_env, + config: tuple[int, bool], + sdpa_with_bias_backend: SDPAWithBiasBackend, + version_config: tuple[str, bool, bool, bool], + dtype: torch.dtype = torch.float32, + c_s: int = 16 * 2, + c_z: int = 7, + num_heads: int = 2, + inf: float = 1e6, +): + serial_version, apply_initial_norm, compute_pair_bias, use_model_cache = version_config + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + multiplicity, fix_window_batching = config + + if sdpa_with_bias_backend == SDPAWithBiasBackend.TORCH_FLEX_ATTN and device_type != "cuda": + pytest.skip("torch_flex_attn requires cuda device") + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + if is_a6000_gpu() and world_size > 1: + pytest.skip("skip cuda test because distribute_tensor leads to deadlock on A6000 GPUs") + + if multiplicity > 1 and not fix_window_batching: + pytest.xfail( + "There is a bug in Boltz1's code due to the difference in the order of repeat_interleave and view calls on single representations and pair bias. The context parallel version doesn't suffer from such bug and therefore won't produce consistent results with that of the serial Boltz1" + ) + + seed_by_rank(0) + + # Use the appropriate serial module for the version + # When compute_pair_bias=False (V2 DTL), z last dim is num_heads (pre-computed bias) + # When compute_pair_bias=True, z last dim is c_z (projected through LayerNorm+Linear) + z_last_dim = c_z if compute_pair_bias else num_heads + if serial_version == "v1": + reference_module_fp64 = SerialAttentionPairBiasV1( + c_s=c_s, + c_z=c_z, + num_heads=num_heads, + inf=inf, + initial_norm=apply_initial_norm, + ) + else: + reference_module_fp64 = SerialAttentionPairBiasV2( + c_s=c_s, + c_z=c_z if compute_pair_bias else None, + num_heads=num_heads, + inf=inf, + compute_pair_bias=compute_pair_bias, + ) + reference_module_fp64 = reference_module_fp64.to(device_type, dtype=torch.float64) + reference_module_fp64 = reference_module_fp64.train() + + val_init_range = 0.25 + init_module_params_uniform(reference_module_fp64, low=-val_init_range, high=val_init_range) + state_dict_reference = {k: v.detach().clone().cpu() for k, v in reference_module_fp64.state_dict().items()} + + # mock inputs and output + bs = 2 * grid_group_sizes["dp"] + num_atoms = 128 * 4 # multiple of 128 for window batching + + s = torch.empty( + size=(bs * multiplicity, num_atoms, c_s), + dtype=torch.float64, + requires_grad=True, + device=device_type, + ) # repeat_interleave happens in AtomAttentionEncoder + z = torch.empty( + size=(bs, num_atoms, num_atoms, z_last_dim), + dtype=torch.float64, + requires_grad=False, # z gradient not tested for window batching + device=device_type, + ) # repeat_interleave happens in AttentionPairBias + mask = torch.ones(bs, num_atoms, dtype=torch.float64, device=device_type) + mask[:, -5:] = 0 # insert padding at the end of the sequence + pair_mask = get_pair_mask(num_atoms).to(dtype=torch.float64, device=device_type) + pair_mask = pair_mask.unsqueeze(0).repeat(bs, 1, 1) + + init_tensors_uniform([s, z], low=-val_init_range, high=val_init_range) + + s_global_host_fp64 = s.detach().clone().cpu() + z_global_host_fp64 = z.detach().clone().cpu() + mask_global_host_fp64 = mask.detach().clone().cpu() + pair_mask_global_host_fp64 = pair_mask.detach().clone().cpu() + + # Run serial forward pass (window batching) + # reshape in AtomAttentionEncoder + to_keys = get_to_keys(s) + + # reshape in AtomTransformer + W, H = 32, 128 + B, N, D = s.shape + NW = N // W + + s_reshaped = s.view((B * NW, W, -1)) + to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1) # noqa: E731 + # In Boltz-2, mask is also transformed by to_keys to match key dimension (H=128) + mask_reshaped = to_keys_new(mask.repeat_interleave(multiplicity, 0).unsqueeze(-1)).squeeze(-1) + + # remap pair representation from square to window shape + z_wb = pair_global_to_window_batch( + z, + n_atoms_no_pads=torch.tensor([num_atoms] * bs, device=device_type), + pair_mask_global=pair_mask[:, :, :, None], + ).requires_grad_(True) + z_reshaped = z_wb + + # reshape in AtomTransformer + if fix_window_batching: + # repeat_interleave -> view + z_reshaped = z_reshaped.repeat_interleave(multiplicity, 0) + z_reshaped = z_reshaped.view((B * NW, W, H, -1)) + else: + # view -> repeat_interleave + z_reshaped = z_reshaped.view((B * NW // multiplicity, W, H, -1)) + + # AttentionPairBias forward pass + # V1: uses to_keys internally (mask must be query-aligned, module transforms it) + # V2: uses pre-computed k_in (mask already key-aligned) + if serial_version == "v1": + # V1 expects query-aligned mask (B*NW, W); to_keys is applied inside forward + mask_query = mask.repeat_interleave(multiplicity, 0).view(B * NW, W) + o_attn_global_fp64 = reference_module_fp64( + s=s_reshaped, + z=z_reshaped, + mask=mask_query, + to_keys=to_keys_new, + multiplicity=1 if fix_window_batching else multiplicity, + ) + else: + k_in_reshaped = to_keys_new(s_reshaped) + o_attn_global_fp64 = reference_module_fp64( + s=s_reshaped, + z=z_reshaped, + mask=mask_reshaped, + k_in=k_in_reshaped, + multiplicity=1 if fix_window_batching else multiplicity, + ) + + # reshape in AtomTransformer + o_global_fp64 = o_attn_global_fp64.view((B, NW * W, D)) + + o_global_host_fp64 = o_global_fp64.detach().clone().cpu() + + # Create upstream gradients, apply masks, and run backward pass + d_o_global_fp64 = torch.empty_like(o_global_fp64) # (B, N, D) + init_tensors_uniform([d_o_global_fp64], low=-val_init_range, high=val_init_range) + d_o_global_fp64 = d_o_global_fp64 * mask[:, :, None].repeat_interleave(multiplicity, 0) + d_o_global_host_fp64 = d_o_global_fp64.detach().clone().cpu() + + o_global_fp64.backward(d_o_global_fp64) + + grad_params_fp64_expected_global_host = { + k: v.grad.detach().clone().cpu() for k, v in reference_module_fp64.named_parameters() if v.grad is not None + } + + # Get reference input gradients + d_s_expected_global_host_fp64 = s.grad.detach().clone().cpu() + d_z_expected_global_host_fp64 = z_wb.grad.detach().clone().cpu() + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + multiplicity, + sdpa_with_bias_backend, + c_s, + c_z, + num_heads, + inf, + state_dict_reference, + s_global_host_fp64, + z_global_host_fp64, + mask_global_host_fp64, + pair_mask_global_host_fp64, + o_global_host_fp64, + d_o_global_host_fp64, + d_s_expected_global_host_fp64, + d_z_expected_global_host_fp64, + grad_params_fp64_expected_global_host, + # Version config (last 4 elements, extracted by parallel function) + serial_version, + apply_initial_norm, + compute_pair_bias, + use_model_cache, + ) + + spawn_multiprocessing(assert_attention_pair_bias_for_atom_diffusion, world_size, payload) + + +def parallel_assert_shardwise_attention_pair_bias( + rank: int, + grid_group_sizes: dict[str, int], + device_type: str, + backend: str, + env_map: dict[str, str], + dtype: torch.dtype, + sdpa_with_bias_backend: SDPAWithBiasBackend, + reference_state_dict: dict, + c_s: int, + c_z: int, + num_heads: int, + inf: float, + s_global_host: torch.Tensor, + z_global_host: torch.Tensor, + mask_global_host: torch.Tensor, + k_in_global_host: torch.Tensor, # V2 API: pre-computed k_in + o_global_host: torch.Tensor, + d_o_global_host: torch.Tensor, + d_s_expected_global_host: torch.Tensor, + d_z_expected_global_host: torch.Tensor, + d_k_in_expected_global_host: torch.Tensor, # V2 API: k_in gradient + grad_params_expected_global_host: dict[str, torch.Tensor], + serial_version: str, + apply_initial_norm: bool, + compute_pair_bias: bool, + use_model_cache: bool, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device = manager.device + device_mesh = manager.device_mesh_subgroups + + seed_by_rank(0, 42) + + B, K, W, D = s_global_host.shape + + # Setup module using the appropriate serial version + if serial_version == "v1": + serial_apb_module = SerialAttentionPairBiasV1( + c_s=c_s, + c_z=c_z, + num_heads=num_heads, + inf=inf, + initial_norm=apply_initial_norm, + ) + else: + serial_apb_module = SerialAttentionPairBiasV2( + c_s=c_s, + c_z=c_z if compute_pair_bias else None, + num_heads=num_heads, + inf=inf, + compute_pair_bias=compute_pair_bias, + ) + + serial_apb_module.load_state_dict(reference_state_dict) + serial_apb_module = serial_apb_module.to(device=manager.device) + + module = AttentionPairBiasShardwise( + attn_pair_bias=serial_apb_module, + device_mesh=device_mesh, + sdpa_with_bias_backend=sdpa_with_bias_backend, + apply_initial_norm=apply_initial_norm, + compute_pair_bias=compute_pair_bias, + use_model_cache=use_model_cache, + ) + module = module.to(device=device, dtype=dtype) + module = module.train() + + # NOTE: only need single rep placements because the "pair" is just (K=N//W, W=32, H=128), and K is sharded along CP0 + placements = (Shard(0), Shard(1), Replicate()) + + # Shard the inputs + s_dtensor = distribute_tensor(s_global_host.to(dtype=dtype, device=device), device_mesh, placements).requires_grad_( + True + ) + z_dtensor = distribute_tensor(z_global_host.to(dtype=dtype, device=device), device_mesh, placements).requires_grad_( + True + ) + mask_dtensor = distribute_tensor( + mask_global_host.to(dtype=dtype, device=device), device_mesh, placements + ).requires_grad_(False) + d_o_dtensor = distribute_tensor( + d_o_global_host.to(dtype=dtype, device=device), device_mesh, placements + ).requires_grad_(False) + + # Create copies to verify inputs/upstream adjoint aren't modified + s_dtensor_copy = s_dtensor.detach().clone().requires_grad_(True) + z_dtensor_copy = z_dtensor.detach().clone().requires_grad_(True) + mask_dtensor_copy = mask_dtensor.detach().clone().requires_grad_(False) + d_o_dtensor_copy = d_o_dtensor.detach().clone().requires_grad_(False) + + if serial_version == "v1": + # V1: use to_keys, mask is query-aligned (B, K, W) + to_keys_dt = partial(convert_single_repr_window_batched_query_to_key, W=W, H=z_global_host.shape[3]) + o_dtensor = module(s_dtensor, z_dtensor, mask_dtensor, to_keys=to_keys_dt) + k_in_dtensor = None + k_in_dtensor_copy = None + else: + # V2: use pre-computed k_in, mask is key-aligned (B, K, H) + k_in_dtensor = distribute_tensor( + k_in_global_host.to(dtype=dtype, device=device), device_mesh, placements + ).requires_grad_(True) + k_in_dtensor_copy = k_in_dtensor.detach().clone().requires_grad_(True) + o_dtensor = module(s_dtensor, z_dtensor, mask_dtensor, k_in=k_in_dtensor) + + # Verify inputs/upstream adjoint weren't modified + assert_tensors_identical(s_dtensor_copy.to_local(), s_dtensor.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(z_dtensor_copy.to_local(), z_dtensor.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical( + mask_dtensor_copy.to_local(), mask_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + if k_in_dtensor is not None: + assert_tensors_identical( + k_in_dtensor_copy.to_local(), k_in_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(d_o_dtensor_copy.to_local(), d_o_dtensor.to_local(), check_grad=False, check_grad_fn=False) + + # Verify forward pass results + assert ( + o_dtensor.stride() == o_global_host.stride() + ), f"Output stride mismatch: {o_dtensor.stride()} != {o_global_host.stride()}" + + o_dtensor_full = o_dtensor.full_tensor() + + torch.testing.assert_close(o_dtensor_full.cpu(), o_global_host.cpu().to(dtype=dtype)) + + # Run backward pass of distributed shardwise module + o_dtensor.backward(d_o_dtensor) + + # Verify upstream input wasn't modified + assert_tensors_identical(s_dtensor_copy.to_local(), s_dtensor.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(z_dtensor_copy.to_local(), z_dtensor.to_local(), check_grad=False, check_grad_fn=False) + if k_in_dtensor is not None: + assert_tensors_identical( + k_in_dtensor_copy.to_local(), k_in_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Verify input gradients + s_inputs_dtensor_grad = s_dtensor.grad.full_tensor() + torch.testing.assert_close( + s_inputs_dtensor_grad.cpu(), + d_s_expected_global_host.to(dtype=dtype), + ) + + z_inputs_dtensor_grad = z_dtensor.grad.full_tensor() + torch.testing.assert_close( + z_inputs_dtensor_grad.cpu(), + d_z_expected_global_host.to(dtype=dtype), + ) + + # V2 API: Verify k_in gradient (only when k_in is used as separate input) + if k_in_dtensor is not None and d_k_in_expected_global_host is not None: + k_in_inputs_dtensor_grad = k_in_dtensor.grad.full_tensor() + torch.testing.assert_close( + k_in_inputs_dtensor_grad.cpu(), + d_k_in_expected_global_host.to(dtype=dtype), + ) + + # Verify parameter gradients + result_param_grads_dict = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in grad_params_expected_global_host: + raise ValueError(f"Parameter {name} has a resulting gradient but it is not in the reference module") + result_param_grads_dict[name] = param.grad + + # Compare parameter gradients + for name, expected_grad_global_host in grad_params_expected_global_host.items(): + assert name in result_param_grads_dict, f"Parameter {name}'s gradient is not found in result gradients" + result_grad = result_param_grads_dict[name] + torch.testing.assert_close(result_grad.full_tensor().cpu(), expected_grad_global_host.to(dtype=dtype)) + + # clean up + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize( + "sdpa_with_bias_backend", + [ + SDPAWithBiasBackend.REFERENCE, + SDPAWithBiasBackend.TORCH_SDPA_EFFICIENT_ATTENTION, + SDPAWithBiasBackend.TORCH_FLEX_ATTN, + ], + ids=lambda x: x.value, +) +@pytest.mark.parametrize( + "version_config", + [ + # (serial_version, apply_initial_norm, compute_pair_bias, use_model_cache) + ("v1", False, True, False), # V1 DTL: initial_norm=False, compute bias, no cache + ("v2", False, False, False), # V2 DTL: no init norm, pre-computed bias, no cache + ], + ids=lambda x: f"serial:{x[0]}, init_norm:{x[1]}, cpb:{x[2]}, cache:{x[3]}", +) +def test_shardwise_attention_pair_bias( + setup_env, + sdpa_with_bias_backend: SDPAWithBiasBackend, + version_config: tuple[str, bool, bool, bool], +): + """Test shardwise attention with V1 and V2 serial modules (pre-computed k_in).""" + serial_version, apply_initial_norm, compute_pair_bias, use_model_cache = version_config + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + seed_by_rank(0) + + dtype: torch.dtype = torch.float32 + c_s: int = 32 # c_s // num_heads must be at least 16 and must be multiple of 4 to test the kernels + c_z: int = 32 + num_heads: int = 2 + inf: float = 1e6 + # When compute_pair_bias=False (V2 DTL), z last dim is num_heads (pre-computed bias) + # When compute_pair_bias=True, z last dim is c_z (projected through LayerNorm+Linear) + z_last_dim = c_z if compute_pair_bias else num_heads + # mock inputs and output + B = 2 * grid_group_sizes["dp"] + W = 32 + H = 128 + num_atoms = 128 * 4 # multiple of 128 for window batching + K = num_atoms // W # number of windows = 16 + D = c_s + + # S is reshaped inside atom transformer already into (B, K, W, D) + s = torch.empty( + size=(B, K, W, D), + dtype=dtype, + requires_grad=True, + device=device_type, + ) + z = torch.empty( + size=(B, K, W, H, z_last_dim), + dtype=dtype, + requires_grad=True, + device=device_type, + ) + + # V2 API: mask will be key-aligned (B, K, H) after transformation + mask_query_aligned = torch.randint(0, 2, s.shape[:-1], dtype=torch.float, device=device_type, requires_grad=False) + d_o = torch.empty( + size=(B, K, W, D), + dtype=dtype, + requires_grad=False, + device=device_type, + ) + + val_init_range = 0.2 + init_tensors_uniform([s, z, d_o], low=-val_init_range, high=val_init_range) + + # mask gradients where the inputs are masked out. We have to do this post-sort because distributed version is done upstream + d_o = d_o * mask_query_aligned.unsqueeze(-1) + + # reference module using the appropriate serial version + if serial_version == "v1": + reference_module = SerialAttentionPairBiasV1( + c_s=c_s, + c_z=c_z, + num_heads=num_heads, + inf=inf, + initial_norm=apply_initial_norm, + ) + else: + reference_module = SerialAttentionPairBiasV2( + c_s=c_s, + c_z=c_z if compute_pair_bias else None, + num_heads=num_heads, + inf=inf, + compute_pair_bias=compute_pair_bias, + ) + reference_module = reference_module.to(device_type, dtype=dtype) + reference_module = reference_module.train() + + init_module_params_uniform(reference_module, low=-val_init_range, high=val_init_range) + + reference_state_dict = {k: v.detach().clone().cpu() for k, v in reference_module.state_dict().items()} + + s_reshaped = s.view((B * K, W, -1)) # Q needs to be in this shape + + # Define single device to_keys function + # This to_keys function assumes that input comes in as shape (B * K, W, D), so translate to (B, N, D) + def _serial_to_keys(s: torch.Tensor, B: int, K: int, W: int, H: int, D: int) -> torch.Tensor: + s = s.view(B, K * W, -1) + indexing_matrix = get_indexing_matrix(K, W, H, s.device).to(dtype=s.dtype) + return single_to_keys(s, indexing_matrix, W, H) + + to_keys_new = partial(_serial_to_keys, B=B, K=K, W=W, H=H, D=D) + + def _to_keys_new_reshape(x: torch.Tensor) -> torch.Tensor: + return to_keys_new(x).view(B * K, H, -1) + + # V2 API: Pre-compute k_in and key-aligned mask + # IMPORTANT: In V2 API, k_in is a separate input (not derived from s in the attention module). + # To match distributed behavior where s and k_in are independent inputs, we: + # 1. Compute k_in values from s (for numerical correctness) + # 2. Detach and create a new leaf tensor for k_in (so s.grad only has Q-path gradients) + k_in_values = _to_keys_new_reshape(s_reshaped.detach()) # Compute values without gradient connection + k_in_reshaped = k_in_values.clone().requires_grad_(True) # Create leaf tensor for gradient tracking + mask_reshaped = _to_keys_new_reshape(mask_query_aligned.unsqueeze(-1)).squeeze(-1) + + z_reshaped = z.view((B * K, W, H, -1)) + + # Run serial forward with the appropriate API + if serial_version == "v1": + # V1: pass to_keys, mask is query-aligned (B*K, W); module transforms internally + mask_query = mask_query_aligned.view(B * K, W) + o_serial = reference_module( + s=s_reshaped, + z=z_reshaped, + mask=mask_query, + to_keys=_to_keys_new_reshape, + ) + else: + # V2: pass pre-computed k_in, mask is key-aligned (B*K, H) + o_serial = reference_module( + s=s_reshaped, + z=z_reshaped, + mask=mask_reshaped, + k_in=k_in_reshaped, + ) + + # clone forward pass output and match distributed module shape + o_global_host = o_serial.detach().clone().cpu().view(B, K, W, D) + + d_o = d_o.view(B * K, W, D) + + o_serial.backward(d_o) + + # parameter gradients. The serial version has S in shape B, N, D + d_s_expected_global_host = s.grad.detach().clone().cpu().view(B, K, W, D) + d_z_expected_global_host = z.grad.detach().clone().cpu() + + if serial_version == "v1": + # V1: k_in is computed internally from to_keys(s), no separate k_in gradient + d_k_in_expected_global_host = None + # V1: mask is query-aligned for distributed test (B, K, W) + mask_global_host = mask_query_aligned.detach().clone().cpu() + k_in_global_host = None + else: + # V2: k_in gradient + d_k_in_expected_global_host = k_in_reshaped.grad.detach().clone().cpu().view(B, K, H, D) + # V2: mask is key-aligned for distributed test (B, K, H) + mask_global_host = mask_reshaped.detach().clone().cpu().view(B, K, H) + # V2: k_in for distributed test (B, K, H, D) + k_in_global_host = k_in_reshaped.detach().clone().cpu().view(B, K, H, D) + + s_global_host = s.detach().clone().cpu() + z_global_host = z.detach().clone().cpu() + d_o_global_host = d_o.detach().clone().cpu().view(B, K, W, D) + + grad_params_expected_global_host = { + k: v.grad.detach().clone().cpu() for k, v in reference_module.named_parameters() if v.grad is not None + } + + spawn_multiprocessing( + parallel_assert_shardwise_attention_pair_bias, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + sdpa_with_bias_backend, + reference_state_dict, + c_s, + c_z, + num_heads, + inf, + s_global_host, + z_global_host, + mask_global_host, + k_in_global_host, # V2 API: pre-computed k_in + o_global_host, + d_o_global_host, + d_s_expected_global_host, + d_z_expected_global_host, + d_k_in_expected_global_host, # V2 API: k_in gradient + grad_params_expected_global_host, + serial_version, + apply_initial_norm, + compute_pair_bias, + use_model_cache, + ) diff --git a/tests/distributed/model/layers/test_dtensor_cat.py b/tests/distributed/model/layers/test_dtensor_cat.py new file mode 100644 index 000000000..062b14cc2 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_cat.py @@ -0,0 +1,370 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from math import isqrt +from typing import Dict, Optional + +import pytest +import torch +from torch.distributed.tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.cat_and_chunk import shardwise_cat +from boltz.testing.utils import assert_tensors_identical, seed_by_rank, spawn_multiprocessing + + +def is_slice_of(x: torch.Tensor, chunks: list[torch.Tensor], dim: int) -> list[bool]: + x_chunks = x.split([c.shape[dim] for c in chunks], dim=dim) + return [chunk.is_set_to(x_chunk) for chunk, x_chunk in zip(chunks, x_chunks)] + + +def compute_global_expectation(shape, num_inputs, dim_to_cat, device): + """Compute global expectation using standard PyTorch operations.""" + inputs = [] + for i in range(num_inputs): + # Create slightly different tensors for each input to make the test more robust + x = torch.rand(*shape, device=device, requires_grad=True) + inputs.append(x) + + # Compute on global tensors using standard cat operation + y = torch.cat(inputs, dim=dim_to_cat) + + # Create gradients for backward pass + dy = torch.rand_like(y) + + # Backward pass on global tensors + y.backward(dy) + + # Collect input gradients + input_grads = [x.grad.detach().clone() for x in inputs] + + # check for backward pass view semantics + is_grad_view_dy = is_slice_of(dy, [x.grad for x in inputs], dim_to_cat) + + return [x.detach().clone() for x in inputs], y.detach().clone(), input_grads, dy.detach().clone(), is_grad_view_dy + + +def compute_dtensor_native( + inputs_global: list[torch.Tensor], + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + dim_to_cat: int, +) -> tuple[list[DTensor], DTensor]: + """Compute DTensor native operations for comparison.""" + # Create DTensor native inputs + inputs_dtensor = [ + distribute_tensor(x_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + for x_global in inputs_global + ] + + # Forward pass with native DTensor cat operation + y_dtensor_result = torch.cat(inputs_dtensor, dim=dim_to_cat) + + # Backward pass with native DTensor op + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + y_dtensor_result.backward(dy_dtensor) + + inputs_grad_dtensor = [x.grad for x in inputs_dtensor] + + # check for backward pass view semantics + is_grad_view_dy = is_slice_of(dy_dtensor.to_local(), [x.to_local() for x in inputs_grad_dtensor], dim_to_cat) + + return inputs_grad_dtensor, y_dtensor_result, is_grad_view_dy + + +def compute_shardwise_cat_with_validation( + inputs_global: list[torch.Tensor], + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + dim_to_cat: int, + label_test_case: str, +) -> tuple[DTensor, list[DTensor], DTensor, list[bool]]: + """ + Compute shardwise_cat forward and backward pass with input validation checks. + + Returns: + y_dtensor_result: Forward pass result + inputs_dtensor: Input tensors with computed gradients + dy_dtensor: Distributed upstream gradient + is_grad_view_dy_result: View semantics check results + """ + # Create DTensor inputs + inputs_dtensor = [ + distribute_tensor(x_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + for x_global in inputs_global + ] + inputs_dtensor_copy = [x_dtensor.detach().clone().requires_grad_(True) for x_dtensor in inputs_dtensor] + + # Compute on distributed tensors using shardwise_cat + y_dtensor_result = shardwise_cat(inputs_dtensor, dim_to_cat) + + # verify no change to the fwd inputs + for x_dtensor, x_dtensor_copy in zip(inputs_dtensor, inputs_dtensor_copy): + assert_tensors_identical(x_dtensor.to_local(), x_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # Distribute the upstream adjoint for backward pass + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + + # Perform backward pass + dy_dtensor_copy = dy_dtensor.detach().clone() + y_dtensor_result.backward(dy_dtensor) + + # verify no change to the bwd input + assert_tensors_identical(dy_dtensor.to_local(), dy_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # verify input gradient placements are consistent with input placements + for i, inp in enumerate(inputs_dtensor): + assert ( + inp.grad.placements == input_placements + ), f"{label_test_case} inconsistent input {i} gradient placements with input placements" + + # check for backward pass view semantics + is_grad_view_dy_result = is_slice_of(dy_dtensor.to_local(), [x.grad.to_local() for x in inputs_dtensor], dim_to_cat) + + return y_dtensor_result, inputs_dtensor, dy_dtensor, is_grad_view_dy_result + + +def parallel_assert_dtensor_cat( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # each rank uses the same seed to generate the same input tensors + seed_by_rank(0, seed=42) + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters + shape = (3, 5, grid_group_sizes["dp"] * 2, 5, size_ring * 4, 5, 3, 2) + num_inputs = 3 # Number of tensors to concatenate + # Shard the sequence dimension (dim=1) for input tensors + # this emulates the sharded single representation in the Boltz model + input_placements = (Shard(dim=2), Shard(dim=4), Replicate()) + + # Test valid dimensions (not sharded) + valid_dims_to_cat = [0, 1, 3, 5, 6, 7, -1, -2, -3, -5, -7, -8] + # Test invalid dimensions (sharded): dim 2, 4, -4 (equiv to dim 4), -6 (equiv to dim 2) + invalid_dims_to_cat = [2, 4, -4, -6] + + # Test valid concatenation dimensions + for dim_to_cat in valid_dims_to_cat: + label_test_case = f"for dim {dim_to_cat}\n" + + # Compute global expectations + inputs_global, y_expected_global, inputs_grad_expected_global, dy_global, is_grad_view_dy_global = ( + compute_global_expectation(shape, num_inputs, dim_to_cat, manager.device) + ) + + # use DTensor native op as an alternative reference + # NOTE: DTensor native cat's backward pass doesn't guarantee view semantics + # as dim_to_cat == 7 gives a different view semantic result than dim_to_cat == -1, + # the latter should be the same as the former because ndim = 8 + inputs_grad_dtensor_native, y_dtensor_result_native, _ = compute_dtensor_native( + inputs_global, dy_global, manager.device_mesh_subgroups, input_placements, dim_to_cat + ) + + # Compute shardwise_cat forward and backward with validation + y_dtensor_result, inputs_dtensor, dy_dtensor, is_grad_view_dy_result = compute_shardwise_cat_with_validation( + inputs_global, dy_global, manager.device_mesh_subgroups, input_placements, dim_to_cat, label_test_case + ) + + # =================================================================== + # BLOCK 1: Check against DTensor native reference + # =================================================================== + + # check metadata against DTensor native + assert ( + y_dtensor_result.placements == y_dtensor_result_native.placements + ), f"{label_test_case} placements mismatch" + assert y_dtensor_result.shape == y_dtensor_result_native.shape, f"{label_test_case} shape mismatch" + assert y_dtensor_result.stride() == y_dtensor_result_native.stride(), f"{label_test_case} stride mismatch" + + # compare forward result with native DTensor op + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_result_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # compare global tensors between shardwise_cat and native DTensor results + y_result_global = y_dtensor_result.full_tensor() + y_result_global_native = y_dtensor_result_native.full_tensor() + + torch.testing.assert_close( + y_result_global, + y_result_global_native, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs native: {m}", + ) + + # assert input gradients' metadata and values against DTensor native + for i, (inp, inp_grad_native) in enumerate(zip(inputs_dtensor, inputs_grad_dtensor_native)): + assert ( + inp.grad.placements == inp_grad_native.placements + ), f"{label_test_case} input {i} gradient placements mismatch" + assert inp.grad.shape == inp_grad_native.shape, f"{label_test_case} input {i} gradient shape mismatch" + assert ( + inp.grad.stride() == inp_grad_native.stride() + ), f"{label_test_case} input {i} gradient stride mismatch" + + torch.testing.assert_close( + inp.grad.to_local(), + inp_grad_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input {i} gradient mismatch: {m}", + ) + + torch.testing.assert_close( + inp.grad.full_tensor(), + inp_grad_native.full_tensor(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input {i} gradient mismatch: {m}", + ) + + # =================================================================== + # BLOCK 2: Check against global serial expectation + # =================================================================== + y_dtensor_expected = distribute_tensor( + y_expected_global, manager.device_mesh_subgroups, y_dtensor_result.placements + ) + + # Compare results with expected local shards + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_expected.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # compare forward result with global expectation + torch.testing.assert_close( + y_result_global, + y_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs global expectation: {m}", + ) + + # create distributed tensors from global results for local shard comparison + for i, input_grad_expected in enumerate(inputs_grad_expected_global): + input_grad_expected_dtensor = distribute_tensor( + input_grad_expected, manager.device_mesh_subgroups, input_placements + ) + + # compare local shards with expected + torch.testing.assert_close( + inputs_dtensor[i].grad.to_local(), + input_grad_expected_dtensor.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input {i} gradient vs global expectation: {m}", + ) + + torch.testing.assert_close( + inputs_dtensor[i].grad.full_tensor(), + input_grad_expected, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input {i} gradient vs global expectation: {m}", + ) + + # With explicit shape and stride, DTensor.from_local can't guarantee view semantics + # assert ( + # is_grad_view_dy_result[i] == is_grad_view_dy_global[i] + # ), f"{label_test_case} input {i} backward pass view semantics mismatch" + + # Test invalid concatenation dimensions (should raise ValueError) + for dim_to_cat in invalid_dims_to_cat: + label_test_case = f"for invalid dim {dim_to_cat}\n" + + # Compute global expectations (this should work fine) + inputs_global, _, _, _, _ = compute_global_expectation(shape, num_inputs, dim_to_cat, manager.device) + + # Create DTensor inputs + inputs_dtensor = [] + for x_global in inputs_global: + x_dtensor = distribute_tensor(x_global, manager.device_mesh_subgroups, input_placements) + x_dtensor.requires_grad = True + inputs_dtensor.append(x_dtensor) + + # This should raise due to sharded dimension + with pytest.raises( + NotImplementedError, match=f"Concatenation along dimension {dim_to_cat} shared by device_mesh axis" + ): + shardwise_cat(inputs_dtensor, dim_to_cat) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_dtensor_cat(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_dtensor_cat, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) diff --git a/tests/distributed/model/layers/test_dtensor_chunk.py b/tests/distributed/model/layers/test_dtensor_chunk.py new file mode 100755 index 000000000..37e590718 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_chunk.py @@ -0,0 +1,405 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from math import isqrt +from typing import Dict, Optional + +import pytest +import torch +from torch.distributed.tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.cat_and_chunk import shardwise_chunk +from boltz.testing.utils import assert_tensors_identical, seed_by_rank, spawn_multiprocessing + + +def is_slice_of(x: torch.Tensor, chunks: list[torch.Tensor], dim: int) -> list[bool]: + """Check if each chunk is a view/slice of the original tensor x along the specified dimension.""" + x_chunks = x.split([c.shape[dim] for c in chunks], dim=dim) + return [chunk.is_set_to(x_chunk) for chunk, x_chunk in zip(chunks, x_chunks)] + + +def compute_global_expectation(shape, num_chunks, dim_to_chunk, device): + """Compute global expectation using standard PyTorch operations.""" + # Create input tensor + x = torch.rand(*shape, device=device, requires_grad=True) + + # Compute on global tensor using standard chunk operation + y_chunks = x.chunk(num_chunks, dim=dim_to_chunk) + + # Check for forward pass view semantics (are chunks views of the input?) + is_chunk_view_x = is_slice_of(x, list(y_chunks), dim_to_chunk) + + # Create gradient for backward pass - each chunk gets different gradient + dy_chunks = [torch.rand_like(chunk) for chunk in y_chunks] + + # Backward pass on global tensors + # Need to use torch.autograd.backward with multiple tensors + torch.autograd.backward(y_chunks, dy_chunks) + + # Collect input gradient + input_grad = x.grad.detach().clone() + + return ( + x.detach().clone(), + [chunk.detach().clone() for chunk in y_chunks], + input_grad, + dy_chunks, + is_chunk_view_x, + ) + + +def compute_dtensor_native( + input_global: torch.Tensor, + dy_chunks_global: list[torch.Tensor], + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + dim_to_chunk: int, + num_chunks: int, +) -> tuple[list[DTensor], torch.Tensor, bool]: + """Compute DTensor native operations for comparison.""" + # Create DTensor native input + input_dtensor = distribute_tensor(input_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + + # Forward pass with native DTensor chunk operation + y_chunks_dtensor_result = torch.chunk(input_dtensor, num_chunks, dim=dim_to_chunk) + + # Convert tuple to list for easier handling + y_chunks_dtensor_result = list(y_chunks_dtensor_result) + + # Check for forward pass view semantics (are chunks views of the input?) + is_chunk_view_x = is_slice_of( + input_dtensor.to_local(), [chunk.to_local() for chunk in y_chunks_dtensor_result], dim_to_chunk + ) + + # Backward pass with native DTensor op + dy_chunks_dtensor = [ + distribute_tensor(dy_chunk_global.detach().clone(), device_mesh, chunk.placements) + for dy_chunk_global, chunk in zip(dy_chunks_global, y_chunks_dtensor_result) + ] + + torch.autograd.backward(y_chunks_dtensor_result, dy_chunks_dtensor) + + input_grad_dtensor = input_dtensor.grad + + return y_chunks_dtensor_result, input_grad_dtensor, is_chunk_view_x + + +def compute_shardwise_chunk_with_validation( + input_global: torch.Tensor, + dy_chunks_global: list[torch.Tensor], + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + dim_to_chunk: int, + num_chunks: int, + label_test_case: str, +) -> tuple[list[DTensor], DTensor, list[DTensor], list[bool]]: + """ + Compute shardwise_chunk forward and backward pass with input validation checks. + + Returns: + y_chunks_result: Forward pass result (list of chunks) + input_dtensor: Input tensor with computed gradient + dy_chunks_dtensor: Distributed upstream gradients + is_grad_view_dy_result: View semantics check results + """ + # Create DTensor input + input_dtensor = distribute_tensor(input_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + input_dtensor_copy = input_dtensor.detach().clone().requires_grad_(True) + + # Compute on distributed tensor using shardwise_chunk + y_chunks_result = shardwise_chunk(input_dtensor, num_chunks, dim_to_chunk) + + # Convert tuple to list for easier handling + y_chunks_result = list(y_chunks_result) + + # Check for forward pass view semantics (are chunks views of the input?) + is_chunk_view_x_result = is_slice_of( + input_dtensor.to_local(), [chunk.to_local() for chunk in y_chunks_result], dim_to_chunk + ) + + # Verify no change to the fwd input + assert_tensors_identical( + input_dtensor.to_local(), input_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False + ) + + # Distribute the upstream adjoints for backward pass + dy_chunks_dtensor = [ + distribute_tensor(dy_chunk_global.detach().clone(), device_mesh, chunk.placements) + for dy_chunk_global, chunk in zip(dy_chunks_global, y_chunks_result) + ] + + # Perform backward pass + dy_chunks_dtensor_copy = [dy_chunk.detach().clone() for dy_chunk in dy_chunks_dtensor] + torch.autograd.backward(y_chunks_result, dy_chunks_dtensor) + + # Verify no change to the bwd inputs + for dy_chunk, dy_chunk_copy in zip(dy_chunks_dtensor, dy_chunks_dtensor_copy): + assert_tensors_identical(dy_chunk.to_local(), dy_chunk_copy.to_local(), check_grad=False, check_grad_fn=False) + + # Verify input gradient placements are consistent with input placements + assert ( + input_dtensor.grad.placements == input_placements + ), f"{label_test_case} inconsistent input gradient placements with input placements" + + return y_chunks_result, input_dtensor, dy_chunks_dtensor, is_chunk_view_x_result + + +def parallel_assert_dtensor_chunk( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # each rank uses the same seed to generate the same input tensors + seed_by_rank(0, seed=42) + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters + shape = (3, 5, grid_group_sizes["dp"] * 2, 5, size_ring * 4, 5, 3, 12) # Last dim divisible by 2,3,4 + num_chunks = 3 # Number of chunks to split into + # Shard the sequence dimension (dim=2) and token dimension (dim=4) for input tensor + # this emulates the sharded representation in the Boltz model + input_placements = (Shard(2), Shard(4), Replicate()) + + # Test valid dimensions (not sharded) + valid_dims_to_chunk = [0, 1, 3, 5, 6, 7, -1, -2, -3, -5, -7, -8] + # Test invalid dimensions (sharded): dim 2, 4, -4 (equiv to dim 4), -6 (equiv to dim 2) + invalid_dims_to_chunk = [2, 4, -4, -6] + + # Test valid chunking dimensions + for dim_to_chunk in valid_dims_to_chunk: + label_test_case = f"for dim {dim_to_chunk}\n" + + # Compute global expectations + input_global, y_chunks_expected_global, input_grad_expected_global, dy_chunks_global, is_chunk_view_x_global = ( + compute_global_expectation(shape, num_chunks, dim_to_chunk, manager.device) + ) + + # Use DTensor native op as an alternative reference + y_chunks_dtensor_native, input_grad_dtensor_native, is_chunk_view_x_native = compute_dtensor_native( + input_global, dy_chunks_global, manager.device_mesh_subgroups, input_placements, dim_to_chunk, num_chunks + ) + + # Compute shardwise_chunk forward and backward with validation + y_chunks_result, input_dtensor, dy_chunks_dtensor, is_chunk_view_x_result = ( + compute_shardwise_chunk_with_validation( + input_global, + dy_chunks_global, + manager.device_mesh_subgroups, + input_placements, + dim_to_chunk, + num_chunks, + label_test_case, + ) + ) + + # =================================================================== + # BLOCK 1: Check against DTensor native reference + # =================================================================== + + # Check metadata against DTensor native - number of chunks + assert len(y_chunks_result) == len(y_chunks_dtensor_native), f"{label_test_case} number of chunks mismatch" + + # Check each chunk against DTensor native + for i, (chunk_result, chunk_native) in enumerate(zip(y_chunks_result, y_chunks_dtensor_native)): + assert ( + chunk_result.placements == chunk_native.placements + ), f"{label_test_case} chunk {i} placements mismatch" + assert chunk_result.shape == chunk_native.shape, f"{label_test_case} chunk {i} shape mismatch" + # In some of the test cases, the DTensor native result will retain the same stride as the input global, + # which I believe is actually wrong because upon the result.full_tensor(), the would-be padding won't be + # materialized and it shouldn't be materialized + # assert chunk_result.stride() == chunk_native.stride(), f"{label_test_case} chunk {i} stride mismatch" + + # Compare forward result with native DTensor op + torch.testing.assert_close( + chunk_result.to_local(), + chunk_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} chunk {i}: {m}", + ) + + # Compare global tensors between shardwise_chunk and native DTensor results + chunk_result_global = chunk_result.full_tensor() + chunk_result_global_native = chunk_native.full_tensor() + + torch.testing.assert_close( + chunk_result_global, + chunk_result_global_native, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} chunk {i} vs native: {m}", + ) + + assert is_chunk_view_x_result == is_chunk_view_x_native, ( + f"{label_test_case} forward pass view semantics mismatch with DTensor native: " + f"Expected: {is_chunk_view_x_native}, " + f"Actual: {is_chunk_view_x_result}" + ) + + # Assert input gradient metadata and values against DTensor native + assert ( + input_dtensor.grad.placements == input_grad_dtensor_native.placements + ), f"{label_test_case} input gradient placements mismatch" + assert ( + input_dtensor.grad.shape == input_grad_dtensor_native.shape + ), f"{label_test_case} input gradient shape mismatch" + assert ( + input_dtensor.grad.stride() == input_grad_dtensor_native.stride() + ), f"{label_test_case} input gradient stride mismatch" + + torch.testing.assert_close( + input_dtensor.grad.to_local(), + input_grad_dtensor_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + torch.testing.assert_close( + input_dtensor.grad.full_tensor(), + input_grad_dtensor_native.full_tensor(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + # =================================================================== + # BLOCK 2: Check against global serial expectation + # =================================================================== + + # Compare results with expected local shards + for i, (chunk_result, chunk_expected) in enumerate(zip(y_chunks_result, y_chunks_expected_global)): + chunk_dtensor_expected = distribute_tensor( + chunk_expected, manager.device_mesh_subgroups, chunk_result.placements + ) + + torch.testing.assert_close( + chunk_result.to_local(), + chunk_dtensor_expected.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} chunk {i}: {m}", + ) + + # Compare forward result with global expectation + torch.testing.assert_close( + chunk_result.full_tensor(), + chunk_expected, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} chunk {i} vs global expectation: {m}", + ) + + # Create distributed tensor from global result for local shard comparison + input_grad_expected_dtensor = distribute_tensor( + input_grad_expected_global, manager.device_mesh_subgroups, input_placements + ) + + # Compare local shards with expected + torch.testing.assert_close( + input_dtensor.grad.to_local(), + input_grad_expected_dtensor.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + torch.testing.assert_close( + input_dtensor.grad.full_tensor(), + input_grad_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + assert ( + is_chunk_view_x_result == is_chunk_view_x_global + ), f"{label_test_case} forward pass view semantics mismatch" + + # Test invalid chunking dimensions (should raise NotImplementedError) + for dim_to_chunk in invalid_dims_to_chunk: + label_test_case = f"for invalid dim {dim_to_chunk}\n" + + # Compute global expectations (this should work fine) + input_global, _, _, _, _ = compute_global_expectation(shape, num_chunks, dim_to_chunk, manager.device) + + # Create DTensor input + input_dtensor = distribute_tensor(input_global, manager.device_mesh_subgroups, input_placements) + input_dtensor.requires_grad = True + + # This should raise due to sharded dimension + with pytest.raises( + NotImplementedError, match=f"Chunking along dimension {dim_to_chunk} shared by device_mesh axis" + ): + shardwise_chunk(input_dtensor, num_chunks, dim_to_chunk) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_dtensor_chunk(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_dtensor_chunk, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) diff --git a/tests/distributed/model/layers/test_dtensor_clip.py b/tests/distributed/model/layers/test_dtensor_clip.py new file mode 100644 index 000000000..39459f795 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_clip.py @@ -0,0 +1,219 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import unittest +from typing import Optional + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.clip import clip +from boltz.testing.utils import assert_tensors_identical, spawn_multiprocessing + + +def serial_clip(tensor: torch.Tensor, min_val: Optional[float] = None, max_val: Optional[float] = None) -> torch.Tensor: + """Serial implementation of clip operation for comparison.""" + return torch.clip(tensor, min=min_val, max=max_val) + + +def parallel_assert_clip( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + min_val, + max_val, + input_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_expected_global_host, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Distribute input tensor + input_dtensor = distribute_tensor( + input_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ).requires_grad_(True) + + # Distribute expected outputs + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + d_input_expected_dtensor = distribute_tensor( + d_input_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Create copy to verify input isn't modified + input_dtensor_copy = input_dtensor.detach().clone().requires_grad_(True) + + # Forward pass + output_dtensor_result = clip(input_dtensor, min_val=min_val, max_val=max_val) + + # Verify input wasn't modified + assert_tensors_identical( + input_dtensor_copy.to_local(), input_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Test forward pass results + assert ( + output_dtensor_result.placements == placements + ), f"placements: {placements}, output_dtensor_result.placements: {output_dtensor_result.placements}" + assert ( + output_dtensor_result.shape == output_expected_dtensor.shape + ), f"Output shape mismatch: {output_dtensor_result.shape} != {output_expected_dtensor.shape}" + assert ( + output_dtensor_result.stride() == output_expected_dtensor.stride() + ), f"Output stride mismatch: {output_dtensor_result.stride()} != {output_expected_dtensor.stride()}" + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradient + assert ( + input_dtensor.grad.placements == placements + ), f"placements: {placements}, input_dtensor.grad.placements: {input_dtensor.grad.placements}" + assert ( + input_dtensor.grad.shape == d_input_expected_dtensor.shape + ), f"Input gradient shape mismatch: {input_dtensor.grad.shape} != {d_input_expected_dtensor.shape}" + assert ( + input_dtensor.grad.stride() == d_input_expected_dtensor.stride() + ), f"Input gradient stride mismatch: {input_dtensor.grad.stride()} != {d_input_expected_dtensor.stride()}" + torch.testing.assert_close(input_dtensor.grad.to_local(), d_input_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + d_input_global_result_host = input_dtensor.grad.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(output_global_result_host, output_expected_global_host) + torch.testing.assert_close(d_input_global_result_host, d_input_expected_global_host) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize( + "placements", [(Shard(0), Shard(1), Shard(2)), (Shard(0), Shard(1), Replicate())], ids=["shard", "replicate"] +) +@pytest.mark.parametrize( + "clip_params", + [ + (0.0, None), # min_val only (ReLU-like) + (None, 5.0), # max_val only + (-2.0, 1.0), # both min and max + ], + ids=["min_only", "max_only", "min_max"], +) +def test_clip_parallel(setup_env, placements, clip_params): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + min_val, max_val = clip_params + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 100 # Number of tokens + D = 32 # Hidden dimension + + seed = 42 + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + + # Create input tensor with values that will test clipping + input_global = torch.empty((B, N, N, D), requires_grad=True, device=device_type) + with torch.no_grad(): + input_global.uniform_(-10, 10, generator=rng) # Wide range to ensure clipping occurs + + # Run serial forward pass + input_global_host = input_global.detach().clone().cpu() + output_expected_global = serial_clip(input_global, min_val=min_val, max_val=max_val) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Create upstream gradient and run backward pass + d_output_expected_global = torch.rand_like(output_expected_global) + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_clip, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + min_val, + max_val, + input_global_host, + output_expected_global_host, + d_output_expected_global_host, + input_global.grad.detach().clone().cpu(), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/distributed/model/layers/test_dtensor_dropout.py b/tests/distributed/model/layers/test_dtensor_dropout.py new file mode 100644 index 000000000..1ffa12634 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_dropout.py @@ -0,0 +1,267 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.dropout import apply_dropout_mask_msa_or_pair +from boltz.model.layers.dropout import get_dropout_mask +from boltz.testing.utils import ( + assert_tensors_identical, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_assert_dropout( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dropout, + training, + columnwise, + samples_dropout_global_host, + seed, + src_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_src_expected_global_host, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Input tensors have shape (B, S, N, D) - sharded on dims 0, 1, and 2 (B, S, N) + placements_input = (Shard(0), Shard(1), Shard(2)) + + # Distribute input tensors + src_dtensor = distribute_tensor( + src_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements_input + ).requires_grad_(True) + + if columnwise: + placements_samples_dropout = (Shard(0), Replicate(), Shard(2)) + else: + placements_samples_dropout = (Shard(0), Shard(1), Replicate()) + + samples_dropout_dtensor = distribute_tensor( + samples_dropout_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_samples_dropout, + ) + + # Distribute expected outputs + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + src_data_rank=None, + ) + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ) + d_src_expected_dtensor = distribute_tensor( + d_src_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + src_data_rank=None, + ) + + # Create copies to verify inputs aren't modified + src_dtensor_copy = src_dtensor.detach().clone().requires_grad_(True) + samples_dropout_dtensor_copy = samples_dropout_dtensor.detach().clone() + + torch.cuda.manual_seed_all(seed) + + # Forward pass + output_dtensor_result = apply_dropout_mask_msa_or_pair( + src_dtensor, dropout, training, columnwise, samples_dropout_dtensor + ) + + # just so this runs but unfortunately we have no way to verify the results due to the lack of way to generate + # consistent RNG sequences between the serial and distributed versions + src_dtensor_no_samples_dropout = src_dtensor.detach().clone().requires_grad_(True) + output_dtensor_result_no_samples_dropout = apply_dropout_mask_msa_or_pair( + src_dtensor_no_samples_dropout, dropout, training, columnwise + ) + assert output_dtensor_result_no_samples_dropout.shape == output_dtensor_result.shape + + # Verify inputs weren't modified + assert_tensors_identical(src_dtensor_copy.to_local(), src_dtensor.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical( + samples_dropout_dtensor_copy.to_local(), + samples_dropout_dtensor.to_local(), + check_grad=False, + check_grad_fn=False, + ) + + # Test forward pass results + assert ( + output_dtensor_result.placements == placements_input + ), f"placements_input: {placements_input}, output_dtensor_result.placements: {output_dtensor_result.placements}" + assert ( + output_dtensor_result.shape == output_expected_dtensor.shape + ), f"Output shape mismatch: {output_dtensor_result.shape} != {output_expected_dtensor.shape}" + assert ( + output_dtensor_result.stride() == output_expected_dtensor.stride() + ), f"Output stride mismatch: {output_dtensor_result.stride()} != {output_expected_dtensor.stride()}" + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # again, no way to verify the results due to the lack of way to generate + # consistent RNG sequences between the serial and distributed versions + output_dtensor_result_no_samples_dropout.backward(d_output_expected_dtensor) + assert src_dtensor_no_samples_dropout.grad.shape == src_dtensor.grad.shape + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradients + assert ( + src_dtensor.grad.placements == placements_input + ), f"placements_input: {placements_input}, src_dtensor.grad.placements: {src_dtensor.grad.placements}" + assert ( + src_dtensor.grad.shape == d_src_expected_dtensor.shape + ), f"Input gradient shape mismatch: {src_dtensor.grad.shape} != {d_src_expected_dtensor.shape}" + assert ( + src_dtensor.grad.stride() == d_src_expected_dtensor.stride() + ), f"Input gradient stride mismatch: {src_dtensor.grad.stride()} != {d_src_expected_dtensor.stride()}" + torch.testing.assert_close(src_dtensor.grad.to_local(), d_src_expected_dtensor.to_local()) + + # Verify that samples_dropout_dtensor has no gradients (should be None) + assert samples_dropout_dtensor.grad is None, "Reference dropout samples_dropout should not have gradients" + + # Test full tensor gathering - verify distributed results match serial results + src_global_result_host = src_dtensor.full_tensor().cpu() + output_global_result_host = output_dtensor_result.full_tensor().cpu() + d_src_global_result_host = src_dtensor.grad.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(src_global_result_host, src_global_host) + torch.testing.assert_close(output_global_result_host, output_expected_global_host) + torch.testing.assert_close(d_src_global_result_host, d_src_expected_global_host) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize("dropout", [0.0, 0.5], ids=lambda x: f"dropout={x}") +@pytest.mark.parametrize("training", [True, False], ids=lambda x: f"training={x}") +@pytest.mark.parametrize("columnwise", [True, False], ids=lambda x: f"columnwise={x}") +def test_apply_dropout_mask_msa_or_pair_parallel(setup_env, dropout, training, columnwise): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + S = size_ring * 4 # Sequence length + N = size_ring * 4 # Number of tokens/positions + D = 64 # Feature dimension + + seed = 42 + seed_by_rank(0, seed) + + # Create input tensors with proper 4D shape (B, S, N, D) + src_global = torch.rand((B, S, N, D), requires_grad=True, device=device_type) + # requires_grad but it won't be set + z_global = torch.rand((B, S, N, D), requires_grad=True, device=device_type) + + # Run serial reference computation + seed_by_rank(0, seed) # Reset seed for consistent dropout mask + src_global_copy = src_global.detach().clone().requires_grad_(True) + z_global_copy = z_global.detach().clone().requires_grad_(True) + + mask = get_dropout_mask(dropout, z_global_copy, training, columnwise) + output_expected_global = src_global_copy * mask.to(src_global_copy.dtype) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Create upstream gradient and run backward pass + d_output_expected_global = torch.rand_like(output_expected_global) + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + + # Verify that z_global_copy has no gradients in the serial version + assert z_global_copy.grad is None, "Reference tensor z should not have gradients in serial version" + + # Get expected gradients + d_src_expected_global_host = src_global_copy.grad.detach().clone().cpu() + + # Prepare input data for parallel test + src_global_host = src_global.detach().clone().cpu() + + # emulate the dropout mask creation to testing the distributed version + seed_by_rank(0, seed) + if columnwise: + samples_dropout_global_host = torch.rand((B, 1, N, 1), device=device_type, dtype=src_global.dtype).cpu() + else: + samples_dropout_global_host = torch.rand((B, S, 1, 1), device=device_type, dtype=src_global.dtype).cpu() + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_dropout, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dropout, + training, + columnwise, + samples_dropout_global_host, + seed, + src_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_src_expected_global_host, + ) diff --git a/tests/distributed/model/layers/test_dtensor_elementwise_op.py b/tests/distributed/model/layers/test_dtensor_elementwise_op.py new file mode 100644 index 000000000..4524395a0 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_elementwise_op.py @@ -0,0 +1,709 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.elementwise_op import ( + ElementwiseOp, + elementwise_op, + scalar_tensor_op, + single_tensor_op, +) +from boltz.testing.utils import assert_tensors_identical, spawn_multiprocessing + + +def serial_elementwise_op(a: torch.Tensor, b: torch.Tensor, op: ElementwiseOp) -> torch.Tensor: + """Serial implementation of elementwise operation for comparison.""" + if op == ElementwiseOp.SUM: + return a + b + elif op == ElementwiseOp.SUB: + return a - b + elif op == ElementwiseOp.PROD: + return a * b + elif op == ElementwiseOp.DIV: + return a / b + elif op == ElementwiseOp.EQUAL: + return a & b + elif op == ElementwiseOp.BITAND: + return a & b + else: + raise ValueError(f"Unsupported operation: {op}") + + +def parallel_assert_elementwise_op( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + op, + input_a_global_host, + input_b_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_a_expected_global_host, + d_input_b_expected_global_host, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Distribute input tensors + input_a_dtensor = distribute_tensor( + input_a_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ) + input_b_dtensor = distribute_tensor( + input_b_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ) + if op != ElementwiseOp.EQUAL and op != ElementwiseOp.BITAND: + input_a_dtensor.requires_grad_(True) + input_b_dtensor.requires_grad_(True) + + # Distribute expected outputs + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + if op == ElementwiseOp.EQUAL or op == ElementwiseOp.BITAND: + d_output_expected_dtensor = None + d_input_a_expected_dtensor = None + d_input_b_expected_dtensor = None + else: + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + d_input_a_expected_dtensor = distribute_tensor( + d_input_a_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + d_input_b_expected_dtensor = distribute_tensor( + d_input_b_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Create copies to verify inputs aren't modified + input_a_dtensor_copy = input_a_dtensor.detach().clone() + input_b_dtensor_copy = input_b_dtensor.detach().clone() + if op != ElementwiseOp.EQUAL and op != ElementwiseOp.BITAND: + input_a_dtensor_copy.requires_grad_(True) + input_b_dtensor_copy.requires_grad_(True) + + # Forward pass + output_dtensor_result = elementwise_op(input_a_dtensor, input_b_dtensor, op) + + # Verify inputs weren't modified + assert_tensors_identical( + input_a_dtensor_copy.to_local(), input_a_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical( + input_b_dtensor_copy.to_local(), input_b_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Test forward pass results + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Test shape and stride consistency + assert ( + output_dtensor_result.shape == output_expected_dtensor.shape + ), f"Output shape mismatch: {output_dtensor_result.shape} != {output_expected_dtensor.shape}" + assert ( + output_dtensor_result.stride() == output_expected_dtensor.stride() + ), f"Output stride mismatch: {output_dtensor_result.stride()} != {output_expected_dtensor.stride()}" + + # Backward pass + if op == ElementwiseOp.EQUAL or op == ElementwiseOp.BITAND: + with pytest.raises(RuntimeError, match="tensors does not require grad and does not have a grad_fn"): + output_dtensor_result.backward(d_output_expected_dtensor) + else: + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradients + if op == ElementwiseOp.EQUAL or op == ElementwiseOp.BITAND: + assert ( + d_output_expected_dtensor is None + and d_input_a_expected_dtensor is None + and d_input_b_expected_dtensor is None + ) + assert input_a_dtensor.grad is None and input_b_dtensor.grad is None + else: + torch.testing.assert_close(input_a_dtensor.grad.to_local(), d_input_a_expected_dtensor.to_local()) + torch.testing.assert_close(input_b_dtensor.grad.to_local(), d_input_b_expected_dtensor.to_local()) + + # Test gradient shape and stride consistency + assert ( + input_a_dtensor.grad.shape == d_input_a_expected_dtensor.shape + ), f"Input A gradient shape mismatch: {input_a_dtensor.grad.shape} != {d_input_a_expected_dtensor.shape}" + assert ( + input_a_dtensor.grad.stride() == d_input_a_expected_dtensor.stride() + ), f"Input A gradient stride mismatch: {input_a_dtensor.grad.stride()} != {d_input_a_expected_dtensor.stride()}" + assert ( + input_b_dtensor.grad.shape == d_input_b_expected_dtensor.shape + ), f"Input B gradient shape mismatch: {input_b_dtensor.grad.shape} != {d_input_b_expected_dtensor.shape}" + assert ( + input_b_dtensor.grad.stride() == d_input_b_expected_dtensor.stride() + ), f"Input B gradient stride mismatch: {input_b_dtensor.grad.stride()} != {d_input_b_expected_dtensor.stride()}" + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + torch.testing.assert_close(output_global_result_host, output_expected_global_host) + if op != ElementwiseOp.EQUAL and op != ElementwiseOp.BITAND: + d_input_a_global_result_host = input_a_dtensor.grad.full_tensor().cpu() + d_input_b_global_result_host = input_b_dtensor.grad.full_tensor().cpu() + torch.testing.assert_close(d_input_a_global_result_host, d_input_a_expected_global_host) + torch.testing.assert_close(d_input_b_global_result_host, d_input_b_expected_global_host) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize( + "placements", [(Shard(0), Shard(1), Shard(2)), (Shard(0), Shard(1), Replicate())], ids=["shard", "replicate"] +) +@pytest.mark.parametrize( + "op", + [ + ElementwiseOp.SUM, + ElementwiseOp.SUB, + ElementwiseOp.PROD, + ElementwiseOp.DIV, + ElementwiseOp.EQUAL, + ElementwiseOp.BITAND, + ], + ids=["sum", "sub", "prod", "div", "equal", "bitand"], +) +def test_elementwise_op_parallel(setup_env, placements, op): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 100 # Number of tokens + D = 32 # Hidden dimension + + seed = 42 + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + + # Create input tensors with proper shapes + if op == ElementwiseOp.EQUAL or op == ElementwiseOp.BITAND: + input_a_global = torch.randint(0, 2, (B, N, N, D), requires_grad=False, device=device_type, generator=rng) + input_b_global = torch.randint(0, 2, (B, N, N, D), requires_grad=False, device=device_type, generator=rng) + else: + input_a_global = torch.empty((B, N, N, D), requires_grad=True, device=device_type) + input_b_global = torch.empty((B, N, N, D), requires_grad=True, device=device_type) + with torch.no_grad(): + input_a_global.uniform_(-1000, 1000, generator=rng) + input_b_global.uniform_(-1000, 1000, generator=rng) + + # Run serial forward pass + input_a_global_host = input_a_global.detach().clone().cpu() + input_b_global_host = input_b_global.detach().clone().cpu() + output_expected_global = serial_elementwise_op(input_a_global, input_b_global, op) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Create upstream gradient and run backward pass + if op == ElementwiseOp.EQUAL or op == ElementwiseOp.BITAND: + with pytest.raises(RuntimeError, match="tensors does not require grad and does not have a grad_fn"): + output_expected_global.backward(torch.empty_like(output_expected_global)) + + d_output_expected_global = None + d_output_expected_global_host = None + else: + d_output_expected_global = torch.rand_like(output_expected_global) + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_elementwise_op, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + op, + input_a_global_host, + input_b_global_host, + output_expected_global_host, + d_output_expected_global_host, + input_a_global.grad.detach().clone().cpu() + if op != ElementwiseOp.EQUAL and op != ElementwiseOp.BITAND + else None, + input_b_global.grad.detach().clone().cpu() + if op != ElementwiseOp.EQUAL and op != ElementwiseOp.BITAND + else None, + ) + + +def serial_single_tensor_op(x: torch.Tensor, op: ElementwiseOp) -> torch.Tensor: + """Serial implementation of single tensor operation for comparison.""" + if op == ElementwiseOp.COS: + return torch.cos(x) + elif op == ElementwiseOp.RELU: + return torch.relu(x) + elif op == ElementwiseOp.ROUND: + return torch.round(x) + elif op == ElementwiseOp.LOG: + return torch.log(x) + elif op == ElementwiseOp.EXP: + return torch.exp(x) + elif op == ElementwiseOp.ABS: + return torch.abs(x) + elif op == ElementwiseOp.SIGMOID: + return torch.sigmoid(x) + else: + raise ValueError(f"Unsupported single tensor operation: {op}") + + +def parallel_assert_single_tensor_op( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + op, + input_x_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_x_expected_global_host, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Distribute input tensor + input_x_dtensor = distribute_tensor( + input_x_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ).requires_grad_(True) + + # Distribute expected outputs + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + d_input_x_expected_dtensor = distribute_tensor( + d_input_x_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Create copy to verify input isn't modified + input_x_dtensor_copy = input_x_dtensor.detach().clone().requires_grad_(True) + + # Forward pass + output_dtensor_result = single_tensor_op(input_x_dtensor, op) + + # Verify input wasn't modified + assert_tensors_identical( + input_x_dtensor_copy.to_local(), input_x_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Test forward pass results + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + if op == ElementwiseOp.ROUND: + with pytest.raises(RuntimeError, match="tensors does not require grad and does not have a grad_fn"): + output_dtensor_result.backward(d_output_expected_dtensor) + else: + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradient + if op == ElementwiseOp.ROUND: + assert input_x_dtensor.grad is None + else: + torch.testing.assert_close(input_x_dtensor.grad.to_local(), d_input_x_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + d_input_x_global_result_host = input_x_dtensor.grad.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(output_global_result_host, output_expected_global_host) + torch.testing.assert_close(d_input_x_global_result_host, d_input_x_expected_global_host) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize( + "placements", [(Shard(0), Shard(1), Shard(2)), (Shard(0), Shard(1), Replicate())], ids=["shard", "replicate"] +) +@pytest.mark.parametrize( + "op", + [ + ElementwiseOp.COS, + ElementwiseOp.RELU, + ElementwiseOp.ROUND, + ElementwiseOp.LOG, + ElementwiseOp.EXP, + ElementwiseOp.ABS, + ElementwiseOp.SIGMOID, + ], + ids=["cos", "relu", "round", "log", "exp", "abs", "sigmoid"], +) +def test_single_tensor_op_parallel(setup_env, placements, op): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 100 # Number of tokens + D = 32 # Hidden dimension + + seed = 42 + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + + # Create input tensor with proper shape + input_x_global = torch.empty((B, N, N, D), requires_grad=True, device=device_type) + with torch.no_grad(): + if op == ElementwiseOp.LOG: + # For LOG operation, use positive values only + # On certain GPU architectures, we need to limit the range of the input + # to limit the numerical errors + input_x_global.uniform_(0.1, 50, generator=rng) + elif op == ElementwiseOp.EXP: + # For EXP operation, use a reasonable range to avoid overflow + input_x_global.uniform_(-10, 10, generator=rng) + else: + # For other operations, test a wide range of values including negative values and zeros for ReLU testing + input_x_global.uniform_(-1000, 1000, generator=rng) + # Ensure we have some zeros for ReLU boundary testing + input_x_global.view(-1)[::17] = 0.0 # Set every 17th element to zero + + # Run serial forward pass + input_x_global_host = input_x_global.detach().clone().cpu() + output_expected_global = serial_single_tensor_op(input_x_global, op) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Create upstream gradient and run backward pass + d_output_expected_global = torch.rand_like(output_expected_global) + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_single_tensor_op, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + op, + input_x_global_host, + output_expected_global_host, + d_output_expected_global_host, + input_x_global.grad.detach().clone().cpu(), + ) + + +def serial_scalar_tensor_op(scalar: float, tensor: torch.Tensor, op: ElementwiseOp) -> torch.Tensor: + """Serial implementation of scalar-tensor operation for comparison.""" + if op == ElementwiseOp.SUM: + return scalar + tensor + elif op == ElementwiseOp.SUB: + return scalar - tensor + elif op == ElementwiseOp.PROD: + return scalar * tensor + elif op == ElementwiseOp.DIV: + return scalar / tensor + elif op == ElementwiseOp.GT: + return scalar > tensor + elif op == ElementwiseOp.LT: + return scalar < tensor + elif op == ElementwiseOp.EQUAL: + return scalar == tensor + elif op == ElementwiseOp.POW: + return torch.pow(tensor, scalar) + elif op == ElementwiseOp.MAX: + return torch.clamp(tensor, min=scalar) + else: + raise ValueError(f"Unsupported scalar-tensor operation: {op}") + + +def parallel_assert_scalar_tensor_op( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + op, + scalar, + input_tensor_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_tensor_expected_global_host, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Distribute input tensor + input_tensor_dtensor = distribute_tensor( + input_tensor_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ) + input_tensor_dtensor.requires_grad_(True) + + # Distribute expected outputs + if op == ElementwiseOp.GT or op == ElementwiseOp.LT or op == ElementwiseOp.EQUAL: + d_output_expected_dtensor = None + d_input_tensor_expected_dtensor = None + else: + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + d_input_tensor_expected_dtensor = distribute_tensor( + d_input_tensor_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Create copy to verify input isn't modified + input_tensor_dtensor_copy = input_tensor_dtensor.detach().clone().requires_grad_(input_tensor_dtensor.requires_grad) + + # Forward pass + output_dtensor_result = scalar_tensor_op(scalar, input_tensor_dtensor, op) + + # Verify input wasn't modified + assert_tensors_identical( + input_tensor_dtensor_copy.to_local(), input_tensor_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Test forward pass results + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Backward pass + if op == ElementwiseOp.GT or op == ElementwiseOp.LT or op == ElementwiseOp.EQUAL: + with pytest.raises(RuntimeError, match="tensors does not require grad and does not have a grad_fn"): + output_dtensor_result.backward(d_output_expected_dtensor) + else: + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradient for tensor + torch.testing.assert_close(input_tensor_dtensor.grad.to_local(), d_input_tensor_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(output_global_result_host, output_expected_global_host) + + if op != ElementwiseOp.GT and op != ElementwiseOp.LT and op != ElementwiseOp.EQUAL: + d_input_tensor_global_result_host = input_tensor_dtensor.grad.full_tensor().cpu() + torch.testing.assert_close(d_input_tensor_global_result_host, d_input_tensor_expected_global_host) + else: + # For GT, LT, and EQUAL operations, verify that no gradients were computed + assert input_tensor_dtensor.grad is None + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize( + "placements", [(Shard(0), Shard(1), Shard(2)), (Shard(0), Shard(1), Replicate())], ids=["shard", "replicate"] +) +@pytest.mark.parametrize( + "op", + [ + ElementwiseOp.SUM, + ElementwiseOp.SUB, + ElementwiseOp.PROD, + ElementwiseOp.DIV, + ElementwiseOp.GT, + ElementwiseOp.LT, + ElementwiseOp.EQUAL, + ElementwiseOp.POW, + ElementwiseOp.MAX, + ], + ids=["sum", "sub", "prod", "div", "gt", "lt", "equal", "pow", "max"], +) +def test_scalar_tensor_op_parallel(setup_env, placements, op): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 100 # Number of tokens + D = 32 # Hidden dimension + + seed = 42 + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + + # Create input tensor + scalar_value = torch.rand(1, device=device_type, generator=rng).item() * 20.0 - 10.0 # [-10, 10] + input_tensor_global = torch.empty( + (B, N, N, D), + requires_grad=(op != ElementwiseOp.GT and op != ElementwiseOp.LT and op != ElementwiseOp.EQUAL), + device=device_type, + ) + with torch.no_grad(): + input_tensor_global.uniform_(-10, 10, generator=rng) + if op == ElementwiseOp.POW: + input_tensor_global.abs_() # fractional power of negative number is complex + + # Run serial forward pass + input_tensor_global_host = input_tensor_global.detach().clone().cpu() + output_expected_global = serial_scalar_tensor_op(scalar_value, input_tensor_global, op) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Create upstream gradient and run backward pass + if op == ElementwiseOp.GT or op == ElementwiseOp.LT or op == ElementwiseOp.EQUAL: + d_output_expected_global = None + d_output_expected_global_host = None + else: + d_output_expected_global = torch.rand_like(output_expected_global) + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_scalar_tensor_op, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + op, + scalar_value, + input_tensor_global_host, + output_expected_global_host, + d_output_expected_global_host, + input_tensor_global.grad.detach().clone().cpu() + if op != ElementwiseOp.GT and op != ElementwiseOp.LT and op != ElementwiseOp.EQUAL + else None, + ) diff --git a/tests/distributed/model/layers/test_dtensor_embedding.py b/tests/distributed/model/layers/test_dtensor_embedding.py new file mode 100644 index 000000000..7156d3f2b --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_embedding.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""DTensor parity tests for EmbeddingParamsReplicated. + +Tests the distributed embedding wrapper against serial nn.Embedding. + +Verification checks: + V4a: multi-proc FW input tensor values unchanged by FW + V4b: multi-proc FW input tensor values unchanged after BW + V5: multi-proc BW input tensor values unchanged by BW + V8: multi-proc FW output tensor values close-to single-proc + V10: multi-proc weight gradient values close-to single-proc + V10b: replicated weight gradients identical across all CP ranks +""" + +import pytest +import torch +import torch.nn as nn +from torch.distributed.tensor import DTensor, Shard, distribute_tensor +from torch.testing import assert_close + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.embedding import EmbeddingParamsReplicated +from boltz.testing.utils import ( + assert_all_identical, + assert_tensors_identical, + skip_if_cuda_not_avail_or_device_count_less_than_word_size, + spawn_multiprocessing, +) + +SEED = 42 + + +def _assert_unchanged(actual, expected, *, serial=False): + """Shorthand for assert_tensors_identical with standard immutability kwargs.""" + assert_tensors_identical( + actual, + expected, + check_stride=True, + check_grad=False, + check_grad_fn=False, + check_storage_pointer=False, + check_storage_offset=serial, + ) + + +def _worker_embedding_parity( + rank: int, + input_on_host: torch.Tensor, + output_ref_on_host: torch.Tensor, + grad_output_on_host: torch.Tensor, + weight_grad_ref_on_host: torch.Tensor, + state_dict: dict, + num_embeddings: int, + embedding_dim: int, + padding_idx: int | None, + grid_group_sizes: dict, + device_type: str, + backend: str, + env_map: dict[str, str] | None = None, +): + """Worker: compare distributed EmbeddingParamsReplicated against serial nn.Embedding.""" + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + dm = DistributedManager() + + try: + serial_emb = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + serial_emb.load_state_dict(state_dict) + serial_emb = serial_emb.to(dm.device) + + dist_emb = EmbeddingParamsReplicated(serial_emb, dm.device_mesh_subgroups) + + # Pair-like placements for 3D input [B, N, N] + pair_placements = (Shard(0), Shard(1), Shard(2)) + x_dt = distribute_tensor(input_on_host.to(dm.device), dm.device_mesh_subgroups, pair_placements) + + # V4a setup + x_dt_clone = x_dt.detach().clone() + + out = dist_emb(x_dt) + + # V4a: FW input unchanged + _assert_unchanged(x_dt.to_local(), x_dt_clone.to_local()) + + # V8: forward parity + assert_close( + out.full_tensor(), + output_ref_on_host.to(dm.device), + atol=0, + rtol=0, + msg=lambda m: f"Rank {rank} forward output mismatch\n{m}", + ) + + # Backward + grad_out_dt = distribute_tensor(grad_output_on_host.to(dm.device), dm.device_mesh_subgroups, pair_placements) + out_clone = out.detach().clone().requires_grad_(out.requires_grad) + grad_out_dt_clone = grad_out_dt.detach().clone().requires_grad_(grad_out_dt.requires_grad) + + out.backward(grad_out_dt) + + # V4b: FW input unchanged after backward + _assert_unchanged(x_dt.to_local(), x_dt_clone.to_local()) + + # V5: BW inputs (values only) unchanged + assert_close(out.to_local(), out_clone.to_local(), atol=0, rtol=0) + assert_close(grad_out_dt.to_local(), grad_out_dt_clone.to_local(), atol=0, rtol=0) + + # V10: weight gradient parity + assert dist_emb.weight.grad is not None, "Weight gradient is None" + weight_grad = dist_emb.weight.grad + assert isinstance(weight_grad, DTensor), f"Weight grad should be DTensor, got {type(weight_grad)}" + weight_grad_full = weight_grad.full_tensor() + assert_close( + weight_grad_full, + weight_grad_ref_on_host.to(dm.device), + atol=1e-5, + rtol=1e-5, + msg=lambda m: f"Rank {rank} weight grad mismatch\n{m}", + ) + + # V10b: replicated weight gradients identical across all CP ranks + assert_all_identical(weight_grad_full.detach(), dm.group["cp"]) + + # Non-vacuous: weight gradient must be non-zero + assert weight_grad_full.abs().sum() > 0, "Weight gradient is all-zero" + + # Non-vacuous: verify sharding is active (local < global on at least one dim) + assert any( + out.to_local().shape[d] < out.shape[d] for d in range(out.ndim) + ), f"Sharding not active: local shape {out.to_local().shape} == global shape {out.shape}" + + finally: + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, padding_idx", + [ + # CPU dp=1 cp=(2,2): basic parity + (((1, (2, 2)), True, "cpu", "ENV"), None), + # CPU dp=2 cp=(2,2): DP + CP with padding_idx + (((2, (2, 2)), True, "cpu", "ENV"), 0), + # CUDA dp=2 cp=(1,1): DP-only, 2-GPU + (((2, (1, 1)), True, "cuda", "ENV"), None), + # CUDA dp=1 cp=(2,2): actual CP, 4-GPU + (((1, (2, 2)), True, "cuda", "ENV"), 0), + ], + indirect=("setup_env",), + ids=["cpu-dp1-cp2x2", "cpu-dp2-cp2x2-pad0", "cuda-dp2-cp1x1", "cuda-dp1-cp2x2-pad0"], +) +def test_dtensor_embedding_forward_backward(setup_env, padding_idx: int | None): + """EmbeddingParamsReplicated: distributed output and weight gradient match serial nn.Embedding.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + skip_if_cuda_not_avail_or_device_count_less_than_word_size(device_type=device_type, world_size=world_size) + + num_embeddings = 32 + embedding_dim = 16 + B = 2 * grid_group_sizes["dp"] + N = 8 * grid_group_sizes["cp"][0] + + with torch.random.fork_rng(devices=[], enabled=True): + torch.manual_seed(SEED) + + serial_emb = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.uniform_(serial_emb.weight, -0.5, 0.5) + state_dict = {k: v.cpu().clone() for k, v in serial_emb.state_dict().items()} + + # Pair-like integer input [B, N, N] with diverse indices + x = torch.randint(0, num_embeddings, (B, N, N)) + if padding_idx is not None: + x[:, :2, :2] = padding_idx + + out_ref = serial_emb(x) + grad_out = torch.randn_like(out_ref) + out_ref.backward(grad_out) + + weight_grad_ref = serial_emb.weight.grad.detach().cpu().clone() + + spawn_multiprocessing( + _worker_embedding_parity, + world_size, + x.cpu(), + out_ref.detach().cpu(), + grad_out.detach().cpu(), + weight_grad_ref, + state_dict, + num_embeddings, + embedding_dim, + padding_idx, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/distributed/model/layers/test_dtensor_flatten.py b/tests/distributed/model/layers/test_dtensor_flatten.py new file mode 100644 index 000000000..f96611b4c --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_flatten.py @@ -0,0 +1,643 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from math import isqrt +from typing import Dict, Optional + +import pytest +import torch +from torch.distributed.tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.flatten_and_unflatten import shardwise_flatten, shardwise_flatten_sharded +from boltz.testing.utils import assert_tensors_identical, seed_by_rank, spawn_multiprocessing + + +def compute_global_expectation(shape, start_dim, end_dim, device): + """Compute global expectation using standard PyTorch operations.""" + # Create tensor for flattening + x = torch.rand(*shape, device=device, requires_grad=True) + + # Compute on global tensor using standard flatten operation + y = torch.flatten(x, start_dim=start_dim, end_dim=end_dim) + + # Create gradients for backward pass + dy = torch.rand_like(y) + + # Backward pass on global tensor + y.backward(dy) + + # Collect input gradient + input_grad = x.grad.detach().clone() + + return x.detach().clone(), y.detach().clone(), input_grad, dy.detach().clone() + + +def compute_dtensor_native( + x_global: torch.Tensor, + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + start_dim: int, + end_dim: int, +) -> tuple[DTensor, DTensor]: + """Compute DTensor native operations for comparison.""" + # Create DTensor native input + x_dtensor = distribute_tensor(x_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + + # Forward pass with native DTensor flatten operation + y_dtensor_result = torch.flatten(x_dtensor, start_dim=start_dim, end_dim=end_dim) + + # Backward pass with native DTensor op + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + y_dtensor_result.backward(dy_dtensor) + + x_grad_dtensor = x_dtensor.grad + + return x_grad_dtensor, y_dtensor_result + + +def compute_shardwise_flatten_with_validation( + x_global: torch.Tensor, + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + start_dim: int, + end_dim: int, + label_test_case: str, +) -> tuple[DTensor, DTensor, DTensor]: + """ + Compute shardwise_flatten forward and backward pass with input validation checks. + + Returns: + y_dtensor_result: Forward pass result + x_dtensor: Input tensor with computed gradient + dy_dtensor: Distributed upstream gradient + """ + # Create DTensor input + x_dtensor = distribute_tensor(x_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + x_dtensor_copy = x_dtensor.detach().clone().requires_grad_(True) + + # Compute on distributed tensor using shardwise_flatten + y_dtensor_result = shardwise_flatten(x_dtensor, start_dim=start_dim, end_dim=end_dim) + + # verify no change to the fwd input + assert_tensors_identical(x_dtensor.to_local(), x_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # Distribute the upstream adjoint for backward pass + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + + # Perform backward pass + dy_dtensor_copy = dy_dtensor.detach().clone() + y_dtensor_result.backward(dy_dtensor) + + # verify no change to the bwd input + assert_tensors_identical(dy_dtensor.to_local(), dy_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # verify input gradient placements are consistent with input placements + assert ( + x_dtensor.grad.placements == input_placements + ), f"{label_test_case} inconsistent input gradient placements with input placements" + + return y_dtensor_result, x_dtensor, dy_dtensor + + +def parallel_assert_dtensor_flatten( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # each rank uses the same seed to generate the same input tensors + seed_by_rank(0, seed=42) + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters - 8D tensor for comprehensive testing + shape = (3, 5, grid_group_sizes["dp"] * 2, 5, size_ring * 4, 5, 3, 2) + # Shard the sequence dimension (dim=2) and another dimension (dim=4) for input tensor + # this emulates the sharded single representation in the Boltz model + input_placements = (Shard(dim=2), Shard(dim=4), Replicate()) + + # Test valid flattening dimensions (not sharded) + # Sharded dims are 2 and 4, so valid ranges must not include these dimensions + valid_flatten_params = [ + (0, 1), # flatten dims 0,1 (no sharded dims) + (1, 1), # flatten just dim 1 (no sharded dims) + (3, 3), # flatten just dim 3 (no sharded dims) + (5, 7), # flatten dims 5,6,7 (no sharded dims) + (-2, -1), # flatten dims 6,7 (negative indexing) + (-3, -1), # flatten dims 5,6,7 (negative indexing) + (-8, -7), # flatten dims 0,1 + ] + + # Test invalid flattening dimensions (include sharded dims 2 and/or 4) + invalid_flatten_params = [ + (0, 2), # flatten dims 0,1,2 (includes sharded dim=2) + (0, 3), # flatten dims 0,1,2,3 (includes sharded dim=2) + (2, 3), # flatten dims 2,3 (includes sharded dim=2) + (1, 4), # flatten dims 1,2,3,4 (includes both sharded dims 2,4) + (3, 5), # flatten dims 3,4,5 (includes sharded dim=4) + (4, 5), # flatten dims 4,5 (includes sharded dim=4) + (0, 4), # flatten dims 0,1,2,3,4 (includes both sharded dims) + (-6, -4), # flatten dims 2,3,4 (includes both sharded dims 2,4) + (-4, -1), # flatten dims 4,5,6,7 (includes sharded dim=4) + ] + + # Test valid flattening dimensions + for start_dim, end_dim in valid_flatten_params: + label_test_case = f"for start_dim={start_dim}, end_dim={end_dim}\n" + + # Compute global expectations + x_global, y_expected_global, x_grad_expected_global, dy_global = compute_global_expectation( + shape, start_dim, end_dim, manager.device + ) + + # use DTensor native op as an alternative reference + x_grad_dtensor_native, y_dtensor_result_native = compute_dtensor_native( + x_global, dy_global, manager.device_mesh_subgroups, input_placements, start_dim, end_dim + ) + + # Compute shardwise_flatten forward and backward with validation + y_dtensor_result, x_dtensor, dy_dtensor = compute_shardwise_flatten_with_validation( + x_global, dy_global, manager.device_mesh_subgroups, input_placements, start_dim, end_dim, label_test_case + ) + + # =================================================================== + # BLOCK 1: Check against DTensor native reference + # =================================================================== + + # check metadata against DTensor native + assert ( + y_dtensor_result.placements == y_dtensor_result_native.placements + ), f"{label_test_case} placements mismatch" + assert y_dtensor_result.shape == y_dtensor_result_native.shape, f"{label_test_case} shape mismatch" + assert y_dtensor_result.stride() == y_dtensor_result_native.stride(), f"{label_test_case} stride mismatch" + + # compare forward result with native DTensor op + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_result_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # compare global tensors between shardwise_flatten and native DTensor results + y_result_global = y_dtensor_result.full_tensor() + y_result_global_native = y_dtensor_result_native.full_tensor() + + torch.testing.assert_close( + y_result_global, + y_result_global_native, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs native: {m}", + ) + + # assert input gradient metadata and values against DTensor native + assert ( + x_dtensor.grad.placements == x_grad_dtensor_native.placements + ), f"{label_test_case} input gradient placements mismatch" + assert x_dtensor.grad.shape == x_grad_dtensor_native.shape, f"{label_test_case} input gradient shape mismatch" + assert ( + x_dtensor.grad.stride() == x_grad_dtensor_native.stride() + ), f"{label_test_case} input gradient stride mismatch" + + torch.testing.assert_close( + x_dtensor.grad.to_local(), + x_grad_dtensor_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + torch.testing.assert_close( + x_dtensor.grad.full_tensor(), + x_grad_dtensor_native.full_tensor(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + # =================================================================== + # BLOCK 2: Check against global serial expectation + # =================================================================== + y_dtensor_expected = distribute_tensor( + y_expected_global, manager.device_mesh_subgroups, y_dtensor_result.placements + ) + + # Compare results with expected local shards + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_expected.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # compare forward result with global expectation + torch.testing.assert_close( + y_result_global, + y_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs global expectation: {m}", + ) + + # create distributed tensor from global result for local shard comparison + x_grad_expected_dtensor = distribute_tensor( + x_grad_expected_global, manager.device_mesh_subgroups, input_placements + ) + + # compare local shard with expected + torch.testing.assert_close( + x_dtensor.grad.to_local(), + x_grad_expected_dtensor.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + torch.testing.assert_close( + x_dtensor.grad.full_tensor(), + x_grad_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + # Test invalid flattening dimensions (should raise NotImplementedError) + for start_dim, end_dim in invalid_flatten_params: + label_test_case = f"for invalid start_dim={start_dim}, end_dim={end_dim}\n" + + # Compute global expectations (this should work fine) + x_global, _, _, _ = compute_global_expectation(shape, start_dim, end_dim, manager.device) + + # Create DTensor input + x_dtensor = distribute_tensor(x_global, manager.device_mesh_subgroups, input_placements) + x_dtensor.requires_grad = True + + # This should raise due to sharded dimension in flatten range + with pytest.raises(NotImplementedError, match="Flattening dimension .* sharded by device_mesh axis"): + shardwise_flatten(x_dtensor, start_dim=start_dim, end_dim=end_dim) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_dtensor_flatten(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_dtensor_flatten, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +def compute_shardwise_flatten_sharded_with_validation( + x_global: torch.Tensor, + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + start_dim: int, + end_dim: int, + label_test_case: str, +) -> tuple[DTensor, DTensor, DTensor]: + """ + Compute shardwise_flatten_sharded forward and backward pass with input validation checks. + + This function is for testing flatten operations that involve the sharded dimension. + + Returns: + y_dtensor_result: Forward pass result + x_dtensor: Input tensor with computed gradient + dy_dtensor: Distributed upstream gradient + """ + # Create DTensor input + x_dtensor = distribute_tensor(x_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + x_dtensor_copy = x_dtensor.detach().clone().requires_grad_(True) + + # Compute on distributed tensor using shardwise_flatten_sharded + y_dtensor_result = shardwise_flatten_sharded(x_dtensor, start_dim=start_dim, end_dim=end_dim) + + # verify no change to the fwd input + assert_tensors_identical(x_dtensor.to_local(), x_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # Distribute the upstream adjoint for backward pass + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + + # Perform backward pass + dy_dtensor_copy = dy_dtensor.detach().clone() + y_dtensor_result.backward(dy_dtensor) + + # verify no change to the bwd input + assert_tensors_identical(dy_dtensor.to_local(), dy_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # verify input gradient placements are consistent with input placements + assert ( + x_dtensor.grad.placements == input_placements + ), f"{label_test_case} inconsistent input gradient placements with input placements" + + return y_dtensor_result, x_dtensor, dy_dtensor + + +def parallel_assert_dtensor_flatten_sharded( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + """ + Test shardwise_flatten_sharded which flattens dimensions starting from a sharded axis. + + Unlike shardwise_flatten, this function is designed to flatten dimensions that include + the sharded dimension. DTensor native op doesn't support this, so we only compare + against the global serial version as reference. + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # each rank uses the same seed to generate the same input tensors + seed_by_rank(0, seed=42) + + size_dp = grid_group_sizes["dp"] + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters - 6D tensor + # Shape designed so that sharded dims can be evenly divided + # dim=1 is sharded by dp, dim=3 is sharded by ring (first dim of cp 2D mesh) + shape = (2, size_dp * 4, 3, size_ring * 6, 5, 4) + + # Test cases: flatten starting from a sharded dimension + # Input is sharded on dim=1 (by dp) and dim=3 (by ring) + # shardwise_flatten_sharded requires start_dim to be the sharded dimension + + # Test Case 1: Shard on dim=1, flatten dims 1,2 + # After flatten: shape becomes (2, size_dp*4*3, size_ring*6, 5, 4) + # The flattened dim (1) remains sharded by dp + test_cases_dim1 = [ + # (input_placements, start_dim, end_dim, description) + ((Shard(dim=1), Replicate(), Replicate()), 1, 2, "flatten dims 1,2 sharded on dim=1"), + ((Shard(dim=1), Replicate(), Replicate()), 1, 3, "flatten dims 1,2,3 sharded on dim=1"), + ((Shard(dim=1), Replicate(), Replicate()), 1, -1, "flatten dims 1 to end sharded on dim=1"), + ] + + # Test Case 2: Shard on dim=3, flatten dims 3,4 + test_cases_dim3 = [ + ((Replicate(), Shard(dim=3), Replicate()), 3, 4, "flatten dims 3,4 sharded on dim=3"), + ((Replicate(), Shard(dim=3), Replicate()), 3, 5, "flatten dims 3,4,5 sharded on dim=3"), + ((Replicate(), Shard(dim=3), Replicate()), 3, -1, "flatten dims 3 to end sharded on dim=3"), + ] + + # Test Case 3: Both dim=1 and dim=3 are sharded + # dim=1 sharded by dp (mesh dim 0), dim=3 sharded by ring (mesh dim 1) + # Only test flattening from dim=3 onwards, so dim=1's shard placement is unaffected. + test_cases_both_sharded = [ + # Flatten from dim=3, dim=1 stays sharded at the same position + ((Shard(dim=1), Shard(dim=3), Replicate()), 3, 4, "flatten dims 3,4 with both dim=1,3 sharded"), + ((Shard(dim=1), Shard(dim=3), Replicate()), 3, 5, "flatten dims 3,4,5 with both dim=1,3 sharded"), + ((Shard(dim=1), Shard(dim=3), Replicate()), 3, -1, "flatten dims 3 to end with both dim=1,3 sharded"), + ] + + # Test Case 4: Both dim=1 and dim=3 are sharded, flatten from dim=1. + # Flattening at a lower dim shifts the higher shard's placement index. + # E.g., flatten dims 1,2 removes 1 dim → Shard(dim=3) must become Shard(dim=2). + # Format: (input_placements, start_dim, end_dim, expected_output_placements, description) + test_cases_placement_shift = [ + # flatten dims 1,2: removes 1 dim → Shard(3) shifts to Shard(2) + ( + (Shard(dim=1), Shard(dim=3), Replicate()), + 1, + 2, + (Shard(dim=1), Shard(dim=2), Replicate()), + "flatten dims 1,2 shifting Shard(3)->Shard(2)", + ), + ] + + all_test_cases = [ + (pl, sd, ed, pl, desc) for pl, sd, ed, desc in test_cases_dim1 + test_cases_dim3 + test_cases_both_sharded + ] + test_cases_placement_shift + + for input_placements, start_dim, end_dim, expected_output_placements, description in all_test_cases: + label_test_case = f"{description} (start_dim={start_dim}, end_dim={end_dim})\n" + + # Compute global expectations using standard PyTorch operations + x_global, y_expected_global, x_grad_expected_global, dy_global = compute_global_expectation( + shape, start_dim, end_dim, manager.device + ) + + # NOTE: DTensor native op doesn't support flattening involving sharded dimensions, + # so we skip DTensor native comparison and only use global serial version as reference. + + # Compute shardwise_flatten_sharded forward and backward with validation + y_dtensor_result, x_dtensor, dy_dtensor = compute_shardwise_flatten_sharded_with_validation( + x_global, dy_global, manager.device_mesh_subgroups, input_placements, start_dim, end_dim, label_test_case + ) + + # =================================================================== + # Check output shape and placements + # =================================================================== + # Verify output shape matches expected global shape + assert ( + y_dtensor_result.shape == y_expected_global.shape + ), f"{label_test_case} output shape mismatch: got {y_dtensor_result.shape}, expected {y_expected_global.shape}" + + # Verify output placements: Shard dims beyond end_dim shift down by + # (end_dim - start_dim) because those intermediate dims are merged. + assert y_dtensor_result.placements == expected_output_placements, ( + f"{label_test_case} output placements mismatch: got {y_dtensor_result.placements}, " + f"expected {expected_output_placements}" + ) + + # =================================================================== + # Check against global serial expectation + # =================================================================== + # Distribute expected output to compare local shards + y_dtensor_expected = distribute_tensor( + y_expected_global, manager.device_mesh_subgroups, y_dtensor_result.placements + ) + + # Compare forward result local shards + assert_tensors_identical( + y_dtensor_result.to_local().detach(), + y_dtensor_expected.to_local().detach(), + ) + + # Compare forward result global tensor + y_result_global = y_dtensor_result.full_tensor() + assert_tensors_identical( + y_result_global.detach(), + y_expected_global.detach(), + ) + + # =================================================================== + # Check backward pass against global serial expectation + # =================================================================== + # Verify input gradient shape + assert x_dtensor.grad.shape == x_grad_expected_global.shape, ( + f"{label_test_case} input gradient shape mismatch: got {x_dtensor.grad.shape}, " + f"expected {x_grad_expected_global.shape}" + ) + + # Distribute expected input gradient for local shard comparison + x_grad_expected_dtensor = distribute_tensor( + x_grad_expected_global, manager.device_mesh_subgroups, input_placements + ) + + # Compare input gradient local shards + assert_tensors_identical( + x_dtensor.grad.to_local().detach(), + x_grad_expected_dtensor.to_local().detach(), + ) + + # Compare input gradient global tensor + assert_tensors_identical( + x_dtensor.grad.full_tensor().detach(), + x_grad_expected_global.detach(), + ) + + # =================================================================== + # Test invalid cases that should raise ValueError + # =================================================================== + + # Create a test tensor for invalid cases + x_global_invalid = torch.rand(*shape, device=manager.device) + + # Invalid Case 1: start_dim is NOT sharded + # Input is sharded on dim=1, but we try to flatten starting from dim=0 (not sharded) + invalid_not_sharded_cases = [ + # (input_placements, start_dim, end_dim, expected_error_pattern) + ((Shard(dim=1), Replicate(), Replicate()), 0, 1, "input is not sharded along start_dim"), + ((Shard(dim=1), Replicate(), Replicate()), 2, 3, "input is not sharded along start_dim"), + ((Replicate(), Shard(dim=3), Replicate()), 0, 2, "input is not sharded along start_dim"), + ((Replicate(), Shard(dim=3), Replicate()), 4, 5, "input is not sharded along start_dim"), + ] + + for input_placements, start_dim, end_dim, error_pattern in invalid_not_sharded_cases: + x_dtensor = distribute_tensor(x_global_invalid.clone(), manager.device_mesh_subgroups, input_placements) + with pytest.raises(ValueError, match=error_pattern): + shardwise_flatten_sharded(x_dtensor, start_dim=start_dim, end_dim=end_dim) + + # Invalid Case 2: start_dim > end_dim + invalid_dim_order_cases = [ + ((Shard(dim=1), Replicate(), Replicate()), 3, 1, "must be <="), + ((Replicate(), Shard(dim=3), Replicate()), 5, 3, "must be <="), + ] + + for input_placements, start_dim, end_dim, error_pattern in invalid_dim_order_cases: + x_dtensor = distribute_tensor(x_global_invalid.clone(), manager.device_mesh_subgroups, input_placements) + with pytest.raises(ValueError, match=error_pattern): + shardwise_flatten_sharded(x_dtensor, start_dim=start_dim, end_dim=end_dim) + + # Invalid Case 3: Dimension out of range + invalid_out_of_range_cases = [ + ((Shard(dim=1), Replicate(), Replicate()), 10, 11, "out of range"), + ((Shard(dim=1), Replicate(), Replicate()), 1, 10, "out of range"), + ] + + for input_placements, start_dim, end_dim, error_pattern in invalid_out_of_range_cases: + x_dtensor = distribute_tensor(x_global_invalid.clone(), manager.device_mesh_subgroups, input_placements) + with pytest.raises(ValueError, match=error_pattern): + shardwise_flatten_sharded(x_dtensor, start_dim=start_dim, end_dim=end_dim) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_dtensor_flatten_sharded(setup_env): + """Test shardwise_flatten_sharded for flattening dimensions starting from a sharded axis.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_dtensor_flatten_sharded, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) diff --git a/tests/distributed/model/layers/test_dtensor_gather.py b/tests/distributed/model/layers/test_dtensor_gather.py new file mode 100644 index 000000000..9a855fa0e --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_gather.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import pytest +import torch +import torch.nn.functional as F +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.gather import distributed_gather +from boltz.testing.utils import assert_tensors_identical, seed_by_rank, spawn_multiprocessing + + +def parallel_assert_gather(rank, grid_group_sizes, device_type, backend, env_map, dtype): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device = manager.device + + seed_by_rank(0, 42) + + shape_extras = [ + (None,), + (4, None), + (None, 3), + (2, None, 3), + ] + + for shape_extra in shape_extras: + if sum(1 for x in shape_extra if x is None) != 1: + raise ValueError(f"There can be one and only one 'None' element in shape_extra but got {shape_extra}") + axis = shape_extra.index(None) + + for N_per_rank, K_per_rank, W in [(8, 6, 2), (16, 4, 3)]: + for device_mesh in [manager.device_mesh_subgroups, manager.device_mesh]: + # placements: shard along axis (on mesh dim 0) + mesh_ndim = device_mesh.ndim + size_group_shard_axis = None + if axis >= 1 and mesh_ndim >= 2: + # shard leading tensor axes as well as 'axis + placements = (Shard(0), Shard(axis)) + (Replicate(),) * (mesh_ndim - 2) + size_group_shard_axis = device_mesh.size(1) + elif mesh_ndim >= 2: + # axis == 0 only shard 'axis' + placements = (Replicate(),) * (mesh_ndim - 1) + (Shard(axis),) + size_group_shard_axis = device_mesh.size(-1) + else: + # not enough mesh dim to shard other than 'axis' + placements = (Shard(axis),) + (Replicate(),) * (mesh_ndim - 1) + size_group_shard_axis = device_mesh.size(0) + + if size_group_shard_axis is None: + raise ValueError(f"size_group_shard_axis is None for axis {axis} and device_mesh {device_mesh}") + + N = N_per_rank * size_group_shard_axis + K = K_per_rank * size_group_shard_axis + + x_shape = shape_extra[:axis] + (N,) + shape_extra[axis + 1 :] + idx_shape = shape_extra[:axis] + (K, W) + + # Test both without mask (None) and with random mask + for use_mask in [False, True]: + label = f"x_shape:{x_shape}, idx_shape:{idx_shape}, axis:{axis}" + if use_mask: + label += " (masked)" + idx_mask_global = torch.rand(idx_shape, device=device) > 0.5 + else: + idx_mask_global = None + + x_global = torch.randn(x_shape, dtype=dtype, device=device, requires_grad=True) + idx_global = torch.randint(0, N, idx_shape, device=device) + + # Reference using one-hot + einsum + idx_onehot = F.one_hot(idx_global, num_classes=N).to(dtype=dtype) + # Zero out one-hot at invalid positions when mask is provided + if idx_mask_global is not None: + idx_onehot = idx_onehot * idx_mask_global.unsqueeze(-1).to(dtype=dtype) + + x_flat = x_global.reshape( + *x_global.shape[:axis], x_global.shape[axis], x_global.shape[axis + 1 :].numel() + ) + out_ref_flat = torch.einsum("...nd,...kwn->...kwd", x_flat, idx_onehot) + out_ref = out_ref_flat.reshape( + *x_global.shape[:axis], idx_shape[-2], idx_shape[-1], *x_global.shape[axis + 1 :] + ) + grad_out = torch.randn_like(out_ref) + + out_ref.backward(grad_out) + grad_x_ref = x_global.grad + + x_dtensor = distribute_tensor(x_global.detach().clone(), device_mesh, placements).requires_grad_( + True + ) + idx_dtensor = distribute_tensor(idx_global, device_mesh, placements) + idx_mask_dtensor = ( + distribute_tensor(idx_mask_global, device_mesh, placements) + if idx_mask_global is not None + else None + ) + + out_dtensor = distributed_gather( + x_dtensor, idx_dtensor, axis=axis, are_ids_contiguous=True, idx_mask=idx_mask_dtensor + ) + + out_local = out_dtensor.full_tensor().requires_grad_(True) + assert_tensors_identical( + out_local, + out_ref, + check_stride=False, + check_grad=False, + check_grad_fn=False, + msg=lambda m: f"{label} fwd output mismatch:\n {m}", + ) + + grad_out_dtensor = distribute_tensor( + grad_out.detach().clone(), out_dtensor.device_mesh, out_dtensor.placements + ) + + out_dtensor.backward(grad_out_dtensor) + + grad_x_local = x_dtensor.grad.full_tensor() + assert_tensors_identical( + grad_x_local, + grad_x_ref, + check_grad=False, + check_grad_fn=False, + rtol=1e-10, + atol=1e-10, + msg=lambda m: f"{label} bwd input gradient mismatch:\n {m}", + ) + + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_distributed_gather(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_gather, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + torch.float64, + ) diff --git a/tests/distributed/model/layers/test_dtensor_layernorm_nocastbf16.py b/tests/distributed/model/layers/test_dtensor_layernorm_nocastbf16.py new file mode 100644 index 000000000..0b777f9d1 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_layernorm_nocastbf16.py @@ -0,0 +1,369 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for LayerNormParamsReplicatedNoAutoCastBF16 distributed layer. + +This module tests the distributed LayerNormParamsReplicatedNoAutoCastBF16 +(from boltz.distributed.model.layers.triangular_attention) against the serial +LayerNorm (from boltz.model.layers.triangular_attention.primitives). +""" + +from collections import OrderedDict +from typing import Dict, Optional + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.triangular_attention import LayerNormParamsReplicatedNoAutoCastBF16 +from boltz.model.layers.triangular_attention.primitives import LayerNorm as SerialLayerNormNoAutoCastBF16 +from boltz.testing.utils import ( + assert_no_percentile_upshift, + assert_tensors_identical, + init_module_params_uniform, + init_tensors_uniform, + spawn_multiprocessing, +) + + +def _compute_references( + B: int, seq_len: int, c_in: int, min_val_init: float, max_val_init: float, device: str = "cpu" +) -> dict: + """Compute FP64 and FP32 serial references for layernorm. + + Computes on ``device`` for numerical consistency with the DTensor test + (CUDA tests should pass ``device="cuda"``). Results are moved to CPU + for safe transfer across the mp.spawn boundary. + """ + input_x_fp64 = torch.empty((B, seq_len, c_in), dtype=torch.float64, device=device, requires_grad=True) + init_tensors_uniform([input_x_fp64], low=min_val_init, high=max_val_init) + + ref_module = SerialLayerNormNoAutoCastBF16(c_in) + ref_module = ref_module.to(dtype=torch.float64, device=device).train() + init_module_params_uniform(ref_module, low=min_val_init, high=max_val_init) + state_dict_fp64 = {k: v.detach().clone().cpu() for k, v in ref_module.state_dict().items()} + + output_fp64 = ref_module(input_x_fp64) + d_output_fp64 = torch.rand_like(output_fp64) + output_fp64.backward(d_output_fp64) + + refs = { + "input_x": input_x_fp64.detach().clone().cpu(), + "output": output_fp64.detach().clone().cpu(), + "d_output": d_output_fp64.detach().clone().cpu(), + "d_input_x": input_x_fp64.grad.detach().clone().cpu(), + "grad_params": {name: p.grad.detach().clone().cpu() for name, p in ref_module.named_parameters()}, + "state_dict": state_dict_fp64, + } + + # FP32 serial reference for three-way error histogram comparison + input_x_fp32 = refs["input_x"].to(dtype=torch.float32, device=device).requires_grad_(True) + ref_module_fp32 = SerialLayerNormNoAutoCastBF16(c_in) + ref_module_fp32.load_state_dict(state_dict_fp64) + ref_module_fp32 = ref_module_fp32.to(dtype=torch.float32, device=device).train() + + output_fp32 = ref_module_fp32(input_x_fp32) + output_fp32.backward(refs["d_output"].to(dtype=torch.float32, device=device)) + + refs["output_fp32"] = output_fp32.detach().clone().cpu() + refs["d_input_x_fp32"] = input_x_fp32.grad.detach().clone().cpu() + refs["grad_params_fp32"] = {name: p.grad.detach().clone().cpu() for name, p in ref_module_fp32.named_parameters()} + + return refs + + +def parallel_assert_dtensor_layernorm_nocastbf16( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_per_rank: Optional[Dict[str, str]], + dtype: torch.dtype, + c_in: int, + refs_cpu: dict, + check_error_hist: bool, +): + """Test distributed LayerNormParamsReplicatedNoAutoCastBF16 in a parallel environment. + + Reference data is computed on CPU in the main process and passed to workers + via mp.spawn. Workers move tensors to their device before use. + """ + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Move CPU references to worker's device + refs = { + "input_x": refs_cpu["input_x"].to(device=manager.device), + "output": refs_cpu["output"].to(device=manager.device), + "d_output": refs_cpu["d_output"].to(device=manager.device), + "d_input_x": refs_cpu["d_input_x"].to(device=manager.device), + "grad_params": {k: v.to(device=manager.device) for k, v in refs_cpu["grad_params"].items()}, + "state_dict": {k: v.to(device=manager.device) for k, v in refs_cpu["state_dict"].items()}, + "output_fp32": refs_cpu["output_fp32"].to(device=manager.device), + "d_input_x_fp32": refs_cpu["d_input_x_fp32"].to(device=manager.device), + "grad_params_fp32": {k: v.to(device=manager.device) for k, v in refs_cpu["grad_params_fp32"].items()}, + } + + if torch.finfo(dtype).resolution < torch.finfo(refs["output"].dtype).resolution: + raise ValueError( + f"Target dtype {dtype} has higher precision than reference output's dtype {refs['output'].dtype}" + ) + + # --- build distributed module --- + module_serial = SerialLayerNormNoAutoCastBF16(c_in) + module_serial = module_serial.to(dtype=dtype, device=manager.device) + module_serial.load_state_dict(refs["state_dict"]) + + module = LayerNormParamsReplicatedNoAutoCastBF16(module_serial, manager.device_mesh_subgroups) + module = module.to(device=manager.device).train() + + # Input: (B, seq_len, c_in) — shard on batch (dim 0) and seq (dim 1), replicate feature dim + placements_input = (Shard(0), Shard(1), Replicate()) + + # distribute_tensor with default src_data_rank=0 keeps NCCL streams + # synchronised with the broadcasts that happened inside the module ctor. + input_x_dtensor = distribute_tensor( + refs["input_x"].to(dtype=dtype), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ).requires_grad_(True) + + d_output_dtensor = distribute_tensor( + refs["d_output"].to(dtype=dtype), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ) + + # Forward and backward pass + output_dtensor_result = module(input_x_dtensor) + output_dtensor_result.backward(d_output_dtensor) + + # Use full_tensor() to get fully reduced parameter gradients. This works + # correctly whether the backward returns Replicate placements (eager + # reduction, no-op) or Partial placements (triggers all-reduce). + grad_params_global: dict[str, torch.Tensor] = {} + for name, param in module.named_parameters(): + if param.grad is not None: + grad_params_global[name] = param.grad.full_tensor() + + if check_error_hist: + output_expected_dtensor = distribute_tensor( + refs["output"].to(dtype=dtype), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ) + d_input_x_expected_dtensor = distribute_tensor( + refs["d_input_x"].to(dtype=dtype), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ) + output_fp32_dtensor = distribute_tensor( + refs["output_fp32"], + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ) + d_input_x_fp32_dtensor = distribute_tensor( + refs["d_input_x_fp32"], + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ) + + # --- All collectives done. Only local assertions below. --- + + assert_no_percentile_upshift( + output_dtensor_result.to_local(), + output_expected_dtensor.to_local(), + output_fp32_dtensor.to_local(), + names_input=("output_cp_fp32", "output_serial_fp64", "output_serial_fp32"), + ) + + assert_no_percentile_upshift( + input_x_dtensor.grad.to_local(), + d_input_x_expected_dtensor.to_local(), + d_input_x_fp32_dtensor.to_local(), + names_input=("d_input_x_cp_fp32", "d_input_x_serial_fp64", "d_input_x_serial_fp32"), + ) + + # Parameter gradient tolerance: the distributed sum splits the + # reduction across DP ranks (different FP32 accumulation order). + # For c_in=8 the bias gradient has only 8 elements; each is a sum + # of B*seq_len/dp values (~256). The FP32 ULP at magnitude 256 + # is 256·2^{-23} ≈ 3e-5, so a 1-ULP accumulation-order shift + # produces ~1.5e-5 absolute error — above the default atol=1e-5. + # Use 5e-5 (≈2 ULP at magnitude 256) for comfortable first- + # principles margin. + perc_param_grad = OrderedDict({0.25: (5e-5, 1e-4), 0.5: (5e-5, 1e-4), 0.75: (5e-5, 1e-4), 0.95: (5e-5, 1e-4)}) + + for name, grad_expected in refs["grad_params"].items(): + if name not in grad_params_global: + raise ValueError(f"Parameter {name}'s gradient is not found in the distributed module") + + assert_no_percentile_upshift( + grad_params_global[name], + grad_expected.to(dtype=grad_params_global[name].dtype), + refs["grad_params_fp32"][name], + perc=perc_param_grad, + names_input=(f"d_{name}_cp_fp32", f"d_{name}_serial_fp64", f"d_{name}_serial_fp32"), + ) + else: + output_expected_dtensor = distribute_tensor( + refs["output"].to(dtype=dtype), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ) + d_input_x_expected_dtensor = distribute_tensor( + refs["d_input_x"].to(dtype=dtype), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ) + + # Gather full tensors (collectives) before any assertions. + output_global_result = output_dtensor_result.full_tensor().cpu() + d_input_x_global_result = input_x_dtensor.grad.full_tensor().cpu() + + # all_gather for assert_all_identical — do the collective now, + # assert on the gathered data later. + gathered_param_grads: dict[str, list[torch.Tensor]] = {} + cp_group = manager.group["cp"] + cp_world_size = torch.distributed.get_world_size(cp_group) + for name in grad_params_global: + grad_on_device = grad_params_global[name].to(device=manager.device) + tensor_list = [torch.empty_like(grad_on_device) for _ in range(cp_world_size)] + torch.distributed.all_gather(tensor_list, grad_on_device, group=cp_group) + gathered_param_grads[name] = tensor_list + + # --- All collectives done. Only local assertions below. --- + + assert output_dtensor_result.shape == output_expected_dtensor.shape + assert output_dtensor_result.stride() == output_expected_dtensor.stride() + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + assert input_x_dtensor.grad.shape == d_input_x_expected_dtensor.shape + assert input_x_dtensor.grad.stride() == d_input_x_expected_dtensor.stride() + torch.testing.assert_close(input_x_dtensor.grad.to_local(), d_input_x_expected_dtensor.to_local()) + + torch.testing.assert_close(output_global_result, refs["output"].to(dtype=dtype).cpu()) + torch.testing.assert_close(d_input_x_global_result, refs["d_input_x"].to(dtype=dtype).cpu()) + + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in refs["grad_params"]: + raise ValueError(f"Parameter {name} has a gradient but is not in the reference") + torch.testing.assert_close(grad_params_global[name], refs["grad_params"][name].to(dtype=dtype)) + grad_on_device = grad_params_global[name].to(device=manager.device) + for gathered in gathered_param_grads[name]: + assert_tensors_identical(gathered, grad_on_device) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, dtype, check_error_hist", + ( + params_test := [ + ## CUDA tests (2 GPUs) + (((2, (1, 1)), True, "cuda", "ENV"), torch.float32, True), + (((2, (1, 1)), True, "cuda", "ENV"), torch.float64, True), + ## CUDA tests (4 GPUs) + (((1, (2, 2)), True, "cuda", "ENV"), torch.float32, True), + (((1, (2, 2)), True, "cuda", "ENV"), torch.float32, True), + ## CUDA tests (8 GPUs) + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, True), + ## CPU tests + (((2, (3, 3)), True, "cpu", "ENV"), torch.float32, True), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}, check_error_hist:{x[2]}" + for x in params_test + ], +) +@pytest.mark.parametrize("c_in", [8, 128]) +def test_dtensor_layernorm_nocastbf16( + setup_env: tuple[dict, int, str, str, str, dict[str, str]], + dtype: torch.dtype, + check_error_hist: bool, + c_in: int, +): + """Test distributed LayerNormParamsReplicatedNoAutoCastBF16 across multiple processes. + + When check_error_hist=True, uses the three-way error histogram comparison: + CP FP32 vs serial FP64 ref, compared against serial FP32 vs serial FP64 ref. + Uses default tolerances from assert_no_percentile_upshift (same as triangle attention test). + + When check_error_hist=False, uses exact match against the FP64 reference. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + if check_error_hist: + if grid_group_sizes["dp"] > 2: + pytest.skip("skip error histogram check for dp > 1 to save test time") + + # Use larger dimensions for error histogram check to emulate realistic workloads + test_large_model = check_error_hist or dtype == torch.float64 + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + if test_large_model: + seq_len = size_ring * 128 + min_val_init = -5e-2 + max_val_init = 5e-2 + else: + seq_len = size_ring * 4 + min_val_init = -0.5 + max_val_init = 0.5 + + # Compute serial references on the same backend as the DTensor test for + # numerical consistency. Results are moved to CPU for mp.spawn transfer. + torch.manual_seed(42) + refs_cpu = _compute_references(B, seq_len, c_in, min_val_init, max_val_init, device=device_type) + + spawn_multiprocessing( + parallel_assert_dtensor_layernorm_nocastbf16, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + c_in, + refs_cpu, + check_error_hist, + ) diff --git a/tests/distributed/model/layers/test_dtensor_outer_gather.py b/tests/distributed/model/layers/test_dtensor_outer_gather.py new file mode 100644 index 000000000..811756087 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_outer_gather.py @@ -0,0 +1,676 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import math + +import pytest +import torch +import torch.nn.functional as F +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.outer_gather import ( + OuterGather, + compute_interval_overlap, + distributed_outer_gather, + get_overlap_from_peers, + outer_gather, +) +from boltz.distributed.model.layers.shardwise_op import shardwise_argmax +from boltz.testing.utils import assert_tensors_identical, init_tensors_uniform, seed_by_rank, spawn_multiprocessing + + +@pytest.mark.parametrize("W", [3, 4, 7, 32], ids=lambda x: f"W:{x}") +@pytest.mark.parametrize("H", [5, 8, 16, 128], ids=lambda x: f"H:{x}") +@pytest.mark.parametrize("K", [1, 2, 10, 50], ids=lambda x: f"K:{x}") +@pytest.mark.parametrize("n", [10, 20], ids=lambda x: f"n:{x}") +@pytest.mark.parametrize("m", [7, 11], ids=lambda x: f"m:{x}") +@pytest.mark.parametrize( + "shape_extra", + [ + (None,), + (4, None), + (None, 3), + (2, None, 3), + (2, 3, None), + (None, 2, 3), + (2, 3, None, 4), + (2, None, 3, 4), + (2, 3, None, 4, 5), + ], + ids=lambda x: f"shape_extra:{'x'.join(map(str, x))}", +) +@pytest.mark.parametrize("device", ["cuda"]) +def test_outer_gather(W, H, K, n, m, shape_extra, device): + """ + Test outer_gather against reference einsum. + """ + assert ( + sum(1 for x in shape_extra if x is None) == 1 + ), "There can be one and only one 'None' element in the shape_extra" + axis = shape_extra.index(None) + shape_z = shape_extra[:axis] + (n, m) + shape_extra[axis + 1 :] + shape_idx_q = shape_extra[:axis] + (K, W) + shape_idx_k = shape_extra[:axis] + (K, H) + + torch.manual_seed(42) + + # Test both without mask (None) and with random mask + for use_mask in [False, True]: + # Generate random masks + if use_mask: + idx_q_mask = torch.rand(shape_idx_q, device=device) > 0.5 + idx_k_mask = torch.rand(shape_idx_k, device=device) > 0.5 + else: + idx_q_mask = None + idx_k_mask = None + + # Inputs + z = torch.empty(shape_z, device=device, requires_grad=True) + init_tensors_uniform([z], low=-0.2, high=0.2) + + # Random one-hots + idx_q = torch.randint(0, n, shape_idx_q, device=device).requires_grad_(False) + idx_k = torch.randint(0, m, shape_idx_k, device=device).requires_grad_(False) + + one_hot_q = torch.nn.functional.one_hot(idx_q, num_classes=n).float().requires_grad_(False) + one_hot_k = torch.nn.functional.one_hot(idx_k, num_classes=m).float().requires_grad_(False) + + # Zero out one-hot at invalid positions when mask is provided + if idx_q_mask is not None: + one_hot_q = one_hot_q * idx_q_mask.unsqueeze(-1).float() + if idx_k_mask is not None: + one_hot_k = one_hot_k * idx_k_mask.unsqueeze(-1).float() + + n_axes_trailing_z = len(shape_extra[axis + 1 :]) + assert n_axes_trailing_z <= 26, "There can be at most 26 trailing axes for z" + + symbol_trailing_z = "".join(chr(ord("A") + i) for i in range(n_axes_trailing_z)) + out_expected = torch.einsum( + f"...ij{symbol_trailing_z},...kwi,...khj->...kwh{symbol_trailing_z}", z, one_hot_q, one_hot_k + ) + + # Test Forward + z_result = z.detach().clone().requires_grad_(True) + out_result = outer_gather( + z_result, one_hot_q, one_hot_k, axis=axis, one_hot_q_mask=idx_q_mask, one_hot_k_mask=idx_k_mask + ) + + assert_tensors_identical(out_result, out_expected, check_grad_fn=False, check_stride=False) + + # Test Backward + grad_out = torch.empty_like(out_expected) + init_tensors_uniform([grad_out], low=-0.2, high=0.2) + + out_expected.backward(grad_out) + + out_result.backward(grad_out.detach().clone()) + + # NOTE: scatter_add involves atomics in the CUDA backend, which leads to + # abs. error that scales with number of elements so setting tolerance is needed + torch.testing.assert_close(z.grad, z_result.grad, atol=5e-5, rtol=5e-5) + + +@pytest.mark.parametrize("device", ["cuda"]) +def test_outer_gather_empty_z_all_masked(device): + """Test outer_gather with empty z and all-False masks.""" + # Test case: z has zero elements along gather dimensions + shape_z = (2, 0, 0, 3) # Empty along axis=1 and axis+1=2 + shape_idx_q = (2, 4, 5) # K=4, W=5 + shape_idx_k = (2, 4, 8) # K=4, H=8 + axis = 1 + + z = torch.zeros(shape_z, device=device, requires_grad=True) + idx_q = torch.zeros(shape_idx_q, dtype=torch.long, device=device) + idx_k = torch.zeros(shape_idx_k, dtype=torch.long, device=device) + idx_q_mask = torch.zeros(shape_idx_q, dtype=torch.bool, device=device) # All False + idx_k_mask = torch.zeros(shape_idx_k, dtype=torch.bool, device=device) # All False + + out = OuterGather.apply(z, idx_q, idx_k, axis, idx_q_mask, idx_k_mask) + + # Output should be zeros with shape (2, 4, 5, 8, 3) + expected_shape = (2, 4, 5, 8, 3) + assert out.shape == expected_shape + assert (out == 0).all() + + # Test backward + grad_out = torch.randn_like(out) + out.backward(grad_out) + assert z.grad.shape == shape_z + + +@pytest.mark.parametrize("device", ["cuda"]) +@pytest.mark.parametrize( + "q_mask_mode,k_mask_mode,should_raise", + [ + # Both masks all-False: combined mask is all-False, no error + ("all_false", "all_false", False), + # q_mask all-False, k_mask has valid: combined mask is all-False, no error + ("all_false", "has_valid", False), + # q_mask has valid, k_mask all-False: combined mask is all-False, no error + ("has_valid", "all_false", False), + # q_mask all-False, k_mask is None: combined mask is all-False, no error + ("all_false", "none", False), + # q_mask is None, k_mask all-False: combined mask is all-False, no error + ("none", "all_false", False), + # Both masks have valid entries: combined mask has valid entries, should raise + ("has_valid", "has_valid", True), + # q_mask has valid, k_mask is None (implicitly all-True): should raise + ("has_valid", "none", True), + # q_mask is None (implicitly all-True), k_mask has valid: should raise + ("none", "has_valid", True), + # Both masks are None (implicitly all-True): should raise + ("none", "none", True), + ], + ids=lambda x: str(x), +) +def test_outer_gather_empty_z_joint_mask_validation(device, q_mask_mode, k_mask_mode, should_raise): + """Test that OuterGather validates masks jointly (outer-AND) when z is empty. + + The sanity check should only raise an error when the combined mask + (idx_q_mask outer-AND idx_k_mask) has valid entries. Separate validation + would incorrectly raise errors in asymmetric cases where one mask is all-False + but the other has valid entries. + """ + shape_z = (2, 0, 0, 3) # Empty z along axis=1 and axis+1=2 + shape_idx_q = (2, 4, 5) # K=4, W=5 + shape_idx_k = (2, 4, 8) # K=4, H=8 + axis = 1 + + z = torch.zeros(shape_z, device=device, requires_grad=True) + idx_q = torch.zeros(shape_idx_q, dtype=torch.long, device=device) + idx_k = torch.zeros(shape_idx_k, dtype=torch.long, device=device) + + # Configure q_mask based on mode + if q_mask_mode == "none": + idx_q_mask = None + elif q_mask_mode == "all_false": + idx_q_mask = torch.zeros(shape_idx_q, dtype=torch.bool, device=device) + else: # "has_valid" + idx_q_mask = torch.zeros(shape_idx_q, dtype=torch.bool, device=device) + idx_q_mask[0, 0, 0] = True # At least one valid entry + + # Configure k_mask based on mode + if k_mask_mode == "none": + idx_k_mask = None + elif k_mask_mode == "all_false": + idx_k_mask = torch.zeros(shape_idx_k, dtype=torch.bool, device=device) + else: # "has_valid" + idx_k_mask = torch.zeros(shape_idx_k, dtype=torch.bool, device=device) + idx_k_mask[0, 0, 0] = True # At least one valid entry + + if should_raise: + with pytest.raises(ValueError, match="combined mask.*contains valid entries"): + OuterGather.apply(z, idx_q, idx_k, axis, idx_q_mask, idx_k_mask) + else: + # Should not raise - the combined mask is all-False + out = OuterGather.apply(z, idx_q, idx_k, axis, idx_q_mask, idx_k_mask) + expected_shape = (2, 4, 5, 8, 3) + assert out.shape == expected_shape + assert (out == 0).all() + + +@pytest.mark.parametrize("leading_shape", [(), (2,), (2, 3)]) +@pytest.mark.parametrize("n_dim", [1, 2, 3]) +@pytest.mark.parametrize("make_empty", [False, True]) +def test_compute_interval_overlap(leading_shape, n_dim, make_empty): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # intervals_a: base starts at 0, ends at 5 + start_a = torch.zeros(leading_shape + (n_dim,), device=device, dtype=torch.long) + end_a = start_a + 5 + intervals_a = torch.stack([start_a, end_a], dim=-1) # (..., n_dim, 2) + + # intervals_b: optionally shift start beyond end_a in one dim to make empty + start_b = torch.zeros_like(start_a) + if make_empty: + start_b = start_b + 6 # greater than end_a -> empty overlap + end_b = start_b + 2 + intervals_b = torch.stack([start_b, end_b], dim=-1) + + overlap, mask = compute_interval_overlap(intervals_a, intervals_b) + + # Expected by direct max/min + expected_start = torch.maximum(start_a, start_b) + expected_end = torch.minimum(end_a, end_b) + expected_overlap = torch.stack([expected_start, expected_end], dim=-1) + expected_mask = torch.all(expected_end > expected_start, dim=-1) + + assert_tensors_identical(overlap, expected_overlap, check_grad=False, check_grad_fn=False) + assert_tensors_identical(mask, expected_mask, check_grad=False, check_grad_fn=False) + + +@pytest.mark.parametrize("leading_shape", [(), (2,), (2, 3)]) +@pytest.mark.parametrize("n_dim", [1, 2]) +@pytest.mark.parametrize("overlap", [True, False]) +def test_get_overlap_from_peers(leading_shape, n_dim, overlap): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # rank_peers carries unique ids per position + total = math.prod(leading_shape) if leading_shape else 1 + rank_peers = torch.arange(total, device=device).reshape(leading_shape if leading_shape else ()) + + # intervals_a: base start = position value, length = 4 in all dims + base = torch.arange(total, device=device).reshape(leading_shape if leading_shape else ()) + starts = torch.stack([base for _ in range(n_dim)], dim=-1) # (..., n_dim) + ends = starts + 4 + intervals_a = torch.stack([starts, ends], dim=-1) # (..., n_dim, 2) + + # intervals_b: either overlaps or not + if overlap: + # Overlap window length 2 starting at base+1: overlap is [base+1, base+3) + starts_b = torch.stack([base + 1 for _ in range(n_dim)], dim=-1) + ends_b = starts_b + 2 + expected_mask = torch.ones_like(base, dtype=torch.bool) + expected_intervals = torch.stack([starts_b, ends_b], dim=-1) + else: + # Disjoint: start beyond end_a + starts_b = ends + 1 + ends_b = starts_b + 2 + expected_mask = torch.zeros_like(base, dtype=torch.bool) + expected_intervals = torch.stack([starts_b, ends_b], dim=-1) + intervals_b = torch.stack([starts_b, ends_b], dim=-1) # (..., n_dim, 2) + + result = get_overlap_from_peers(rank_peers, intervals_a, intervals_b) + + expected_peers = rank_peers[expected_mask] + expected_intervals = expected_intervals[expected_mask] + + assert len(result) == expected_peers.numel() + for idx, item in enumerate(result): + assert item["peer"] == expected_peers.view(-1)[idx].item() + assert_tensors_identical(item["interval"], expected_intervals.view(-1, n_dim, 2)[idx], check_grad=False) + + +def parallel_assert_outer_gather(rank, grid_group_sizes, device_type, backend, env_map, dtype): + """Run distributed outer gather on each rank.""" + + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device = manager.device + device_mesh = manager.device_mesh_subgroups + device_mesh_flat = manager.device_mesh + + seed_by_rank(0, 42) + + # We test with different configurations of W, H, K, N + # Following test_distributed_window_batch_attention style + + for input_shape_extra in [ + (2 * manager.group["dp"].size(), None, 3), + (4 * manager.group["dp"].size(), None, 128), + ]: # 'None' will become the axis to be sharded in z + if sum(1 for x in input_shape_extra if x is None) != 1: + raise ValueError( + f"There can be one and only one 'None' element in the input_shape but got {input_shape_extra}" + ) + axis = input_shape_extra.index(None) + + for W, H, K_per_rank, N_per_rank in [(4, 8, 1, 8), (8, 16, 2, 16), (32, 128, 3, 100)]: + N = N_per_rank * manager.group["cp_axis_0"].size() + M = N_per_rank * manager.group["cp_axis_1"].size() + + # Construct shapes + z_shape = input_shape_extra[:axis] + (N, M) + input_shape_extra[axis + 1 :] + + # Construct placements + z_placements = [Shard(0), Shard(axis), Shard(axis + 1)] + + for idx_use_flat_device_mesh in [True, False]: + if idx_use_flat_device_mesh: + K = K_per_rank * manager.group["cp"].size() + device_mesh_idx = device_mesh_flat + # equivalent to len(idx_n_shape) - 2 + placements_idx = [Shard(0), Shard(len(input_shape_extra[:axis]))] + else: + K = K_per_rank * manager.group["cp_axis_0"].size() + device_mesh_idx = device_mesh + # equivalent to len(idx_n_shape) - 2 + placements_idx = [Shard(0), Shard(len(input_shape_extra[:axis])), Replicate()] + + idx_n_shape = input_shape_extra[:axis] + (K, W) + idx_m_shape = input_shape_extra[:axis] + (K, H) + + # Test both without mask (None) and with random mask + for use_mask in [False, True]: + label = ( + f"z_shape:{z_shape}, idx_n_shape:{idx_n_shape}, idx_m_shape:{idx_m_shape}, " + f"z_placements:{z_placements}, placements_idx:{placements_idx}, axis:{axis}, " + f"idx_use_flat_device_mesh:{idx_use_flat_device_mesh}" + ) + if use_mask: + label += " (masked)" + idx_n_mask_global = torch.rand(idx_n_shape, device=device) > 0.5 + idx_m_mask_global = torch.rand(idx_m_shape, device=device) > 0.5 + + # Set second half of K dimension to all-False to simulate + # ranks with all-empty data when K is sharded across mesh + idx_n_mask_global[..., idx_n_shape[-2] // 2 :, :] = False + idx_m_mask_global[..., idx_m_shape[-2] // 2 :, :] = False + else: + idx_n_mask_global = None + idx_m_mask_global = None + + # Create global data + z_global = torch.randn(z_shape, dtype=dtype, device=device, requires_grad=True) + + idx_n_global = torch.randint(0, N, idx_n_shape, device=device, requires_grad=False) + + idx_m_global = torch.randint(0, M, idx_m_shape, device=device, requires_grad=False) + + # Reference + out_ref = OuterGather.apply( + z_global, idx_n_global, idx_m_global, axis, idx_n_mask_global, idx_m_mask_global + ) + grad_out = torch.randn_like(out_ref) + + # Use autograd to compute ref grad + # We clone z_global to avoid in-place modification issues if any (though here we just read) + out_ref.backward(grad_out) + grad_z_ref = z_global.grad + + z_dtensor = distribute_tensor(z_global.detach().clone(), device_mesh, z_placements).requires_grad_( + True + ) + idx_n_dtensor = distribute_tensor(idx_n_global, device_mesh_idx, placements_idx) + idx_m_dtensor = distribute_tensor(idx_m_global, device_mesh_idx, placements_idx) + idx_n_mask_dtensor = ( + distribute_tensor(idx_n_mask_global, device_mesh_idx, placements_idx) + if idx_n_mask_global is not None + else None + ) + idx_m_mask_dtensor = ( + distribute_tensor(idx_m_mask_global, device_mesh_idx, placements_idx) + if idx_m_mask_global is not None + else None + ) + + # Forward + # The data in idx_n_global and idx_m_global are not actually contiguous due to randint + # but the code still works despite being inefficient for this test setting + out_dtensor = distributed_outer_gather( + z_dtensor, + idx_n_dtensor, + idx_m_dtensor, + axis=axis, + are_ids_contiguous=True, + idx_n_mask=idx_n_mask_dtensor, + idx_m_mask=idx_m_mask_dtensor, + ) + + # Verify Forward + out_local = out_dtensor.full_tensor().requires_grad_(True) + assert_tensors_identical( + out_local, + out_ref, + check_stride=False, + check_grad=False, + check_grad_fn=False, + msg=lambda m: f"{label} fwd output mismatch:\n {m}", + ) + + # Backward + grad_out_dtensor = distribute_tensor( + grad_out.detach().clone(), out_dtensor.device_mesh, out_dtensor.placements + ) + + out_dtensor.backward(grad_out_dtensor) + + # Verify Backward + grad_z_local = z_dtensor.grad.full_tensor() + assert_tensors_identical( + grad_z_local, + grad_z_ref, + check_grad=False, + check_grad_fn=False, + rtol=1e-10, # FP64 should be very precise but small accumulation errors can happen + atol=1e-10, + msg=lambda m: f"{label} bwd input gradient mismatch:\n {m}", + ) + + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_distributed_outer_gather(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_outer_gather, world_size, grid_group_sizes, device_type, backend, env_per_rank, torch.float64 + ) + + +def parallel_assert_distributed_outer_gather_w_one_hot(rank, grid_group_sizes, device_type, backend, env_map, dtype): + """Run distributed outer gather with one-hot indices on each rank.""" + + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device = manager.device + device_mesh = manager.device_mesh_subgroups + device_mesh_flat = manager.device_mesh + + seed_by_rank(0, 42) + + for input_shape_extra in [ + (2 * manager.group["dp"].size(), None, 3), + (4 * manager.group["dp"].size(), None, 128), + ]: # 'None' will become the axis to be sharded in z + if sum(1 for x in input_shape_extra if x is None) != 1: + raise ValueError( + f"There can be one and only one 'None' element in the input_shape but got {input_shape_extra}" + ) + axis = input_shape_extra.index(None) + + for W, H, K_per_rank, N_per_rank in [(4, 8, 1, 8), (8, 16, 2, 16), (32, 128, 3, 100)]: + N = N_per_rank * manager.group["cp_axis_0"].size() + M = N_per_rank * manager.group["cp_axis_1"].size() + + # Construct shapes + z_shape = input_shape_extra[:axis] + (N, M) + input_shape_extra[axis + 1 :] + + # Construct placements + z_placements = [Shard(0), Shard(axis), Shard(axis + 1)] + + for idx_use_flat_device_mesh in [True, False]: + if idx_use_flat_device_mesh: + K = K_per_rank * manager.group["cp"].size() + device_mesh_idx = device_mesh_flat + placements_idx = [Shard(0), Shard(len(input_shape_extra[:axis]))] + else: + K = K_per_rank * manager.group["cp_axis_0"].size() + device_mesh_idx = device_mesh + placements_idx = [Shard(0), Shard(len(input_shape_extra[:axis])), Replicate()] + + idx_n_shape = input_shape_extra[:axis] + (K, W) + idx_m_shape = input_shape_extra[:axis] + (K, H) + + # Test both without mask (None) and with random mask + for use_mask in [False, True]: + label = ( + f"z_shape:{z_shape}, idx_n_shape:{idx_n_shape}, idx_m_shape:{idx_m_shape}, " + f"z_placements:{z_placements}, placements_idx:{placements_idx}, axis:{axis}, " + f"idx_use_flat_device_mesh:{idx_use_flat_device_mesh}" + ) + if use_mask: + label += " (masked)" + idx_n_mask_global = torch.rand(idx_n_shape, device=device) > 0.5 + idx_m_mask_global = torch.rand(idx_m_shape, device=device) > 0.5 + + # Set second half of K dimension to all-False to simulate + # ranks with all-empty data when K is sharded across mesh + idx_n_mask_global[..., idx_n_shape[-2] // 2 :, :] = False + idx_m_mask_global[..., idx_m_shape[-2] // 2 :, :] = False + else: + idx_n_mask_global = None + idx_m_mask_global = None + + # Create global data + z_global = torch.randn(z_shape, dtype=dtype, device=device, requires_grad=True) + + # Create sorted index tensors then convert to one-hot + idx_n_indices = torch.randint(0, N, idx_n_shape, device=device) + idx_m_indices = torch.randint(0, M, idx_m_shape, device=device) + + one_hot_n = F.one_hot(idx_n_indices, num_classes=N).to(dtype=z_global.dtype) + one_hot_m = F.one_hot(idx_m_indices, num_classes=M).to(dtype=z_global.dtype) + + # Zero out one-hot at invalid positions when mask is provided + if idx_n_mask_global is not None: + one_hot_n = one_hot_n * idx_n_mask_global.unsqueeze(-1).to(dtype=z_global.dtype) + if idx_m_mask_global is not None: + one_hot_m = one_hot_m * idx_m_mask_global.unsqueeze(-1).to(dtype=z_global.dtype) + + # Reference using einsum with one-hot + feature_shape = z_shape[axis + 2 :] + feature_flat = int(math.prod(feature_shape)) if feature_shape else 1 + z_flat = z_global.view(*z_shape[: axis + 2], feature_flat) + ref_flat = torch.einsum("...nmf,...kwn,...khm->...kwhf", z_flat, one_hot_n, one_hot_m) + out_ref = ref_flat.view(*z_shape[:axis], K, W, H, *feature_shape) + + grad_out = torch.randn_like(out_ref) + + # Use autograd to compute ref grad + out_ref.backward(grad_out) + grad_z_ref = z_global.grad + + z_dtensor = distribute_tensor(z_global.detach().clone(), device_mesh, z_placements).requires_grad_( + True + ) + one_hot_n_dtensor = distribute_tensor(one_hot_n, device_mesh_idx, placements_idx) + one_hot_m_dtensor = distribute_tensor(one_hot_m, device_mesh_idx, placements_idx) + + # Convert one-hot DTensors to index DTensors via shardwise_argmax + idx_n_dtensor = shardwise_argmax(one_hot_n_dtensor, dim=-1, keepdim=False) + idx_m_dtensor = shardwise_argmax(one_hot_m_dtensor, dim=-1, keepdim=False) + + idx_n_mask_dtensor = ( + distribute_tensor(idx_n_mask_global, device_mesh_idx, placements_idx) + if idx_n_mask_global is not None + else None + ) + idx_m_mask_dtensor = ( + distribute_tensor(idx_m_mask_global, device_mesh_idx, placements_idx) + if idx_m_mask_global is not None + else None + ) + + # Forward + # The data in idx_n_global and idx_m_global are not actually contiguous due to randint + # but the code still works despite being inefficient for this test setting + out_dtensor = distributed_outer_gather( + z_dtensor, + idx_n_dtensor, + idx_m_dtensor, + axis=axis, + are_ids_contiguous=True, + idx_n_mask=idx_n_mask_dtensor, + idx_m_mask=idx_m_mask_dtensor, + ) + + # Verify Forward + out_local = out_dtensor.full_tensor().requires_grad_(True) + assert_tensors_identical( + out_local, + out_ref, + check_stride=False, + check_grad=False, + check_grad_fn=False, + msg=lambda m: f"{label} fwd output mismatch:\n {m}", + ) + + # Backward + grad_out_dtensor = distribute_tensor( + grad_out.detach().clone(), out_dtensor.device_mesh, out_dtensor.placements + ) + + out_dtensor.backward(grad_out_dtensor) + + # Verify Backward + grad_z_local = z_dtensor.grad.full_tensor() + assert_tensors_identical( + grad_z_local, + grad_z_ref, + check_grad=False, + check_grad_fn=False, + rtol=1e-10, + atol=1e-10, + msg=lambda m: f"{label} bwd input gradient mismatch:\n {m}", + ) + + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_outer_distributed_gather_w_one_hot(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_distributed_outer_gather_w_one_hot, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + torch.float64, + ) diff --git a/tests/distributed/model/layers/test_dtensor_outer_op.py b/tests/distributed/model/layers/test_dtensor_outer_op.py new file mode 100644 index 000000000..cea96c469 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_outer_op.py @@ -0,0 +1,508 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from math import isqrt +from typing import Dict, Optional + +import pytest +import torch +from torch.distributed.tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.outer_op import OuterOp, replicate_to_shard_outer_op +from boltz.testing.utils import assert_tensors_identical, seed_by_rank, spawn_multiprocessing + + +def compute_global_expectation(shape_input, axis, op: OuterOp, device, asymmetric: bool = False): + """Compute global expectation using standard PyTorch operations.""" + # Create input tensor + if op == OuterOp.BITAND: + input_tensor = torch.randint(0, 2, shape_input, device=device, dtype=torch.bool) + else: + input_tensor = torch.rand(*shape_input, device=device, requires_grad=True) + + if asymmetric: + if op == OuterOp.BITAND: + input_t = torch.randint(0, 2, shape_input, device=device, dtype=torch.bool) + else: + input_t = torch.rand(*shape_input, device=device, requires_grad=True) + else: + input_t = input_tensor + + # Compute on global tensors using native PyTorch operations + # Replicate the logic from distributed_outer_op for non-distributed case + input_expanded = input_tensor.unsqueeze(axis + 1) + input_t_expanded = input_t.unsqueeze(axis + 1) + input_t_transposed = input_t_expanded.transpose(axis, axis + 1) + + if op == OuterOp.SUM: + y = input_expanded + input_t_transposed + elif op == OuterOp.SUBTRACT: + y = input_expanded - input_t_transposed + elif op == OuterOp.EQUAL: + y = input_expanded == input_t_transposed + elif op == OuterOp.BITAND: + y = input_expanded & input_t_transposed + elif op == OuterOp.PROD: + y = input_expanded * input_t_transposed + elif op == OuterOp.CDIST: + y = torch.cdist(input_tensor, input_t, p=2) + + if op == OuterOp.EQUAL or op == OuterOp.BITAND: + # Boolean output can't be backpropagated + return ( + input_tensor.detach().clone(), + input_t.detach().clone() if asymmetric else None, + y.detach().clone(), + None, # input_grad + None, # input_t_grad + None, # dy + ) + else: + # Create gradients for backward pass + dy = torch.rand_like(y) + + # Backward pass on global tensors + y.backward(dy) + + # Collect input gradients + input_grad = input_tensor.grad.detach().clone() + input_t_grad = input_t.grad.detach().clone() if asymmetric else None + + return ( + input_tensor.detach().clone(), + input_t.detach().clone() if asymmetric else None, + y.detach().clone(), + input_grad, + input_t_grad, + dy.detach().clone(), + ) + + +def compute_dtensor_native_outer_op( + input_global: torch.Tensor, + input_t_global: torch.Tensor | None, + dy_global: torch.Tensor | None, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + output_placements: tuple[Placement, ...], + axis: int, + op: OuterOp, +) -> tuple[list[DTensor], DTensor]: + """Compute DTensor native operations for comparison.""" + # Create DTensor native inputs + input_dtensor = distribute_tensor(input_global.detach().clone(), device_mesh, input_placements).requires_grad_( + op != OuterOp.BITAND + ) + + if input_t_global is not None: + input_t_dtensor = distribute_tensor( + input_t_global.detach().clone(), device_mesh, input_placements + ).requires_grad_(op != OuterOp.BITAND) + else: + input_t_dtensor = input_dtensor # Symmetric case + + # Forward pass with native DTensor operations (manual outer operation) + input_expanded = input_dtensor.unsqueeze(axis + 1) + input_t_expanded = input_t_dtensor.unsqueeze(axis + 1) + + placements_y = (Shard(0), Shard(1), Shard(2)) + + # it's necessary to redistribute the input_t_expanded to the placements_y + # in order to avoid a runtime error raise from the native DTensor backward: + # redistribute S(1) -> R in backward is not supported + input_t_expanded_transposed = input_t_expanded.transpose(axis, axis + 1).redistribute(placements=placements_y) + + if op == OuterOp.SUM: + y_dtensor_result_native = input_expanded + input_t_expanded_transposed + elif op == OuterOp.SUBTRACT: + y_dtensor_result_native = input_expanded - input_t_expanded_transposed + elif op == OuterOp.EQUAL: + y_dtensor_result_native = input_expanded == input_t_expanded_transposed + elif op == OuterOp.BITAND: + y_dtensor_result_native = input_expanded & input_t_expanded_transposed + elif op == OuterOp.PROD: + y_dtensor_result_native = input_expanded * input_t_expanded_transposed + elif op == OuterOp.CDIST: + y_dtensor_result_with_zeros = torch.sum((input_expanded - input_t_expanded_transposed) ** 2, dim=-1) + # to avoid the diagonal zeros causing backward pass nan + y_dtensor_result_native = ( + y_dtensor_result_with_zeros + torch.finfo(y_dtensor_result_with_zeros.dtype).tiny + ).sqrt() + + if op == OuterOp.EQUAL or op == OuterOp.BITAND or dy_global is None: + # No backward pass for EQUAL/BITAND operation + return [], y_dtensor_result_native + + # Backward pass with native DTensor op + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, output_placements) + y_dtensor_result_native.backward(dy_dtensor) + + # redistribute here to avoid comparing Partial() placements with Replicate() placements, the + # latter of which is in the replicate_to_shard_outer_op due to the explicit all_reduce op + inputs_grad_dtensor = [input_dtensor.grad.redistribute(placements=input_placements)] + if input_t_global is not None: + inputs_grad_dtensor.append(input_t_dtensor.grad.redistribute(placements=input_placements)) + + return inputs_grad_dtensor, y_dtensor_result_native + + +def compute_replicate_to_shard_outer_op_with_validation( + input_global: torch.Tensor, + input_t_global: torch.Tensor | None, + dy_global: torch.Tensor | None, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + output_placements: tuple[Placement, ...], + transpose_comm: TransposeComm, + axis: int, + op: OuterOp, + label_test_case: str, +) -> tuple[DTensor, DTensor, DTensor | None, DTensor | None]: + """ + Compute replicate_to_shard_outer_op forward and backward pass with input validation checks. + + Returns: + y_dtensor_result: Forward pass result + input_dtensor: Input tensor with computed gradient + input_t_dtensor: Second input tensor with computed gradient (if asymmetric) + dy_dtensor: Distributed upstream gradient + """ + # Create DTensor inputs + input_dtensor = distribute_tensor(input_global.detach().clone(), device_mesh, input_placements).requires_grad_( + op != OuterOp.BITAND + ) + + if input_t_global is not None: + input_t_dtensor = distribute_tensor( + input_t_global.detach().clone(), device_mesh, input_placements + ).requires_grad_(op != OuterOp.BITAND) + else: + input_t_dtensor = None + + input_dtensor_copy = input_dtensor.detach().clone().requires_grad_(op != OuterOp.BITAND) + input_t_dtensor_copy = ( + input_t_dtensor.detach().clone().requires_grad_(op != OuterOp.BITAND) if input_t_dtensor is not None else None + ) + + # Compute on distributed tensors using replicate_to_shard_outer_op + y_dtensor_result = replicate_to_shard_outer_op(input_dtensor, op, axis, transpose_comm, input_t_dtensor) + + # Verify no change to the forward inputs + assert_tensors_identical( + input_dtensor.to_local(), input_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False + ) + if input_t_dtensor is not None: + assert_tensors_identical( + input_t_dtensor.to_local(), input_t_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False + ) + + # Verify output placements + assert y_dtensor_result.placements == output_placements, f"{label_test_case} output placements mismatch" + + if op == OuterOp.EQUAL or op == OuterOp.BITAND or dy_global is None: + # No backward pass for EQUAL/BITAND operation + return y_dtensor_result, input_dtensor, input_t_dtensor, None + + # Distribute the upstream adjoint for backward pass + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, output_placements) + + # Perform backward pass + dy_dtensor_copy = dy_dtensor.detach().clone() + y_dtensor_result.backward(dy_dtensor) + + # Verify no change to the backward input + assert_tensors_identical(dy_dtensor.to_local(), dy_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # Verify input gradient placements are consistent with input placements + assert ( + input_dtensor.grad.placements == input_placements + ), f"{label_test_case} inconsistent input gradient placements with input placements" + + if input_t_dtensor is not None: + assert ( + input_t_dtensor.grad.placements == input_placements + ), f"{label_test_case} inconsistent input_t gradient placements with input placements" + + return y_dtensor_result, input_dtensor, input_t_dtensor, dy_dtensor + + +def parallel_assert_replicate_to_shard_outer_op( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Each rank uses the same seed to generate the same input tensors + seed_by_rank(0, seed=42) + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters + # Input placements: (Shard(0), Shard(1), Replicate()) + # Output placements: (Shard(0), Shard(1), Shard(2)) + batch_size = 2 * len(manager.group_ranks["dp"]) + seq_len_dim0 = size_ring * 4 # Sharded along dim 0 + embed_dim = 8 + + shape_input = (batch_size, seq_len_dim0, embed_dim) + input_placements = (Shard(0), Shard(1), Replicate()) + output_placements = (Shard(0), Shard(1), Shard(2)) + + # Create transpose communication + layout_map = manager.layout_subgroups["cp"] + transpose_comm = TransposeComm(manager.group["cp"], layout_map) + + # Test all OuterOp cases and both symmetric/asymmetric + for op in [OuterOp.SUM, OuterOp.SUBTRACT, OuterOp.EQUAL, OuterOp.BITAND, OuterOp.PROD]: + for asymmetric in [False, True]: + axis = 1 + label_test_case = ( + f"op={op}, asymmetric={asymmetric}, axis={axis}, " + f"input_placements={input_placements}, output_placements={output_placements}\n" + ) + + # Compute global expectations + ( + input_global, + input_t_global, + y_expected_global, + input_grad_expected_global, + input_t_grad_expected_global, + dy_global, + ) = compute_global_expectation(shape_input, axis, op, manager.device, asymmetric) + + # Use DTensor native op as an alternative reference (always call, even for EQUAL) + ( + inputs_grad_dtensor_native, + y_dtensor_result_native, + ) = compute_dtensor_native_outer_op( + input_global, + input_t_global, + dy_global, # This will be None for EQUAL operations + manager.device_mesh_subgroups, + input_placements, + output_placements, + axis, + op, + ) + + # Compute replicate_to_shard_outer_op forward and backward with validation + y_dtensor_result, input_dtensor, input_t_dtensor, dy_dtensor = ( + compute_replicate_to_shard_outer_op_with_validation( + input_global, + input_t_global, + dy_global, + manager.device_mesh_subgroups, + input_placements, + output_placements, + transpose_comm, + axis, + op, + label_test_case, + ) + ) + + # =================================================================== + # BLOCK 1: Check against DTensor native reference + # =================================================================== + + # Check metadata against DTensor native (for all operations) + assert ( + y_dtensor_result.placements == y_dtensor_result_native.placements + ), f"{label_test_case} placements mismatch" + assert y_dtensor_result.shape == y_dtensor_result_native.shape, f"{label_test_case} shape mismatch" + assert y_dtensor_result.stride() == y_dtensor_result_native.stride(), f"{label_test_case} stride mismatch" + + # Compare forward result with native DTensor op (for all operations) + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_result_native.to_local(), + msg=lambda m: f"{label_test_case} {m}", + ) + + # Compare global tensors between replicate_to_shard_outer_op and native DTensor results + y_result_global = y_dtensor_result.full_tensor() + y_result_global_native = y_dtensor_result_native.full_tensor() + + torch.testing.assert_close( + y_result_global, + y_result_global_native, + msg=lambda m: f"{label_test_case} output vs native: {m}", + ) + + # Only check gradients for non-EQUAL and non-BITAND operations + if op != OuterOp.EQUAL and op != OuterOp.BITAND: + # Assert input gradients' metadata and values against DTensor native + # Input gradient comparison + assert ( + input_dtensor.grad.placements == inputs_grad_dtensor_native[0].placements + ), f"{label_test_case} input gradient placements mismatch" + assert ( + input_dtensor.grad.shape == inputs_grad_dtensor_native[0].shape + ), f"{label_test_case} input gradient shape mismatch" + assert ( + input_dtensor.grad.stride() == inputs_grad_dtensor_native[0].stride() + ), f"{label_test_case} input gradient stride mismatch" + + torch.testing.assert_close( + input_dtensor.grad.to_local(), + inputs_grad_dtensor_native[0].to_local(), + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + torch.testing.assert_close( + input_dtensor.grad.full_tensor(), + inputs_grad_dtensor_native[0].full_tensor(), + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + # Input_t gradient comparison (asymmetric case) + if asymmetric and len(inputs_grad_dtensor_native) > 1: + assert ( + input_t_dtensor.grad.placements == inputs_grad_dtensor_native[1].placements + ), f"{label_test_case} input_t gradient placements mismatch" + assert ( + input_t_dtensor.grad.shape == inputs_grad_dtensor_native[1].shape + ), f"{label_test_case} input_t gradient shape mismatch" + assert ( + input_t_dtensor.grad.stride() == inputs_grad_dtensor_native[1].stride() + ), f"{label_test_case} input_t gradient stride mismatch" + + torch.testing.assert_close( + input_t_dtensor.grad.to_local(), + inputs_grad_dtensor_native[1].to_local(), + msg=lambda m: f"{label_test_case} input_t gradient mismatch: {m}", + ) + + torch.testing.assert_close( + input_t_dtensor.grad.full_tensor(), + inputs_grad_dtensor_native[1].full_tensor(), + msg=lambda m: f"{label_test_case} input_t gradient mismatch: {m}", + ) + + # =================================================================== + # BLOCK 2: Check against global serial expectation + # =================================================================== + y_dtensor_expected = distribute_tensor( + y_expected_global, manager.device_mesh_subgroups, y_dtensor_result.placements + ) + + # Compare results with expected local shards + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_expected.to_local(), + msg=lambda m: f"{label_test_case} forward result: {m}", + ) + + # Compare forward result with global expectation + y_result_global = y_dtensor_result.full_tensor() + torch.testing.assert_close( + y_result_global, + y_expected_global, + msg=lambda m: f"{label_test_case} forward result vs global expectation: {m}", + ) + + if op != OuterOp.EQUAL and op != OuterOp.BITAND: + # Check input gradient + input_grad_expected_dtensor = distribute_tensor( + input_grad_expected_global, manager.device_mesh_subgroups, input_placements + ) + + torch.testing.assert_close( + input_dtensor.grad.to_local(), + input_grad_expected_dtensor.to_local(), + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + torch.testing.assert_close( + input_dtensor.grad.full_tensor(), + input_grad_expected_global, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + # Check input_t gradient (asymmetric case) + if asymmetric and input_t_grad_expected_global is not None: + input_t_grad_expected_dtensor = distribute_tensor( + input_t_grad_expected_global, manager.device_mesh_subgroups, input_placements + ) + + torch.testing.assert_close( + input_t_dtensor.grad.to_local(), + input_t_grad_expected_dtensor.to_local(), + msg=lambda m: f"{label_test_case} input_t gradient vs global expectation: {m}", + ) + + torch.testing.assert_close( + input_t_dtensor.grad.full_tensor(), + input_t_grad_expected_global, + msg=lambda m: f"{label_test_case} input_t gradient vs global expectation: {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_replicate_to_shard_outer_op(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_replicate_to_shard_outer_op, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) diff --git a/tests/distributed/model/layers/test_dtensor_outer_product_mean.py b/tests/distributed/model/layers/test_dtensor_outer_product_mean.py new file mode 100644 index 000000000..e5f59c293 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_outer_product_mean.py @@ -0,0 +1,377 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import pytest +import torch +from torch.distributed.tensor import Shard, distribute_tensor + +from boltz.distributed.comm import Ring2DComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.outer_product_mean import OuterProductMean as DistributedOuterProductMean +from boltz.model.layers.outer_product_mean import OuterProductMean as SerialOuterProductMean +from boltz.testing.utils import ( + assert_all_identical, + assert_no_percentile_upshift, + assert_tensors_identical, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_assert_outer_prod_mean( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + C_in, + C_hidden, + C_out, + layer_state_dict, + input_global_host, + mask_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_expected_global_host, + grad_params_expected_global_host, + output_global_fp32_host: torch.Tensor | None = None, + d_input_global_fp32_host: torch.Tensor | None = None, + grad_params_fp32_global_host: dict[str, torch.Tensor] | None = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + if torch.finfo(dtype).resolution < torch.finfo(output_expected_global_host.dtype).resolution: + raise ValueError( + f"Target dtype {dtype} has higher precision than reference output's dtype {output_expected_global_host.dtype}" + ) + + if ((output_global_fp32_host is None) != (d_input_global_fp32_host is None)) or ( + (output_global_fp32_host is not None) != (grad_params_fp32_global_host is not None) + ): + raise ValueError( + "output_global_fp32_host, d_input_global_fp32_host, and grad_params_fp32_global_host must be either all None or all not None" + ) + + check_error_hist = output_global_fp32_host is not None + + layout_map = manager.layout_subgroups["cp"] + ring_comm = Ring2DComm(manager.group["cp"], manager.subgroups["cp"][0], layout_map) + + module_serial = SerialOuterProductMean(C_in, C_hidden, C_out).to(dtype=dtype) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.to(device=manager.device) + module = DistributedOuterProductMean(module_serial, manager.device_mesh_subgroups, ring_comm) + module.train() + + placements_input = (Shard(0), Shard(1), Shard(2)) + # Omitting the src_data_rank parameter to distribute_tensor means the data from rank 0 is sharded + input_dtensor = distribute_tensor( + input_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ).requires_grad_(True) + mask_dtensor = distribute_tensor( + mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ) + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ) + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + src_data_rank=None, + ) + d_input_expected_dtensor = distribute_tensor( + d_input_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + src_data_rank=None, + ) + + input_dtensor_copy = input_dtensor.detach().clone().requires_grad_(True) + mask_dtensor_copy = mask_dtensor.detach().clone() + + if check_error_hist: + output_dtensor_result = module(input_dtensor, mask_dtensor) + output_dtensor_result.backward(d_output_expected_dtensor) + + output_fp32_dtensor = distribute_tensor( + output_global_fp32_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + src_data_rank=None, + ) + + d_input_fp32_dtensor = distribute_tensor( + d_input_global_fp32_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + src_data_rank=None, + ) + + assert_no_percentile_upshift( + output_dtensor_result.to_local(), + output_expected_dtensor.to_local(), + output_fp32_dtensor.to_local(), + names_input=("output_cp_fp32", "output_serial_fp64", "output_serial_fp32"), + ) + + assert_no_percentile_upshift( + input_dtensor.grad.to_local(), + d_input_expected_dtensor.to_local(), + d_input_fp32_dtensor.to_local(), + names_input=("d_input_cp_fp32", "d_input_serial_fp64", "d_input_serial_fp32"), + ) + + for name, grad_param_expected_global in grad_params_expected_global_host.items(): + grad_param_result_global = get_param_by_key(module, name).grad.full_tensor().cpu() + assert_no_percentile_upshift( + grad_param_result_global, + grad_param_expected_global.to(dtype=grad_param_result_global.dtype), + grad_params_fp32_global_host[name], + names_input=(f"d_{name}_cp_fp32", f"d_{name}_serial_fp64", f"d_{name}_serial_fp32"), + ) + else: + output_dtensor_result = module(input_dtensor, mask_dtensor) + + # no modification on the input + assert_tensors_identical( + input_dtensor_copy.to_local(), input_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(mask_dtensor_copy.to_local(), mask_dtensor.to_local()) + + # test for consistent forward results with the single-device + assert ( + output_dtensor_result.shape == output_expected_dtensor.shape + ), f"Output shape mismatch: {output_dtensor_result.shape} != {output_expected_dtensor.shape}" + assert ( + output_dtensor_result.stride() == output_expected_dtensor.stride() + ), f"Output stride mismatch: {output_dtensor_result.stride()} != {output_expected_dtensor.stride()}" + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # check backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # backward pass should not modify the upstream adjoint + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + assert ( + input_dtensor.grad.shape == d_input_expected_dtensor.shape + ), f"Gradient shape mismatch: {input_dtensor.grad.shape} != {d_input_expected_dtensor.shape}" + assert ( + input_dtensor.grad.stride() == d_input_expected_dtensor.stride() + ), f"Gradient stride mismatch: {input_dtensor.grad.stride()} != {d_input_expected_dtensor.stride()}" + torch.testing.assert_close(input_dtensor.grad.to_local(), d_input_expected_dtensor.to_local()) + + # check gradient of the weight + grad_params_result_dtensors = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in grad_params_expected_global_host: + # do an extra check here to make sure the parallel computation don't result in extra gradients + raise ValueError(f"Parameter {name} has a resulting gradient but it is not in the reference module") + grad_params_result_dtensors[name] = param.grad + + for name, grad_param_expected_global_host in grad_params_expected_global_host.items(): + assert name in grad_params_result_dtensors, f"Parameter {name}'s gradient is not found in result gradients" + grad_params_result = grad_params_result_dtensors[name] + assert ( + grad_params_result.shape == grad_param_expected_global_host.shape + ), f"Gradient shape mismatch: {grad_params_result.shape} != {grad_param_expected_global_host.shape}" + assert ( + grad_params_result.stride() == grad_param_expected_global_host.stride() + ), f"Gradient stride mismatch: {grad_params_result.stride()} != {grad_param_expected_global_host.stride()}" + grad_params_result_global = grad_params_result.full_tensor() + torch.testing.assert_close(grad_params_result_global.cpu(), grad_param_expected_global_host.to(dtype=dtype)) + assert_all_identical(grad_params_result_global, manager.group["cp"]) + + # check the results with the full tensor to make sure the module's output and + # and gradients can be gathered into the consistent results with the single-device + input_global_result = input_dtensor.full_tensor() + mask_global_result = mask_dtensor.full_tensor() + output_global_result = output_dtensor_result.full_tensor() + d_input_global_result = input_dtensor.grad.full_tensor() + + torch.testing.assert_close(input_global_result.cpu(), input_global_host.to(dtype=dtype)) + torch.testing.assert_close(mask_global_result.cpu(), mask_global_host.to(dtype=dtype)) + torch.testing.assert_close(output_global_result.cpu(), output_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_global_result.cpu(), d_input_expected_global_host.to(dtype=dtype)) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, dtype, check_error_hist", + ( + params_test := [ + ## CUDA tests (2 GPUs) + # (((2, (1, 1)), True, "cuda", "ENV"), torch.float32, True), + # (((2, (1, 1)), True, "cuda", "ENV"), torch.float64, True), + ## CUDA tests (8 GPUs) + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, False), + ## CPU tests + (((1, (3, 3)), True, "cuda", "ENV"), torch.float32, False), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}, check_error_hist:{x[2]}" + for x in params_test + ], +) +def test_outer_product_parallel(setup_env, dtype, check_error_hist): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + # dtype is the dtype used by the parallel computation + # check_error_hist determine whether to compare the error histograms between + # (CP_in_FP32, serial_in_FP64) and (serial_in_FP32, serial_in_FP64) + # Typically, check_error_hist will use large input dimensions to emulate + # the real-world use cases. Same with dtype==torch.float64. + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + if check_error_hist: + if grid_group_sizes["dp"] > 2: + pytest.skip("skip error histogram check for dp > 1 to save test time") + + # For float64 and error histogram check, we use a realistic model and input size + # with heavier computation to test the numerical stability. On the other hand, + # a smaller model and input size incur less numerical error accumulation to allow + # a larger range of input values to detect logical bugs inexpensively by using + # smaller dimensions. + test_large_model = check_error_hist or dtype == torch.float64 + + size_ring = grid_group_sizes["cp"][0] + + B = 2 * grid_group_sizes["dp"] + if test_large_model: + N = size_ring * 128 + S = size_ring * 128 + C_in = 64 + C_hidden = 32 + C_out = 128 + min_val_init = -5e-2 + max_val_init = 5e-2 + else: + N = size_ring * 2 + S = size_ring * 3 + C_in = 3 + C_hidden = 5 + C_out = 3 + min_val_init = -0.5 + max_val_init = 0.5 + + seed = 42 + seed_by_rank(0, seed=seed) + + # compute reference results with FP64 + input_global_fp64 = torch.empty((B, S, N, C_in), dtype=torch.float64, requires_grad=True, device=device_type) + mask_global_fp64 = torch.ones((B, S, N), dtype=torch.float64, requires_grad=False, device=device_type) + mask_global_fp64[0, (S // size_ring) :, :] = 0 + mask_global_fp64[0, :, (N // size_ring) :] = 0 + reference_module = SerialOuterProductMean(C_in, C_hidden, C_out).to(dtype=torch.float64) + # The output activation and gradient of the layer weights typically increase by 2 to 3 orders of magnitude, + # where the ULP would be too large and numerical error distribution becomes very wide, i.e., we would have + # very unpredictable numerical errors. That would make the test results very noisy and not very useful to + # detect logical bugs in the code. To avoid this, we use a smaller range for the input and layer weights. + init_tensors_uniform([input_global_fp64], low=min_val_init, high=max_val_init) + init_module_params_uniform(reference_module, low=min_val_init, high=max_val_init) + layer_state_dict_fp64 = reference_module.state_dict() + reference_module = reference_module.to(device=device_type).train() + + output_expected_global_fp64 = reference_module(input_global_fp64, mask_global_fp64) + d_output_expected_global_fp64 = torch.rand_like(output_expected_global_fp64) + output_expected_global_fp64.backward(d_output_expected_global_fp64) + + grad_params_fp64_expected_global_host = { + name: param.grad.detach().clone().cpu() for name, param in reference_module.named_parameters() + } + + if check_error_hist: + input_global_fp32 = input_global_fp64.detach().clone().to(dtype=torch.float32).requires_grad_(True) + mask_global_fp32 = mask_global_fp64.detach().clone().to(dtype=torch.float32).requires_grad_(False) + reference_module_fp32 = SerialOuterProductMean(C_in, C_hidden, C_out).to(dtype=torch.float32) + reference_module_fp32.load_state_dict(layer_state_dict_fp64) + reference_module_fp32 = reference_module_fp32.to(device=device_type).train() + output_global_fp32 = reference_module_fp32(input_global_fp32, mask_global_fp32) + d_output_expected_global_fp32 = d_output_expected_global_fp64.to(dtype=torch.float32) + output_global_fp32.backward(d_output_expected_global_fp32) + + output_global_fp32_host = output_global_fp32.detach().clone().cpu() + d_input_global_fp32_host = input_global_fp32.grad.detach().clone().cpu() + grad_params_fp32_global_host = { + name: param.grad.detach().clone().cpu() for name, param in reference_module_fp32.named_parameters() + } + else: + output_global_fp32_host = None + d_input_global_fp32_host = None + grad_params_fp32_global_host = None + + spawn_multiprocessing( + parallel_assert_outer_prod_mean, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + C_in, + C_hidden, + C_out, + layer_state_dict_fp64, + input_global_fp64.detach().clone().cpu(), + mask_global_fp64.detach().clone().cpu(), + output_expected_global_fp64.detach().clone().cpu(), + d_output_expected_global_fp64.detach().clone().cpu(), + input_global_fp64.grad.detach().clone().cpu(), + grad_params_fp64_expected_global_host, + output_global_fp32_host, + d_input_global_fp32_host, + grad_params_fp32_global_host, + ) diff --git a/tests/distributed/model/layers/test_dtensor_pair_weighted_averaging.py b/tests/distributed/model/layers/test_dtensor_pair_weighted_averaging.py new file mode 100644 index 000000000..c9a9f0e7b --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_pair_weighted_averaging.py @@ -0,0 +1,429 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import pytest +import torch +from torch.distributed.tensor import Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.pair_averaging import PairWeightedAveraging as DistributedPairWeightedAveraging +from boltz.distributed.model.layers.pair_averaging import Ring2DCommPairAveraging +from boltz.model.layers.pair_averaging import PairWeightedAveraging as SerialPairWeightedAveraging +from boltz.testing.utils import ( + assert_all_identical, + assert_no_percentile_upshift, + assert_tensors_identical, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_assert_pair_weighted_averaging( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + c_m, + c_z, + c_h, + num_heads, + layer_state_dict, + input_m_global_host, + input_z_global_host, + mask_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_m_expected_global_host, + d_input_z_expected_global_host, + grad_params_expected_global_host, + output_global_fp32_host: torch.Tensor | None = None, + d_input_m_global_fp32_host: torch.Tensor | None = None, + d_input_z_global_fp32_host: torch.Tensor | None = None, + grad_params_fp32_global_host: dict[str, torch.Tensor] | None = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + if torch.finfo(dtype).resolution < torch.finfo(output_expected_global_host.dtype).resolution: + raise ValueError( + f"Target dtype {dtype} has higher precision than reference output's dtype {output_expected_global_host.dtype}" + ) + + if ( + ((output_global_fp32_host is None) != (d_input_m_global_fp32_host is None)) + or ((output_global_fp32_host is None) != (d_input_z_global_fp32_host is None)) + or ((output_global_fp32_host is not None) != (grad_params_fp32_global_host is not None)) + ): + raise ValueError( + "output_global_fp32_host, d_input_m_global_fp32_host, d_input_z_global_fp32_host, and grad_params_fp32_global_host must be either all None or all not None" + ) + + check_error_hist = output_global_fp32_host is not None + + layout_map = manager.layout_subgroups["cp"] + ring_comm = Ring2DCommPairAveraging(manager.group["cp"], manager.subgroups["cp"][0], layout_map) + + dtype2inf = {torch.float32: 1e9, torch.float64: 1e18} + module_serial = SerialPairWeightedAveraging(c_m, c_z, c_h, num_heads, inf=dtype2inf[dtype]).to(dtype=dtype) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.to(device=manager.device) + module = DistributedPairWeightedAveraging(module_serial, manager.device_mesh_subgroups, ring_comm) + module.train() + + # Input tensors have different sharding patterns: + # m: (B, S, N, c_m) - sharded on dims 1 and 2 (S and N) + # z: (B, N, N, c_z) - sharded on dims 1 and 2 (N and N) + # mask: (B, N, N) - sharded on dims 1 and 2 (N and N) + placements_m = (Shard(0), Shard(1), Shard(2)) # For m tensor + placements_z_mask = (Shard(0), Shard(1), Shard(2)) # For z and mask tensors + + # Distribute input tensors + input_m_dtensor = distribute_tensor( + input_m_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_m, + ).requires_grad_(True) + + input_z_dtensor = distribute_tensor( + input_z_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_mask, + ).requires_grad_(True) + + mask_dtensor = distribute_tensor( + mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_mask, + ) + + # Distribute expected outputs + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_m, + ) + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_m, + src_data_rank=None, + ) + d_input_m_expected_dtensor = distribute_tensor( + d_input_m_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_m, + src_data_rank=None, + ) + d_input_z_expected_dtensor = distribute_tensor( + d_input_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_mask, + src_data_rank=None, + ) + + # Create copies to verify inputs aren't modified + input_m_dtensor_copy = input_m_dtensor.detach().clone().requires_grad_(True) + input_z_dtensor_copy = input_z_dtensor.detach().clone().requires_grad_(True) + mask_dtensor_copy = mask_dtensor.detach().clone() + + if check_error_hist: + # Forward pass + output_dtensor_result = module(input_m_dtensor, input_z_dtensor, mask_dtensor) + output_dtensor_result.backward(d_output_expected_dtensor) + + # Distribute FP32 comparison data + output_fp32_dtensor = distribute_tensor( + output_global_fp32_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_m, + src_data_rank=None, + ) + + d_input_m_fp32_dtensor = distribute_tensor( + d_input_m_global_fp32_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_m, + src_data_rank=None, + ) + + d_input_z_fp32_dtensor = distribute_tensor( + d_input_z_global_fp32_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_mask, + src_data_rank=None, + ) + + # Test error histogram for output + assert_no_percentile_upshift( + output_dtensor_result.to_local(), + output_expected_dtensor.to_local(), + output_fp32_dtensor.to_local(), + names_input=("output_cp_fp32", "output_serial_fp64", "output_serial_fp32"), + ) + + # Test error histogram for input gradients + assert_no_percentile_upshift( + input_m_dtensor.grad.to_local(), + d_input_m_expected_dtensor.to_local(), + d_input_m_fp32_dtensor.to_local(), + names_input=("d_input_m_cp_fp32", "d_input_m_serial_fp64", "d_input_m_serial_fp32"), + ) + + assert_no_percentile_upshift( + input_z_dtensor.grad.to_local(), + d_input_z_expected_dtensor.to_local(), + d_input_z_fp32_dtensor.to_local(), + names_input=("d_input_z_cp_fp32", "d_input_z_serial_fp64", "d_input_z_serial_fp32"), + ) + # Test error histogram for parameter gradients + for name, grad_param_expected_global in grad_params_expected_global_host.items(): + grad_param_result_global = get_param_by_key(module, name).grad.full_tensor().cpu() + assert_no_percentile_upshift( + grad_param_result_global, + grad_param_expected_global.to(dtype=grad_param_result_global.dtype), + grad_params_fp32_global_host[name], + names_input=(f"d_{name}_cp_fp32", f"d_{name}_serial_fp64", f"d_{name}_serial_fp32"), + ) + else: + # Forward pass + output_dtensor_result = module(input_m_dtensor, input_z_dtensor, mask_dtensor) + + # Verify inputs weren't modified + assert_tensors_identical( + input_m_dtensor_copy.to_local(), input_m_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical( + input_z_dtensor_copy.to_local(), input_z_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(mask_dtensor_copy.to_local(), mask_dtensor.to_local()) + + # Test forward pass results + assert output_dtensor_result.shape == output_expected_dtensor.shape + assert output_dtensor_result.stride() == output_expected_dtensor.stride() + + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradients + assert input_m_dtensor.grad.shape == d_input_m_expected_dtensor.shape + assert input_m_dtensor.grad.stride() == d_input_m_expected_dtensor.stride() + assert input_z_dtensor.grad.shape == d_input_z_expected_dtensor.shape + assert input_z_dtensor.grad.stride() == d_input_z_expected_dtensor.stride() + + torch.testing.assert_close(input_m_dtensor.grad.to_local(), d_input_m_expected_dtensor.to_local()) + torch.testing.assert_close(input_z_dtensor.grad.to_local(), d_input_z_expected_dtensor.to_local()) + + # Test parameter gradients + grad_params_result_dtensors = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in grad_params_expected_global_host: + # do an extra check here to make sure the parallel computation don't result in extra gradients + raise ValueError(f"Parameter {name} has a resulting gradient but it is not in the reference module") + grad_params_result_dtensors[name] = param.grad + + for name, grad_param_expected_global_host in grad_params_expected_global_host.items(): + assert name in grad_params_result_dtensors, f"Parameter {name}'s gradient is not found in result gradients" + grad_params_result = grad_params_result_dtensors[name] + assert grad_params_result.shape == grad_param_expected_global_host.shape + assert grad_params_result.stride() == grad_param_expected_global_host.stride() + grad_params_result_global = grad_params_result.full_tensor() + torch.testing.assert_close(grad_params_result_global.cpu(), grad_param_expected_global_host.to(dtype=dtype)) + assert_all_identical(grad_params_result_global, manager.group["cp"]) + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + d_input_m_global_result_host = input_m_dtensor.grad.full_tensor().cpu() + d_input_z_global_result_host = input_z_dtensor.grad.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(output_global_result_host, output_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_m_global_result_host, d_input_m_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_z_global_result_host, d_input_z_expected_global_host.to(dtype=dtype)) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, dtype, check_error_hist", + ( + params_test := [ + ## CUDA tests (2 GPUs) + # (((2, (1, 1)), True, "cuda", "ENV"), torch.float32, True), + # (((2, (1, 1)), True, "cuda", "ENV"), torch.float64, True), + ## CUDA tests (8 GPUs) + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, False), + ## CPU tests + (((1, (3, 3)), True, "cuda", "ENV"), torch.float32, False), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}, check_error_hist:{x[2]}" + for x in params_test + ], +) +def test_pair_weighted_averaging_parallel(setup_env, dtype, check_error_hist): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + # dtype is the dtype used by the parallel computation + # check_error_hist determine whether to compare the error histograms between + # (CP_in_FP32, serial_in_FP64) and (serial_in_FP32, serial_in_FP64) + # Typically, check_error_hist will use large input dimensions to emulate + # the real-world use cases. Same with dtype==torch.float64. + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + + # For float64 and error histogram check, we use a realistic model and input size + # with heavier computation to test the numerical stability. On the other hand, + # a smaller model and input size incur less numerical error accumulation to allow + # a larger range of input values to detect logical bugs inexpensively by using + # smaller dimensions. + test_large_model = check_error_hist or dtype == torch.float64 + + if test_large_model: + N = size_ring * 128 # Number of tokens + S = size_ring * 128 # Number of sequences + c_m = 64 # Sequence dimension + c_z = 128 # Pairwise dimension + c_h = 32 # Hidden dimension per head + num_heads = 8 + min_val_init = -5e-2 if dtype == torch.float64 else -1e-2 + max_val_init = -min_val_init + else: + N = size_ring * 2 # Number of tokens + S = size_ring * 3 # Number of sequences + c_m = 3 # Sequence dimension + c_z = 5 # Pairwise dimension + c_h = 7 # Hidden dimension per head + num_heads = 3 + min_val_init = -0.5 + max_val_init = 0.5 + + seed = 42 + seed_by_rank(0, seed=seed) + + # compute reference results with FP64 + input_m_global_fp64 = torch.empty((B, S, N, c_m), dtype=torch.float64, requires_grad=True, device=device_type) + input_z_global_fp64 = torch.empty((B, N, N, c_z), dtype=torch.float64, requires_grad=True, device=device_type) + mask_global_fp64 = torch.ones((B, N, N), dtype=torch.float64, requires_grad=False, device=device_type) + mask_global_fp64[0, N // size_ring :, :] = 0 + mask_global_fp64[0, :, N // size_ring :] = 0 + + reference_module = SerialPairWeightedAveraging(c_m, c_z, c_h, num_heads, inf=1e18).to(dtype=torch.float64) + # The output activation and gradient of the layer weights typically increase by 3 to 4 orders of magnitude, + # where the ULP would be too large and numerical error distribution becomes very wide, i.e., we would have + # very unpredictable numerical errors. That would make the test results very noisy and not very useful to + # detect logical bugs in the code. To avoid this, we use a smaller range for the input and layer weights. + init_tensors_uniform([input_m_global_fp64, input_z_global_fp64], low=min_val_init, high=max_val_init) + init_module_params_uniform(reference_module, low=min_val_init, high=max_val_init) + layer_state_dict_fp64 = reference_module.state_dict() + reference_module = reference_module.to(device=device_type).train() + + output_expected_global_fp64 = reference_module(input_m_global_fp64, input_z_global_fp64, mask_global_fp64) + d_output_expected_global_fp64 = torch.rand_like(output_expected_global_fp64) + output_expected_global_fp64.backward(d_output_expected_global_fp64) + + grad_params_fp64_expected_global_host = { + name: param.grad.detach().clone().cpu() for name, param in reference_module.named_parameters() + } + + if check_error_hist: + input_m_global_fp32 = input_m_global_fp64.detach().clone().to(dtype=torch.float32).requires_grad_(True) + input_z_global_fp32 = input_z_global_fp64.detach().clone().to(dtype=torch.float32).requires_grad_(True) + mask_global_fp32 = mask_global_fp64.detach().clone().to(dtype=torch.float32).requires_grad_(False) + reference_module_fp32 = SerialPairWeightedAveraging(c_m, c_z, c_h, num_heads, inf=1e9).to(dtype=torch.float32) + reference_module_fp32.load_state_dict(layer_state_dict_fp64) + reference_module_fp32 = reference_module_fp32.to(device=device_type).train() + output_global_fp32 = reference_module_fp32(input_m_global_fp32, input_z_global_fp32, mask_global_fp32) + d_output_expected_global_fp32 = d_output_expected_global_fp64.to(dtype=torch.float32) + output_global_fp32.backward(d_output_expected_global_fp32) + + output_global_fp32_host = output_global_fp32.detach().clone().cpu() + d_input_m_global_fp32_host = input_m_global_fp32.grad.detach().clone().cpu() + d_input_z_global_fp32_host = input_z_global_fp32.grad.detach().clone().cpu() + grad_params_fp32_global_host = { + name: param.grad.detach().clone().cpu() for name, param in reference_module_fp32.named_parameters() + } + else: + output_global_fp32_host = None + d_input_m_global_fp32_host = None + d_input_z_global_fp32_host = None + grad_params_fp32_global_host = None + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_pair_weighted_averaging, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + c_m, + c_z, + c_h, + num_heads, + layer_state_dict_fp64, + input_m_global_fp64.detach().clone().cpu(), + input_z_global_fp64.detach().clone().cpu(), + mask_global_fp64.detach().clone().cpu(), + output_expected_global_fp64.detach().clone().cpu(), + d_output_expected_global_fp64.detach().clone().cpu(), + input_m_global_fp64.grad.detach().clone().cpu(), + input_z_global_fp64.grad.detach().clone().cpu(), + grad_params_fp64_expected_global_host, + output_global_fp32_host, + d_input_m_global_fp32_host, + d_input_z_global_fp32_host, + grad_params_fp32_global_host, + ) diff --git a/tests/distributed/model/layers/test_dtensor_pairformer_layer.py b/tests/distributed/model/layers/test_dtensor_pairformer_layer.py new file mode 100644 index 000000000..1a0d0bc46 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_pairformer_layer.py @@ -0,0 +1,504 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.pairformer import PairformerLayer as DistributedPairformerLayer +from boltz.model.layers.pairformer import PairformerLayer as SerialPairformerLayer +from boltz.testing.utils import ( + assert_all_identical, + assert_no_percentile_upshift, + assert_tensors_identical, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + set_dtype_specific_inf_values, + spawn_multiprocessing, +) + + +def parallel_assert_pairformer_layer( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + token_s, + token_z, + num_heads, + dropout, + pairwise_head_width, + pairwise_num_heads, + post_layer_norm, + layer_state_dict, + input_s_global_host, + input_z_global_host, + mask_global_host, + pair_mask_global_host, + output_s_expected_global_host, + output_z_expected_global_host, + d_output_s_expected_global_host, + d_output_z_expected_global_host, + d_input_s_expected_global_host, + d_input_z_expected_global_host, + expected_param_grads_global_host_dict, + output_s_global_fp32_host: torch.Tensor | None = None, + output_z_global_fp32_host: torch.Tensor | None = None, + d_input_s_global_fp32_host: torch.Tensor | None = None, + d_input_z_global_fp32_host: torch.Tensor | None = None, + grad_params_fp32_global_host: dict[str, torch.Tensor] | None = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + if torch.finfo(dtype).resolution < torch.finfo(output_s_expected_global_host.dtype).resolution: + raise ValueError( + f"Target dtype {dtype} has higher precision than reference output's dtype {output_s_expected_global_host.dtype}" + ) + + if ( + (output_s_global_fp32_host is None) != (output_z_global_fp32_host is None) + or (output_s_global_fp32_host is None) != (d_input_s_global_fp32_host is None) + or (output_s_global_fp32_host is None) != (d_input_z_global_fp32_host is None) + or (output_s_global_fp32_host is None) != (grad_params_fp32_global_host is None) + ): + raise ValueError( + "output_s_global_fp32_host, output_z_global_fp32_host, d_input_s_global_fp32_host, " + "d_input_z_global_fp32_host, and grad_params_fp32_global_host must be either all None or all not None" + ) + + check_error_hist = output_s_global_fp32_host is not None + + # Create serial reference module (Boltz-2: v2=True) + module_serial = SerialPairformerLayer( + token_s=token_s, + token_z=token_z, + num_heads=num_heads, + dropout=dropout, + pairwise_head_width=pairwise_head_width, + pairwise_num_heads=pairwise_num_heads, + post_layer_norm=post_layer_norm, + v2=True, + ) + module_serial = module_serial.to(dtype=dtype, device=manager.device) + module_serial.load_state_dict(layer_state_dict) + set_dtype_specific_inf_values(module_serial, dtype) + + # Create distributed module + module = DistributedPairformerLayer(module_serial, manager) + module.train() + + placements_s_mask = (Shard(0), Shard(1), Replicate()) + placements_z_pair_mask = (Shard(0), Shard(1), Shard(2)) + + input_s_dtensor = distribute_tensor( + input_s_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + ).requires_grad_(True) + + input_z_dtensor = distribute_tensor( + input_z_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ).requires_grad_(True) + + mask_dtensor = distribute_tensor( + mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + ) + pair_mask_dtensor = distribute_tensor( + pair_mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ) + + d_output_s_expected_dtensor = distribute_tensor( + d_output_s_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + ) + d_output_z_expected_dtensor = distribute_tensor( + d_output_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ) + output_s_expected_dtensor = distribute_tensor( + output_s_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + src_data_rank=None, + ) + output_z_expected_dtensor = distribute_tensor( + output_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + d_input_s_expected_dtensor = distribute_tensor( + d_input_s_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + src_data_rank=None, + ) + d_input_z_expected_dtensor = distribute_tensor( + d_input_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + + input_s_dtensor_copy = input_s_dtensor.detach().clone().requires_grad_(True) + input_z_dtensor_copy = input_z_dtensor.detach().clone().requires_grad_(True) + mask_dtensor_copy = mask_dtensor.detach().clone() + pair_mask_dtensor_copy = pair_mask_dtensor.detach().clone() + + if check_error_hist: + output_s_dtensor_result, output_z_dtensor_result = module( + s=input_s_dtensor, + z=input_z_dtensor, + mask=mask_dtensor, + pair_mask=pair_mask_dtensor, + ) + torch.autograd.backward( + [output_s_dtensor_result, output_z_dtensor_result], + [d_output_s_expected_dtensor, d_output_z_expected_dtensor], + ) + output_s_fp32_dtensor = distribute_tensor( + output_s_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + src_data_rank=None, + ) + output_z_fp32_dtensor = distribute_tensor( + output_z_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + d_input_s_fp32_dtensor = distribute_tensor( + d_input_s_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + src_data_rank=None, + ) + d_input_z_fp32_dtensor = distribute_tensor( + d_input_z_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + assert_no_percentile_upshift( + output_s_dtensor_result.to_local(), + output_s_expected_dtensor.to_local(), + output_s_fp32_dtensor.to_local(), + names_input=("output_s_cp_fp32", "output_s_serial_fp64", "output_s_serial_fp32"), + ) + assert_no_percentile_upshift( + output_z_dtensor_result.to_local(), + output_z_expected_dtensor.to_local(), + output_z_fp32_dtensor.to_local(), + names_input=("output_z_cp_fp32", "output_z_serial_fp64", "output_z_serial_fp32"), + ) + assert_no_percentile_upshift( + input_s_dtensor.grad.to_local(), + d_input_s_expected_dtensor.to_local(), + d_input_s_fp32_dtensor.to_local(), + names_input=("d_input_s_cp_fp32", "d_input_s_serial_fp64", "d_input_s_serial_fp32"), + ) + assert_no_percentile_upshift( + input_z_dtensor.grad.to_local(), + d_input_z_expected_dtensor.to_local(), + d_input_z_fp32_dtensor.to_local(), + names_input=("d_input_z_cp_fp32", "d_input_z_serial_fp64", "d_input_z_serial_fp32"), + ) + for name, grad_param_expected_global in expected_param_grads_global_host_dict.items(): + grad_param_result_global = get_param_by_key(module, name).grad.full_tensor().cpu() + assert_no_percentile_upshift( + grad_param_result_global, + grad_param_expected_global.to(dtype=grad_param_result_global.dtype), + grad_params_fp32_global_host[name], + names_input=(f"d_{name}_cp_fp32", f"d_{name}_serial_fp64", f"d_{name}_serial_fp32"), + ) + else: + output_s_dtensor_result, output_z_dtensor_result = module( + s=input_s_dtensor, + z=input_z_dtensor, + mask=mask_dtensor, + pair_mask=pair_mask_dtensor, + ) + assert_tensors_identical( + input_s_dtensor_copy.to_local(), input_s_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical( + input_z_dtensor_copy.to_local(), input_z_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(mask_dtensor_copy.to_local(), mask_dtensor.to_local()) + assert_tensors_identical(pair_mask_dtensor_copy.to_local(), pair_mask_dtensor.to_local()) + torch.testing.assert_close( + output_s_dtensor_result.to_local() * mask_dtensor.to_local().unsqueeze(-1), + output_s_expected_dtensor.to_local() * mask_dtensor.to_local().unsqueeze(-1), + ) + torch.testing.assert_close(output_z_dtensor_result.to_local(), output_z_expected_dtensor.to_local()) + + # Clone upstream gradients so we can verify backward does not modify them (match boltz1x) + d_output_s_expected_dtensor_copy = d_output_s_expected_dtensor.detach().clone() + d_output_z_expected_dtensor_copy = d_output_z_expected_dtensor.detach().clone() + torch.autograd.backward( + [output_s_dtensor_result, output_z_dtensor_result], + [d_output_s_expected_dtensor, d_output_z_expected_dtensor], + ) + # Verify upstream gradients were not modified + assert_tensors_identical(d_output_s_expected_dtensor_copy.to_local(), d_output_s_expected_dtensor.to_local()) + assert_tensors_identical(d_output_z_expected_dtensor_copy.to_local(), d_output_z_expected_dtensor.to_local()) + + torch.testing.assert_close( + input_s_dtensor.grad.to_local() * mask_dtensor.to_local().unsqueeze(-1), + d_input_s_expected_dtensor.to_local() * mask_dtensor.to_local().unsqueeze(-1), + ) + torch.testing.assert_close(input_z_dtensor.grad.to_local(), d_input_z_expected_dtensor.to_local()) + + output_s_global_result_host = output_s_dtensor_result.full_tensor().cpu() + output_z_global_result_host = output_z_dtensor_result.full_tensor().cpu() + d_input_s_global_result_host = input_s_dtensor.grad.full_tensor().cpu() + d_input_z_global_result_host = input_z_dtensor.grad.full_tensor().cpu() + torch.testing.assert_close(output_s_global_result_host, output_s_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(output_z_global_result_host, output_z_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_s_global_result_host, d_input_s_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_z_global_result_host, d_input_z_expected_global_host.to(dtype=dtype)) + + result_param_grads_dict = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in expected_param_grads_global_host_dict: + raise ValueError(f"Parameter {name} has a resulting gradient but it is not in the reference") + result_param_grads_dict[name] = param.grad + for name, expected_grad_global_host in expected_param_grads_global_host_dict.items(): + assert name in result_param_grads_dict, f"Parameter {name}'s gradient is not found in result gradients" + result_grad = result_param_grads_dict[name] + result_grad_global = result_grad.full_tensor() + torch.testing.assert_close(result_grad_global.cpu(), expected_grad_global_host.to(dtype=dtype)) + assert_all_identical(result_grad_global, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env, dtype, check_error_hist", + ( + params_test := [ + (((1, (2, 2)), True, "cuda", "ENV"), torch.float32, True), + (((1, (2, 2)), True, "cuda", "ENV"), torch.float64, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, False), + (((1, (3, 3)), True, "cuda", "ENV"), torch.float32, False), + (((1, (3, 3)), True, "cpu", "ENV"), torch.float32, False), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}, check_error_hist:{x[2]}" + for x in params_test + ], +) +def test_pairformer_layer_parallel(setup_env, dtype, check_error_hist): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + # For float64 and error histogram check, we use a realistic model and input size + # with heavier computation to test the numerical stability. On the other hand, + # a smaller model and input size incur less numerical error accumulation to allow + # a larger range of input values to detect logical bugs inexpensively by using + # smaller dimensions. + test_large_model = check_error_hist or dtype == torch.float64 + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + if test_large_model: + N = size_ring * 32 # Number of tokens + token_s = 32 # Token single embedding dimension + token_z = 128 # Token pairwise embedding dimension + num_heads = 16 + pairwise_head_width = 32 + pairwise_num_heads = 4 + min_val_init = -0.08 if dtype == torch.float64 else -5e-4 + max_val_init = -min_val_init + else: + N = size_ring * 2 # Number of tokens + token_s = 8 # Token single embedding dimension + token_z = 12 # Token pairwise embedding dimension + num_heads = 4 + pairwise_head_width = 4 + pairwise_num_heads = 2 + min_val_init = -0.5 + max_val_init = 0.5 + dropout = 0.0 # disable dropout as we have no way to match the random sequences between serial and CP + post_layer_norm = False + + seed = 42 + seed_by_rank(0, seed=seed) + + # Compute reference results with FP64 + input_s_global_fp64 = torch.empty((B, N, token_s), dtype=torch.float64, requires_grad=True, device=device_type) + input_z_global_fp64 = torch.empty((B, N, N, token_z), dtype=torch.float64, requires_grad=True, device=device_type) + mask_global_fp64 = torch.ones((B, N), dtype=torch.float64, requires_grad=False, device=device_type) + mask_global_fp64[0, N // size_ring :] = 0 + pair_mask_global_fp64 = torch.randint(0, 2, (B, N, N), dtype=torch.float64, requires_grad=False, device=device_type) + pair_mask_global_fp64[0, N // size_ring :, :] = 0 + pair_mask_global_fp64[0, :, N // size_ring :] = 0 + + # Create reference serial module + reference_module = SerialPairformerLayer( + token_s=token_s, + token_z=token_z, + num_heads=num_heads, + dropout=dropout, + pairwise_head_width=pairwise_head_width, + pairwise_num_heads=pairwise_num_heads, + post_layer_norm=post_layer_norm, + v2=True, + ) + init_tensors_uniform([input_s_global_fp64, input_z_global_fp64], low=min_val_init, high=max_val_init) + init_module_params_uniform(reference_module, low=min_val_init, high=max_val_init) + set_dtype_specific_inf_values(reference_module, torch.float64) + reference_module = reference_module.to(dtype=torch.float64, device=device_type).train() + layer_state_dict_fp64 = reference_module.state_dict() + + output_s_expected_global_fp64, output_z_expected_global_fp64 = reference_module( + s=input_s_global_fp64, + z=input_z_global_fp64, + mask=mask_global_fp64, + pair_mask=pair_mask_global_fp64, + ) + d_output_s_expected_global_fp64 = torch.rand_like(output_s_expected_global_fp64) + d_output_z_expected_global_fp64 = torch.rand_like(output_z_expected_global_fp64) + torch.autograd.backward( + [output_s_expected_global_fp64, output_z_expected_global_fp64], + [d_output_s_expected_global_fp64, d_output_z_expected_global_fp64], + ) + + grad_params_fp64_expected_global_host = { + name: param.grad.detach().to(dtype=dtype, device="cpu", copy=True) + for name, param in reference_module.named_parameters() + } + + if check_error_hist: + input_s_global_fp32 = input_s_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(True) + input_z_global_fp32 = input_z_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(True) + mask_global_fp32 = mask_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(False) + pair_mask_global_fp32 = pair_mask_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(False) + reference_module_fp32 = SerialPairformerLayer( + token_s=token_s, + token_z=token_z, + num_heads=num_heads, + dropout=dropout, + pairwise_head_width=pairwise_head_width, + pairwise_num_heads=pairwise_num_heads, + post_layer_norm=post_layer_norm, + v2=True, + ) + reference_module_fp32.load_state_dict(layer_state_dict_fp64) + reference_module_fp32 = reference_module_fp32.to(dtype=torch.float32, device=device_type).train() + set_dtype_specific_inf_values(reference_module_fp32, torch.float32) + output_s_global_fp32, output_z_global_fp32 = reference_module_fp32( + s=input_s_global_fp32, + z=input_z_global_fp32, + mask=mask_global_fp32, + pair_mask=pair_mask_global_fp32, + ) + d_output_s_expected_global_fp32 = d_output_s_expected_global_fp64.to(dtype=torch.float32) + d_output_z_expected_global_fp32 = d_output_z_expected_global_fp64.to(dtype=torch.float32) + torch.autograd.backward( + [output_s_global_fp32, output_z_global_fp32], + [d_output_s_expected_global_fp32, d_output_z_expected_global_fp32], + ) + output_s_global_fp32_host = output_s_global_fp32.detach().to(device="cpu", copy=True) + output_z_global_fp32_host = output_z_global_fp32.detach().to(device="cpu", copy=True) + d_input_s_global_fp32_host = input_s_global_fp32.grad.detach().to(device="cpu", copy=True) + d_input_z_global_fp32_host = input_z_global_fp32.grad.detach().to(device="cpu", copy=True) + grad_params_fp32_global_host = { + name: param.grad.detach().to(device="cpu", copy=True) + for name, param in reference_module_fp32.named_parameters() + } + else: + output_s_global_fp32_host = None + output_z_global_fp32_host = None + d_input_s_global_fp32_host = None + d_input_z_global_fp32_host = None + grad_params_fp32_global_host = None + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_pairformer_layer, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + token_s, + token_z, + num_heads, + dropout, + pairwise_head_width, + pairwise_num_heads, + post_layer_norm, + layer_state_dict_fp64, + input_s_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + input_z_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + mask_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + pair_mask_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + output_s_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + d_output_s_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + d_output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + input_s_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True), + input_z_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True), + grad_params_fp64_expected_global_host, + output_s_global_fp32_host, + output_z_global_fp32_host, + d_input_s_global_fp32_host, + d_input_z_global_fp32_host, + grad_params_fp32_global_host, + ) diff --git a/tests/distributed/model/layers/test_dtensor_pairformer_module.py b/tests/distributed/model/layers/test_dtensor_pairformer_module.py new file mode 100644 index 000000000..f74d3d56f --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_pairformer_module.py @@ -0,0 +1,755 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from typing import Any + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.pairformer import PairformerModule as DistributedPairformerModule +from boltz.model.layers.pairformer import PairformerModule as SerialPairformerModule +from boltz.testing.utils import ( + assert_all_identical, + assert_no_percentile_upshift, + assert_tensors_identical, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + set_dtype_specific_inf_values, + spawn_multiprocessing, +) + + +def parallel_assert_pairformer_module( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + pairformer_params: dict[str, Any], + module_state_dict, + input_s_global_host, + input_z_global_host, + mask_global_host, + pair_mask_global_host, + output_s_expected_global_host, + output_z_expected_global_host, + d_output_s_expected_global_host, + d_output_z_expected_global_host, + d_input_s_expected_global_host, + d_input_z_expected_global_host, + expected_param_grads_global_host_dict, + output_s_global_fp32_host: torch.Tensor | None = None, + output_z_global_fp32_host: torch.Tensor | None = None, + d_input_s_global_fp32_host: torch.Tensor | None = None, + d_input_z_global_fp32_host: torch.Tensor | None = None, + grad_params_fp32_global_host: dict[str, torch.Tensor] | None = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + if torch.finfo(dtype).resolution < torch.finfo(output_s_expected_global_host.dtype).resolution: + raise ValueError( + f"Target dtype {dtype} has higher precision than reference output's dtype {output_s_expected_global_host.dtype}" + ) + + if ( + (output_s_global_fp32_host is None) != (output_z_global_fp32_host is None) + or (output_s_global_fp32_host is None) != (d_input_s_global_fp32_host is None) + or (output_s_global_fp32_host is None) != (d_input_z_global_fp32_host is None) + or (output_s_global_fp32_host is None) != (grad_params_fp32_global_host is None) + ): + raise ValueError( + "output_s_global_fp32_host, output_z_global_fp32_host, d_input_s_global_fp32_host, " + "d_input_z_global_fp32_host, and grad_params_fp32_global_host must be either all None or all not None" + ) + + check_error_hist = output_s_global_fp32_host is not None + + # Create serial reference module + module_serial = SerialPairformerModule(**pairformer_params) + module_serial.load_state_dict(module_state_dict) + set_dtype_specific_inf_values(module_serial, dtype) + module_serial = module_serial.to(dtype=dtype, device=manager.device) + + module = DistributedPairformerModule(module_serial, manager) + module.train() + + placements_s_mask = (Shard(0), Shard(1), Replicate()) + placements_z_pair_mask = (Shard(0), Shard(1), Shard(2)) + + input_s_dtensor = distribute_tensor( + input_s_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + ).requires_grad_(True) + input_z_dtensor = distribute_tensor( + input_z_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ).requires_grad_(True) + mask_dtensor = distribute_tensor( + mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + ) + pair_mask_dtensor = distribute_tensor( + pair_mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ) + + d_output_s_expected_dtensor = distribute_tensor( + d_output_s_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + ) + d_output_z_expected_dtensor = distribute_tensor( + d_output_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ) + output_s_expected_dtensor = distribute_tensor( + output_s_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + src_data_rank=None, + ) + output_z_expected_dtensor = distribute_tensor( + output_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + d_input_s_expected_dtensor = distribute_tensor( + d_input_s_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + src_data_rank=None, + ) + d_input_z_expected_dtensor = distribute_tensor( + d_input_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + + input_s_dtensor_copy = input_s_dtensor.detach().clone().requires_grad_(True) + input_z_dtensor_copy = input_z_dtensor.detach().clone().requires_grad_(True) + mask_dtensor_copy = mask_dtensor.detach().clone() + pair_mask_dtensor_copy = pair_mask_dtensor.detach().clone() + + if check_error_hist: + output_s_dtensor_result, output_z_dtensor_result = module( + s=input_s_dtensor, + z=input_z_dtensor, + mask=mask_dtensor, + pair_mask=pair_mask_dtensor, + ) + torch.autograd.backward( + [output_s_dtensor_result, output_z_dtensor_result], + [d_output_s_expected_dtensor, d_output_z_expected_dtensor], + ) + output_s_fp32_dtensor = distribute_tensor( + output_s_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + src_data_rank=None, + ) + output_z_fp32_dtensor = distribute_tensor( + output_z_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + d_input_s_fp32_dtensor = distribute_tensor( + d_input_s_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + src_data_rank=None, + ) + d_input_z_fp32_dtensor = distribute_tensor( + d_input_z_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + assert_no_percentile_upshift( + output_s_dtensor_result.to_local(), + output_s_expected_dtensor.to_local(), + output_s_fp32_dtensor.to_local(), + names_input=("output_s_cp_fp32", "output_s_serial_fp64", "output_s_serial_fp32"), + ) + assert_no_percentile_upshift( + output_z_dtensor_result.to_local(), + output_z_expected_dtensor.to_local(), + output_z_fp32_dtensor.to_local(), + names_input=("output_z_cp_fp32", "output_z_serial_fp64", "output_z_serial_fp32"), + ) + assert_no_percentile_upshift( + input_s_dtensor.grad.to_local(), + d_input_s_expected_dtensor.to_local(), + d_input_s_fp32_dtensor.to_local(), + names_input=("d_input_s_cp_fp32", "d_input_s_serial_fp64", "d_input_s_serial_fp32"), + ) + assert_no_percentile_upshift( + input_z_dtensor.grad.to_local(), + d_input_z_expected_dtensor.to_local(), + d_input_z_fp32_dtensor.to_local(), + names_input=("d_input_z_cp_fp32", "d_input_z_serial_fp64", "d_input_z_serial_fp32"), + ) + for name, grad_param_expected_global in expected_param_grads_global_host_dict.items(): + grad_param_result_global = get_param_by_key(module, name).grad.full_tensor().cpu() + assert_no_percentile_upshift( + grad_param_result_global, + grad_param_expected_global.to(dtype=grad_param_result_global.dtype), + grad_params_fp32_global_host[name], + names_input=(f"d_{name}_cp_fp32", f"d_{name}_serial_fp64", f"d_{name}_serial_fp32"), + ) + else: + output_s_dtensor_result, output_z_dtensor_result = module( + s=input_s_dtensor, + z=input_z_dtensor, + mask=mask_dtensor, + pair_mask=pair_mask_dtensor, + ) + assert_tensors_identical( + input_s_dtensor_copy.to_local(), input_s_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical( + input_z_dtensor_copy.to_local(), input_z_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(mask_dtensor_copy.to_local(), mask_dtensor.to_local()) + assert_tensors_identical(pair_mask_dtensor_copy.to_local(), pair_mask_dtensor.to_local()) + torch.testing.assert_close( + output_s_dtensor_result.to_local() * mask_dtensor.to_local().unsqueeze(-1), + output_s_expected_dtensor.to_local() * mask_dtensor.to_local().unsqueeze(-1), + ) + torch.testing.assert_close(output_z_dtensor_result.to_local(), output_z_expected_dtensor.to_local()) + + # Backward pass + d_output_s_expected_dtensor_copy = d_output_s_expected_dtensor.detach().clone() + d_output_z_expected_dtensor_copy = d_output_z_expected_dtensor.detach().clone() + torch.autograd.backward( + [output_s_dtensor_result, output_z_dtensor_result], + [d_output_s_expected_dtensor, d_output_z_expected_dtensor], + ) + + # Verify upstream gradients weren't modified + assert_tensors_identical(d_output_s_expected_dtensor_copy.to_local(), d_output_s_expected_dtensor.to_local()) + assert_tensors_identical(d_output_z_expected_dtensor_copy.to_local(), d_output_z_expected_dtensor.to_local()) + + torch.testing.assert_close( + input_s_dtensor.grad.to_local() * mask_dtensor.to_local().unsqueeze(-1), + d_input_s_expected_dtensor.to_local() * mask_dtensor.to_local().unsqueeze(-1), + ) + torch.testing.assert_close(input_z_dtensor.grad.to_local(), d_input_z_expected_dtensor.to_local()) + + output_s_global_result_host = output_s_dtensor_result.full_tensor().cpu() + output_z_global_result_host = output_z_dtensor_result.full_tensor().cpu() + d_input_s_global_result_host = input_s_dtensor.grad.full_tensor().cpu() + d_input_z_global_result_host = input_z_dtensor.grad.full_tensor().cpu() + torch.testing.assert_close(output_s_global_result_host, output_s_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(output_z_global_result_host, output_z_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_s_global_result_host, d_input_s_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_z_global_result_host, d_input_z_expected_global_host.to(dtype=dtype)) + + result_param_grads_dict = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in expected_param_grads_global_host_dict: + raise ValueError(f"Parameter {name} has a resulting gradient but it is not in the reference") + result_param_grads_dict[name] = param.grad + for name, expected_grad_global_host in expected_param_grads_global_host_dict.items(): + assert name in result_param_grads_dict + result_grad = result_param_grads_dict[name] + result_grad_global = result_grad.full_tensor() + torch.testing.assert_close(result_grad_global.cpu(), expected_grad_global_host.to(dtype=dtype)) + assert_all_identical(result_grad_global, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env, dtype, check_error_hist, activation_checkpointing", + ( + params_test := [ + (((1, (2, 2)), True, "cuda", "ENV"), torch.float32, True, False), + (((1, (2, 2)), True, "cuda", "ENV"), torch.float64, False, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, True, False), + (((2, (2, 2)), True, "cpu", "ENV"), torch.float32, False, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, False, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, False, True), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}, check_error_hist:{x[2]}, activation_checkpointing:{x[3]}" + for x in params_test + ], +) +def test_pairformer_module_parallel(setup_env, dtype, check_error_hist, activation_checkpointing): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + # For float64 and error histogram check, we use a realistic model and input size + # with heavier computation to test the numerical stability. On the other hand, + # a smaller model and input size incur less numerical error accumulation to allow + # a larger range of input values to detect logical bugs inexpensively by using + # smaller dimensions. + test_large_model = check_error_hist or dtype == torch.float64 + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + if test_large_model: + N = size_ring * 64 # Number of tokens + token_s = 32 + token_z = 128 + num_blocks = 4 + num_heads = 16 + pairwise_head_width = 32 + pairwise_num_heads = 4 + min_val_init = -0.03 + max_val_init = 0.03 + else: + N = size_ring * 2 # Number of tokens + token_s = 8 + token_z = 12 + num_blocks = 2 + num_heads = 4 + pairwise_head_width = 4 + pairwise_num_heads = 2 + min_val_init = -0.5 + max_val_init = 0.5 + dropout = 0.0 + post_layer_norm = False + + pairformer_params = { + "token_s": token_s, + "token_z": token_z, + "num_blocks": num_blocks, + "num_heads": num_heads, + "dropout": dropout, + "pairwise_head_width": pairwise_head_width, + "pairwise_num_heads": pairwise_num_heads, + "post_layer_norm": post_layer_norm, + "activation_checkpointing": activation_checkpointing, + "v2": True, + } + + seed = 42 + seed_by_rank(0, seed=seed) + + # Compute reference results with FP64 + input_s_global_fp64 = torch.empty((B, N, token_s), dtype=torch.float64, requires_grad=True, device=device_type) + input_z_global_fp64 = torch.empty((B, N, N, token_z), dtype=torch.float64, requires_grad=True, device=device_type) + mask_global_fp64 = torch.ones((B, N), dtype=torch.float64, requires_grad=False, device=device_type) + mask_global_fp64[0, N // size_ring :] = 0 + pair_mask_global_fp64 = torch.randint(0, 2, (B, N, N), dtype=torch.float64, requires_grad=False, device=device_type) + pair_mask_global_fp64[0, N // size_ring :, :] = 0 + pair_mask_global_fp64[0, :, N // size_ring :] = 0 + + # Create reference serial module + reference_module = SerialPairformerModule(**pairformer_params) + # Initialize parameters to ensure reproducible behavior + init_tensors_uniform([input_s_global_fp64, input_z_global_fp64], low=min_val_init, high=max_val_init) + init_module_params_uniform(reference_module, low=min_val_init, high=max_val_init) + set_dtype_specific_inf_values(reference_module, torch.float64) + + reference_module = reference_module.to(dtype=torch.float64, device=device_type).train() + module_state_dict_fp64 = reference_module.state_dict() + + # Run forward pass + output_s_expected_global_fp64, output_z_expected_global_fp64 = reference_module( + s=input_s_global_fp64, + z=input_z_global_fp64, + mask=mask_global_fp64, + pair_mask=pair_mask_global_fp64, + ) + d_output_s_expected_global_fp64 = torch.rand_like(output_s_expected_global_fp64) + d_output_z_expected_global_fp64 = torch.rand_like(output_z_expected_global_fp64) + torch.autograd.backward( + [output_s_expected_global_fp64, output_z_expected_global_fp64], + [d_output_s_expected_global_fp64, d_output_z_expected_global_fp64], + ) + + grad_params_fp64_expected_global_host = { + name: param.grad.detach().to(dtype=dtype, device="cpu", copy=True) + for name, param in reference_module.named_parameters() + } + + del reference_module + if device_type == "cuda": + torch.cuda.empty_cache() + + if check_error_hist: + input_s_global_fp32 = input_s_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(True) + input_z_global_fp32 = input_z_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(True) + mask_global_fp32 = mask_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(False) + pair_mask_global_fp32 = pair_mask_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(False) + reference_module_fp32 = SerialPairformerModule(**pairformer_params) + reference_module_fp32.load_state_dict(module_state_dict_fp64) + reference_module_fp32 = reference_module_fp32.to(dtype=torch.float32, device=device_type).train() + set_dtype_specific_inf_values(reference_module_fp32, torch.float32) + output_s_global_fp32, output_z_global_fp32 = reference_module_fp32( + s=input_s_global_fp32, + z=input_z_global_fp32, + mask=mask_global_fp32, + pair_mask=pair_mask_global_fp32, + ) + d_output_s_expected_global_fp32 = d_output_s_expected_global_fp64.to(dtype=torch.float32) + d_output_z_expected_global_fp32 = d_output_z_expected_global_fp64.to(dtype=torch.float32) + torch.autograd.backward( + [output_s_global_fp32, output_z_global_fp32], + [d_output_s_expected_global_fp32, d_output_z_expected_global_fp32], + ) + output_s_global_fp32_host = output_s_global_fp32.detach().to(device="cpu", copy=True) + output_z_global_fp32_host = output_z_global_fp32.detach().to(device="cpu", copy=True) + d_input_s_global_fp32_host = input_s_global_fp32.grad.detach().to(device="cpu", copy=True) + d_input_z_global_fp32_host = input_z_global_fp32.grad.detach().to(device="cpu", copy=True) + grad_params_fp32_global_host = { + name: param.grad.detach().to(device="cpu", copy=True) + for name, param in reference_module_fp32.named_parameters() + } + else: + output_s_global_fp32_host = None + output_z_global_fp32_host = None + d_input_s_global_fp32_host = None + d_input_z_global_fp32_host = None + grad_params_fp32_global_host = None + + spawn_multiprocessing( + parallel_assert_pairformer_module, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + pairformer_params, + module_state_dict_fp64, + input_s_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + input_z_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + mask_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + pair_mask_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + output_s_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + d_output_s_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + d_output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + input_s_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True), + input_z_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True), + grad_params_fp64_expected_global_host, + output_s_global_fp32_host, + output_z_global_fp32_host, + d_input_s_global_fp32_host, + d_input_z_global_fp32_host, + grad_params_fp32_global_host, + ) + + +def parallel_assert_pairformer_module_activation_checkpointing( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + pairformer_params: dict[str, Any], + min_val_init: float, + max_val_init: float, + input_s_global: torch.Tensor, + input_z_global: torch.Tensor, + mask_global: torch.Tensor, + pair_mask_global: torch.Tensor, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + seed_by_rank(0, seed=42) + + # Create serial reference module - Activation checkpointing is enabled here + pairformer_params["activation_checkpointing"] = True + module_serial = SerialPairformerModule(**pairformer_params) + init_module_params_uniform(module_serial, low=min_val_init, high=max_val_init) + set_dtype_specific_inf_values(module_serial, dtype) + + # Save per-rank state dict which can be re-used for model with activation checkpointing enabled + module_state_dict_ref = module_serial.state_dict() + module_serial = module_serial.to(dtype=dtype, device=manager.device) + + # Create distributed module + module = DistributedPairformerModule(module_serial, manager) + module.train() + + # Input tensors have sharding patterns: + # s: (B, N, token_s) - sharded on dims 0, 1 + # z: (B, N, N, token_z) - sharded on dims 0, 1, 2 + # mask: (B, N) - sharded on dims 0, 1 + # pair_mask: (B, N, N) - sharded on dims 0, 1, 2 + placements_s_mask = (Shard(0), Shard(1), Replicate()) + placements_z_pair_mask = (Shard(0), Shard(1), Shard(2)) + + input_s_dtensor = distribute_tensor( + input_s_global.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + ).requires_grad_(True) + + input_z_dtensor = distribute_tensor( + input_z_global.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ).requires_grad_(True) + + mask_dtensor = distribute_tensor( + mask_global.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_s_mask, + ) + + pair_mask_dtensor = distribute_tensor( + pair_mask_global.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ) + + # Create copies to verify inputs aren't modified + input_s_dtensor_copy = input_s_dtensor.detach().clone().requires_grad_(True) + input_z_dtensor_copy = input_z_dtensor.detach().clone().requires_grad_(True) + mask_dtensor_copy = mask_dtensor.detach().clone() + pair_mask_dtensor_copy = pair_mask_dtensor.detach().clone() + + # Forward pass + output_s_dtensor_result, output_z_dtensor_result = module( + s=input_s_dtensor, + z=input_z_dtensor, + mask=mask_dtensor, + pair_mask=pair_mask_dtensor, + ) + + # Verify inputs weren't modified + assert_tensors_identical( + input_s_dtensor_copy.to_local(), input_s_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical( + input_z_dtensor_copy.to_local(), input_z_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(mask_dtensor_copy.to_local(), mask_dtensor.to_local()) + assert_tensors_identical(pair_mask_dtensor_copy.to_local(), pair_mask_dtensor.to_local()) + + # Backward pass - create output grad tensors with same shape and placements as outputs + d_output_s_dtensor = torch.distributed.tensor.rand( + output_s_dtensor_result.shape, + requires_grad=False, + dtype=dtype, + device_mesh=manager.device_mesh_subgroups, + placements=output_s_dtensor_result.placements, + ) + d_output_z_dtensor = torch.distributed.tensor.rand( + output_z_dtensor_result.shape, + requires_grad=False, + dtype=dtype, + device_mesh=manager.device_mesh_subgroups, + placements=output_z_dtensor_result.placements, + ) + + d_output_s_dtensor_copy = d_output_s_dtensor.detach().clone() + d_output_z_dtensor_copy = d_output_z_dtensor.detach().clone() + + torch.autograd.backward( + [output_s_dtensor_result, output_z_dtensor_result], [d_output_s_dtensor, d_output_z_dtensor] + ) + + # Verify upstream gradients weren't modified + assert_tensors_identical(d_output_s_dtensor_copy.to_local(), d_output_s_dtensor.to_local()) + assert_tensors_identical(d_output_z_dtensor_copy.to_local(), d_output_z_dtensor.to_local()) + + # Reset seed + seed_by_rank(0, seed=42) + + # Create new model with activation checkpointing enabled + pairformer_params["activation_checkpointing"] = True + module_serial_act_ckpt = SerialPairformerModule(**pairformer_params) + module_serial_act_ckpt.load_state_dict(module_state_dict_ref) + set_dtype_specific_inf_values(module_serial_act_ckpt, dtype) + + module_serial_act_ckpt = module_serial_act_ckpt.to(dtype=dtype, device=manager.device) + module_act_ckpt = DistributedPairformerModule(module_serial_act_ckpt, manager) + module_act_ckpt.train() + + # Forward pass + output_s_dtensor_result_act_ckpt, output_z_dtensor_result_act_ckpt = module_act_ckpt( + s=input_s_dtensor_copy, + z=input_z_dtensor_copy, + mask=mask_dtensor_copy, + pair_mask=pair_mask_dtensor_copy, + ) + + # Verify outputs are the same after activation checkpoint fwd and no actv ckpt fwd + assert_tensors_identical( + output_s_dtensor_result_act_ckpt.to_local(), + output_s_dtensor_result.to_local(), + check_grad=False, + check_grad_fn=False, + ) + assert_tensors_identical( + output_z_dtensor_result_act_ckpt.to_local(), + output_z_dtensor_result.to_local(), + check_grad=False, + check_grad_fn=False, + ) + + # Backward pass + torch.autograd.backward( + [output_s_dtensor_result_act_ckpt, output_z_dtensor_result_act_ckpt], [d_output_s_dtensor, d_output_z_dtensor] + ) + + # Verify that input gradients are identical + assert_tensors_identical(input_s_dtensor.grad.to_local(), input_s_dtensor_copy.grad.to_local()) + assert_tensors_identical(input_z_dtensor.grad.to_local(), input_z_dtensor_copy.grad.to_local()) + + # Compare parameter gradients + reference_param_grads_dict = {} + for name, param in module.named_parameters(): + if param.grad is not None: + reference_param_grads_dict[name] = param.grad + + for name, param_act_ckpt in module_act_ckpt.named_parameters(): + assert name in reference_param_grads_dict, f"Parameter {name}'s gradient is not found in reference gradients" + reference_grad = reference_param_grads_dict[name] + assert_tensors_identical(reference_grad.to_local(), param_act_ckpt.grad.to_local()) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, dtype", + ( + params_test := [ + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}" + for x in params_test + ], +) +def test_pairformer_module_parallel_activation_checkpointing(setup_env, dtype): + """ + Test the Pairformer Module with activation checkpointing enabled vs CP without actv ckpt, results should be identical. Test on small model and input size. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + # Small model params - same as main test's small case; dropout non-zero to exercise activation checkpointing + token_s = 8 + token_z = 12 + num_blocks = 2 + num_heads = 4 + pairwise_head_width = 4 + pairwise_num_heads = 2 + dropout = 0.5 + post_layer_norm = False + + pairformer_params = { + "token_s": token_s, + "token_z": token_z, + "num_blocks": num_blocks, + "num_heads": num_heads, + "dropout": dropout, + "pairwise_head_width": pairwise_head_width, + "pairwise_num_heads": pairwise_num_heads, + "post_layer_norm": post_layer_norm, + "activation_checkpointing": True, + "v2": True, + } + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 2 # Number of tokens + min_val_init = -1.0 + max_val_init = 1.0 + + input_s_global = torch.empty((B, N, token_s), dtype=dtype, requires_grad=True, device="cpu") + input_z_global = torch.empty((B, N, N, token_z), dtype=dtype, requires_grad=True, device="cpu") + + mask_global = torch.ones((B, N), dtype=dtype, requires_grad=False, device="cpu") + mask_global[0, N // size_ring :] = 0 + + pair_mask_global = torch.randint(0, 2, (B, N, N), dtype=dtype, requires_grad=False, device="cpu") + pair_mask_global[0, N // size_ring :, :] = 0 + pair_mask_global[0, :, N // size_ring :] = 0 + + init_tensors_uniform([input_s_global, input_z_global], low=min_val_init, high=max_val_init) + + spawn_multiprocessing( + parallel_assert_pairformer_module_activation_checkpointing, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + pairformer_params, + min_val_init, + max_val_init, + input_s_global.detach(), + input_z_global.detach(), + mask_global.detach(), + pair_mask_global.detach(), + ) diff --git a/tests/distributed/model/layers/test_dtensor_pairformer_no_seq_layer.py b/tests/distributed/model/layers/test_dtensor_pairformer_no_seq_layer.py new file mode 100644 index 000000000..e3256c0ff --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_pairformer_no_seq_layer.py @@ -0,0 +1,360 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import pytest +import torch +from torch.distributed.tensor import Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.pairformer import ( + PairformerNoSeqLayer as DistributedPairformerNoSeqLayer, +) +from boltz.model.layers.pairformer import ( + PairformerNoSeqLayer as SerialPairformerNoSeqLayer, +) +from boltz.testing.utils import ( + assert_all_identical, + assert_no_percentile_upshift, + assert_tensors_identical, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + set_dtype_specific_inf_values, + spawn_multiprocessing, +) + + +def parallel_assert_pairformer_noseq_layer( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + token_z, + dropout, + pairwise_head_width, + pairwise_num_heads, + post_layer_norm, + layer_state_dict, + input_z_global_host, + pair_mask_global_host, + output_z_expected_global_host, + d_output_z_expected_global_host, + d_input_z_expected_global_host, + expected_param_grads_global_host_dict, + output_z_global_fp32_host: torch.Tensor | None = None, + d_input_z_global_fp32_host: torch.Tensor | None = None, + grad_params_fp32_global_host: dict[str, torch.Tensor] | None = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + if torch.finfo(dtype).resolution < torch.finfo(output_z_expected_global_host.dtype).resolution: + raise ValueError( + f"Target dtype {dtype} has higher precision than reference output's dtype {output_z_expected_global_host.dtype}" + ) + + if (output_z_global_fp32_host is None) != (d_input_z_global_fp32_host is None) or ( + output_z_global_fp32_host is None + ) != (grad_params_fp32_global_host is None): + raise ValueError( + "output_z_global_fp32_host, d_input_z_global_fp32_host, and grad_params_fp32_global_host " + "must be either all None or all not None" + ) + + check_error_hist = output_z_global_fp32_host is not None + + # Create serial reference module + module_serial = SerialPairformerNoSeqLayer( + token_z=token_z, + dropout=dropout, + pairwise_head_width=pairwise_head_width, + pairwise_num_heads=pairwise_num_heads, + post_layer_norm=post_layer_norm, + ) + module_serial.load_state_dict(layer_state_dict) + set_dtype_specific_inf_values(module_serial, dtype) + module_serial = module_serial.to(dtype=dtype, device=manager.device) + + # Create distributed module + module = DistributedPairformerNoSeqLayer(module_serial, manager) + module.train() + + placements_z_pair_mask = (Shard(0), Shard(1), Shard(2)) + + input_z_dtensor = distribute_tensor( + input_z_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ).requires_grad_(True) + pair_mask_dtensor = distribute_tensor( + pair_mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ) + + d_output_z_expected_dtensor = distribute_tensor( + d_output_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ) + output_z_expected_dtensor = distribute_tensor( + output_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + d_input_z_expected_dtensor = distribute_tensor( + d_input_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + + input_z_dtensor_copy = input_z_dtensor.detach().clone().requires_grad_(True) + pair_mask_dtensor_copy = pair_mask_dtensor.detach().clone() + + if check_error_hist: + output_z_dtensor_result = module(z=input_z_dtensor, pair_mask=pair_mask_dtensor) + output_z_dtensor_result.backward(d_output_z_expected_dtensor) + output_z_fp32_dtensor = distribute_tensor( + output_z_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + d_input_z_fp32_dtensor = distribute_tensor( + d_input_z_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + assert_no_percentile_upshift( + output_z_dtensor_result.to_local(), + output_z_expected_dtensor.to_local(), + output_z_fp32_dtensor.to_local(), + names_input=("output_z_cp_fp32", "output_z_serial_fp64", "output_z_serial_fp32"), + ) + assert_no_percentile_upshift( + input_z_dtensor.grad.to_local(), + d_input_z_expected_dtensor.to_local(), + d_input_z_fp32_dtensor.to_local(), + names_input=("d_input_z_cp_fp32", "d_input_z_serial_fp64", "d_input_z_serial_fp32"), + ) + for name, grad_param_expected_global in expected_param_grads_global_host_dict.items(): + grad_param_result_global = get_param_by_key(module, name).grad.full_tensor().cpu() + assert_no_percentile_upshift( + grad_param_result_global, + grad_param_expected_global.to(dtype=grad_param_result_global.dtype), + grad_params_fp32_global_host[name], + names_input=(f"d_{name}_cp_fp32", f"d_{name}_serial_fp64", f"d_{name}_serial_fp32"), + ) + else: + output_z_dtensor_result = module(z=input_z_dtensor, pair_mask=pair_mask_dtensor) + assert_tensors_identical( + input_z_dtensor_copy.to_local(), input_z_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(pair_mask_dtensor_copy.to_local(), pair_mask_dtensor.to_local()) + torch.testing.assert_close(output_z_dtensor_result.to_local(), output_z_expected_dtensor.to_local()) + + # Clone upstream gradients so we can verify backward does not modify them (match layer test) + d_output_z_expected_dtensor_copy = d_output_z_expected_dtensor.detach().clone() + torch.autograd.backward([output_z_dtensor_result], [d_output_z_expected_dtensor]) + # Verify upstream gradients were not modified + assert_tensors_identical(d_output_z_expected_dtensor_copy.to_local(), d_output_z_expected_dtensor.to_local()) + + torch.testing.assert_close(input_z_dtensor.grad.to_local(), d_input_z_expected_dtensor.to_local()) + + output_z_global_result_host = output_z_dtensor_result.full_tensor().cpu() + d_input_z_global_result_host = input_z_dtensor.grad.full_tensor().cpu() + torch.testing.assert_close(output_z_global_result_host, output_z_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_z_global_result_host, d_input_z_expected_global_host.to(dtype=dtype)) + + result_param_grads_dict = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in expected_param_grads_global_host_dict: + raise ValueError(f"Parameter {name} has a resulting gradient but it is not in the reference") + result_param_grads_dict[name] = param.grad + for name, expected_grad_global_host in expected_param_grads_global_host_dict.items(): + assert name in result_param_grads_dict, f"Parameter {name}'s gradient is not found in result gradients" + result_grad = result_param_grads_dict[name] + result_grad_global = result_grad.full_tensor() + torch.testing.assert_close(result_grad_global.cpu(), expected_grad_global_host.to(dtype=dtype)) + assert_all_identical(result_grad_global, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env, dtype, check_error_hist", + ( + params_test := [ + (((1, (2, 2)), True, "cuda", "ENV"), torch.float32, True), + (((1, (2, 2)), True, "cuda", "ENV"), torch.float64, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, False), + (((1, (3, 3)), True, "cuda", "ENV"), torch.float32, False), + (((1, (3, 3)), True, "cpu", "ENV"), torch.float32, False), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}, check_error_hist:{x[2]}" + for x in params_test + ], +) +def test_pairformer_noseq_layer_parallel(setup_env, dtype, check_error_hist): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + if check_error_hist and grid_group_sizes["dp"] > 1: + pytest.skip("skip error histogram check for dp > 1 to save test time") + + # For float64 and error histogram check, we use a realistic model and input size + # with heavier computation to test the numerical stability. On the other hand, + # a smaller model and input size incur less numerical error accumulation to allow + # a larger range of input values to detect logical bugs inexpensively by using + # smaller dimensions. + test_large_model = check_error_hist or dtype == torch.float64 + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + if test_large_model: + N = size_ring * 128 # Number of tokens (no_seq: pair dimension only) + token_z = 128 # Token pairwise embedding dimension + pairwise_head_width = 32 + pairwise_num_heads = 4 + min_val_init = -0.08 if dtype == torch.float64 else -5e-4 + max_val_init = -min_val_init + else: + N = size_ring * 2 + token_z = 12 # Token pairwise embedding dimension + pairwise_head_width = 4 + pairwise_num_heads = 2 + min_val_init = -0.5 + max_val_init = 0.5 + dropout = 0.0 # disable dropout as we have no way to match the random sequences between serial and CP + post_layer_norm = False + + seed = 42 + seed_by_rank(0, seed=seed) + + # Compute reference results with FP64 + input_z_global_fp64 = torch.empty((B, N, N, token_z), dtype=torch.float64, requires_grad=True, device=device_type) + pair_mask_global_fp64 = torch.randint(0, 2, (B, N, N), dtype=torch.float64, requires_grad=False, device=device_type) + pair_mask_global_fp64[0, N // size_ring :, :] = 0 + pair_mask_global_fp64[0, :, N // size_ring :] = 0 + + # Create reference serial module + reference_module = SerialPairformerNoSeqLayer( + token_z=token_z, + dropout=dropout, + pairwise_head_width=pairwise_head_width, + pairwise_num_heads=pairwise_num_heads, + post_layer_norm=post_layer_norm, + ) + init_tensors_uniform([input_z_global_fp64], low=min_val_init, high=max_val_init) + init_module_params_uniform(reference_module, low=min_val_init, high=max_val_init) + set_dtype_specific_inf_values(reference_module, torch.float64) + layer_state_dict_fp64 = reference_module.state_dict() + reference_module = reference_module.to(dtype=torch.float64, device=device_type).train() + + output_z_expected_global_fp64 = reference_module(input_z_global_fp64, pair_mask_global_fp64) + d_output_z_expected_global_fp64 = torch.rand_like(output_z_expected_global_fp64) + output_z_expected_global_fp64.backward(d_output_z_expected_global_fp64) + + grad_params_fp64_expected_global_host = { + name: param.grad.detach().to(dtype=dtype, device="cpu", copy=True) + for name, param in reference_module.named_parameters() + } + + if check_error_hist: + input_z_global_fp32 = input_z_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(True) + pair_mask_global_fp32 = pair_mask_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(False) + reference_module_fp32 = SerialPairformerNoSeqLayer( + token_z=token_z, + dropout=dropout, + pairwise_head_width=pairwise_head_width, + pairwise_num_heads=pairwise_num_heads, + post_layer_norm=post_layer_norm, + ) + reference_module_fp32.load_state_dict(layer_state_dict_fp64) + reference_module_fp32 = reference_module_fp32.to(dtype=torch.float32, device=device_type).train() + set_dtype_specific_inf_values(reference_module_fp32, torch.float32) + output_z_global_fp32 = reference_module_fp32(input_z_global_fp32, pair_mask_global_fp32) + d_output_z_fp32 = d_output_z_expected_global_fp64.to(dtype=torch.float32) + output_z_global_fp32.backward(d_output_z_fp32) + output_z_global_fp32_host = output_z_global_fp32.detach().to(device="cpu", copy=True) + d_input_z_global_fp32_host = input_z_global_fp32.grad.detach().to(device="cpu", copy=True) + grad_params_fp32_global_host = { + name: param.grad.detach().to(device="cpu", copy=True) + for name, param in reference_module_fp32.named_parameters() + } + else: + output_z_global_fp32_host = None + d_input_z_global_fp32_host = None + grad_params_fp32_global_host = None + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_pairformer_noseq_layer, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + token_z, + dropout, + pairwise_head_width, + pairwise_num_heads, + post_layer_norm, + layer_state_dict_fp64, + input_z_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + pair_mask_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + d_output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + input_z_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True), + grad_params_fp64_expected_global_host, + output_z_global_fp32_host, + d_input_z_global_fp32_host, + grad_params_fp32_global_host, + ) diff --git a/tests/distributed/model/layers/test_dtensor_pairformer_no_seq_module.py b/tests/distributed/model/layers/test_dtensor_pairformer_no_seq_module.py new file mode 100644 index 000000000..7ebd5d3f6 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_pairformer_no_seq_module.py @@ -0,0 +1,555 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from typing import Any + +import pytest +import torch +from torch.distributed.tensor import Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.pairformer import ( + PairformerNoSeqModule as DistributedPairformerNoSeqModule, +) +from boltz.model.layers.pairformer import ( + PairformerNoSeqModule as SerialPairformerNoSeqModule, +) +from boltz.testing.utils import ( + assert_all_identical, + assert_no_percentile_upshift, + assert_tensors_identical, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + set_dtype_specific_inf_values, + spawn_multiprocessing, +) + + +def parallel_assert_pairformer_noseq_module( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + pairformer_params: dict[str, Any], + module_state_dict, + input_z_global_host, + pair_mask_global_host, + output_z_expected_global_host, + d_output_z_expected_global_host, + d_input_z_expected_global_host, + expected_param_grads_global_host_dict, + output_z_global_fp32_host: torch.Tensor | None = None, + d_input_z_global_fp32_host: torch.Tensor | None = None, + grad_params_fp32_global_host: dict[str, torch.Tensor] | None = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + if torch.finfo(dtype).resolution < torch.finfo(output_z_expected_global_host.dtype).resolution: + raise ValueError( + f"Target dtype {dtype} has higher precision than reference output's dtype {output_z_expected_global_host.dtype}" + ) + + if (output_z_global_fp32_host is None) != (d_input_z_global_fp32_host is None) or ( + output_z_global_fp32_host is None + ) != (grad_params_fp32_global_host is None): + raise ValueError( + "output_z_global_fp32_host, d_input_z_global_fp32_host, and grad_params_fp32_global_host " + "must be either all None or all not None" + ) + + check_error_hist = output_z_global_fp32_host is not None + + # Create serial reference module + module_serial = SerialPairformerNoSeqModule(**pairformer_params) + module_serial.load_state_dict(module_state_dict) + set_dtype_specific_inf_values(module_serial, dtype) + module_serial = module_serial.to(dtype=dtype, device=manager.device) + + # Create distributed module + module = DistributedPairformerNoSeqModule(module_serial, manager) + module.train() + + placements_z_pair_mask = (Shard(0), Shard(1), Shard(2)) + + input_z_dtensor = distribute_tensor( + input_z_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ).requires_grad_(True) + pair_mask_dtensor = distribute_tensor( + pair_mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ) + + d_output_z_expected_dtensor = distribute_tensor( + d_output_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ) + output_z_expected_dtensor = distribute_tensor( + output_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + d_input_z_expected_dtensor = distribute_tensor( + d_input_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + + input_z_dtensor_copy = input_z_dtensor.detach().clone().requires_grad_(True) + pair_mask_dtensor_copy = pair_mask_dtensor.detach().clone() + + if check_error_hist: + output_z_dtensor_result = module(z=input_z_dtensor, pair_mask=pair_mask_dtensor) + output_z_dtensor_result.backward(d_output_z_expected_dtensor) + output_z_fp32_dtensor = distribute_tensor( + output_z_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + d_input_z_fp32_dtensor = distribute_tensor( + d_input_z_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + src_data_rank=None, + ) + assert_no_percentile_upshift( + output_z_dtensor_result.to_local(), + output_z_expected_dtensor.to_local(), + output_z_fp32_dtensor.to_local(), + names_input=("output_z_cp_fp32", "output_z_serial_fp64", "output_z_serial_fp32"), + ) + assert_no_percentile_upshift( + input_z_dtensor.grad.to_local(), + d_input_z_expected_dtensor.to_local(), + d_input_z_fp32_dtensor.to_local(), + names_input=("d_input_z_cp_fp32", "d_input_z_serial_fp64", "d_input_z_serial_fp32"), + ) + for name, grad_param_expected_global in expected_param_grads_global_host_dict.items(): + grad_param_result_global = get_param_by_key(module, name).grad.full_tensor().cpu() + assert_no_percentile_upshift( + grad_param_result_global, + grad_param_expected_global.to(dtype=grad_param_result_global.dtype), + grad_params_fp32_global_host[name], + names_input=(f"d_{name}_cp_fp32", f"d_{name}_serial_fp64", f"d_{name}_serial_fp32"), + ) + else: + output_z_dtensor_result = module(z=input_z_dtensor, pair_mask=pair_mask_dtensor) + assert_tensors_identical( + input_z_dtensor_copy.to_local(), input_z_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(pair_mask_dtensor_copy.to_local(), pair_mask_dtensor.to_local()) + torch.testing.assert_close(output_z_dtensor_result.to_local(), output_z_expected_dtensor.to_local()) + + # Clone upstream gradients so we can verify backward does not modify them (match module test) + d_output_z_expected_dtensor_copy = d_output_z_expected_dtensor.detach().clone() + torch.autograd.backward([output_z_dtensor_result], [d_output_z_expected_dtensor]) + # Verify upstream gradients were not modified + assert_tensors_identical(d_output_z_expected_dtensor_copy.to_local(), d_output_z_expected_dtensor.to_local()) + + torch.testing.assert_close(input_z_dtensor.grad.to_local(), d_input_z_expected_dtensor.to_local()) + + output_z_global_result_host = output_z_dtensor_result.full_tensor().cpu() + d_input_z_global_result_host = input_z_dtensor.grad.full_tensor().cpu() + torch.testing.assert_close(output_z_global_result_host, output_z_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_z_global_result_host, d_input_z_expected_global_host.to(dtype=dtype)) + + result_param_grads_dict = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in expected_param_grads_global_host_dict: + raise ValueError(f"Parameter {name} has a resulting gradient but it is not in the reference") + result_param_grads_dict[name] = param.grad + for name, expected_grad_global_host in expected_param_grads_global_host_dict.items(): + assert name in result_param_grads_dict, f"Parameter {name}'s gradient is not found in result gradients" + result_grad = result_param_grads_dict[name] + result_grad_global = result_grad.full_tensor() + torch.testing.assert_close(result_grad_global.cpu(), expected_grad_global_host.to(dtype=dtype)) + assert_all_identical(result_grad_global, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env, dtype, check_error_hist, activation_checkpointing", + ( + params_test := [ + (((1, (2, 2)), True, "cuda", "ENV"), torch.float32, True, False), + (((1, (2, 2)), True, "cuda", "ENV"), torch.float64, False, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, True, False), + (((2, (2, 2)), True, "cpu", "ENV"), torch.float32, False, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, False, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, False, True), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}, check_error_hist:{x[2]}, activation_checkpointing:{x[3]}" + for x in params_test + ], +) +def test_pairformer_noseq_module_parallel(setup_env, dtype, check_error_hist, activation_checkpointing): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + if check_error_hist and grid_group_sizes["dp"] > 1: + pytest.skip("skip error histogram check for dp > 1 to save test time") + + # For float64 and error histogram check, we use a realistic model and input size + # with heavier computation to test the numerical stability. On the other hand, + # a smaller model and input size incur less numerical error accumulation to allow + # a larger range of input values to detect logical bugs inexpensively by using + # smaller dimensions. + test_large_model = check_error_hist or dtype == torch.float64 + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + if test_large_model: + N = size_ring * 64 # Number of tokens (no_seq: pair dimension only) + token_z = 128 + num_blocks = 2 + pairwise_head_width = 32 + pairwise_num_heads = 4 + min_val_init = -0.05 + max_val_init = 0.05 + else: + N = size_ring * 2 + token_z = 12 + num_blocks = 2 + pairwise_head_width = 4 + pairwise_num_heads = 2 + min_val_init = -0.5 + max_val_init = 0.5 + dropout = 0.0 + post_layer_norm = False + + pairformer_params = { + "token_z": token_z, + "num_blocks": num_blocks, + "dropout": dropout, + "pairwise_head_width": pairwise_head_width, + "pairwise_num_heads": pairwise_num_heads, + "post_layer_norm": post_layer_norm, + "activation_checkpointing": activation_checkpointing, + } + + seed = 42 + seed_by_rank(0, seed=seed) + + # Compute reference results with FP64 + input_z_global_fp64 = torch.empty((B, N, N, token_z), dtype=torch.float64, requires_grad=True, device=device_type) + pair_mask_global_fp64 = torch.randint(0, 2, (B, N, N), dtype=torch.float64, requires_grad=False, device=device_type) + pair_mask_global_fp64[0, N // size_ring :, :] = 0 + pair_mask_global_fp64[0, :, N // size_ring :] = 0 + + # Create reference serial module + reference_module = SerialPairformerNoSeqModule(**pairformer_params) + # Initialize parameters to ensure reproducible behavior + init_tensors_uniform([input_z_global_fp64], low=min_val_init, high=max_val_init) + init_module_params_uniform(reference_module, low=min_val_init, high=max_val_init) + set_dtype_specific_inf_values(reference_module, torch.float64) + module_state_dict_fp64 = reference_module.state_dict() + reference_module = reference_module.to(dtype=torch.float64, device=device_type).train() + + # Run forward pass + output_z_expected_global_fp64 = reference_module(input_z_global_fp64, pair_mask_global_fp64) + d_output_z_expected_global_fp64 = torch.rand_like(output_z_expected_global_fp64) + output_z_expected_global_fp64.backward(d_output_z_expected_global_fp64) + + grad_params_fp64_expected_global_host = { + name: param.grad.detach().to(dtype=dtype, device="cpu", copy=True) + for name, param in reference_module.named_parameters() + } + + del reference_module + if device_type == "cuda": + torch.cuda.empty_cache() + + if check_error_hist: + input_z_global_fp32 = input_z_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(True) + pair_mask_global_fp32 = pair_mask_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(False) + reference_module_fp32 = SerialPairformerNoSeqModule(**pairformer_params) + reference_module_fp32.load_state_dict(module_state_dict_fp64) + reference_module_fp32 = reference_module_fp32.to(dtype=torch.float32, device=device_type).train() + set_dtype_specific_inf_values(reference_module_fp32, torch.float32) + output_z_global_fp32 = reference_module_fp32(input_z_global_fp32, pair_mask_global_fp32) + d_output_z_fp32 = d_output_z_expected_global_fp64.to(dtype=torch.float32) + output_z_global_fp32.backward(d_output_z_fp32) + output_z_global_fp32_host = output_z_global_fp32.detach().to(device="cpu", copy=True) + d_input_z_global_fp32_host = input_z_global_fp32.grad.detach().to(device="cpu", copy=True) + grad_params_fp32_global_host = { + name: param.grad.detach().to(device="cpu", copy=True) + for name, param in reference_module_fp32.named_parameters() + } + else: + output_z_global_fp32_host = None + d_input_z_global_fp32_host = None + grad_params_fp32_global_host = None + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_pairformer_noseq_module, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + pairformer_params, + module_state_dict_fp64, + input_z_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + pair_mask_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + d_output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + input_z_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True), + grad_params_fp64_expected_global_host, + output_z_global_fp32_host, + d_input_z_global_fp32_host, + grad_params_fp32_global_host, + ) + + +def parallel_assert_pairformer_noseq_module_activation_checkpointing( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + pairformer_params: dict[str, Any], + min_val_init: float, + max_val_init: float, + input_z_global: torch.Tensor, + pair_mask_global: torch.Tensor, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + seed_by_rank(0, seed=42) + + # Create serial reference module - Activation checkpointing is enabled here + pairformer_params["activation_checkpointing"] = True + module_serial = SerialPairformerNoSeqModule(**pairformer_params) + init_module_params_uniform(module_serial, low=min_val_init, high=max_val_init) + set_dtype_specific_inf_values(module_serial, dtype) + + # Save per-rank state dict which can be re-used for model with activation checkpointing enabled + module_state_dict_ref = module_serial.state_dict() + module_serial = module_serial.to(dtype=dtype, device=manager.device) + + # Create distributed module + module = DistributedPairformerNoSeqModule(module_serial, manager) + module.train() + + placements_z_pair_mask = (Shard(0), Shard(1), Shard(2)) + + input_z_dtensor = distribute_tensor( + input_z_global.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ).requires_grad_(True) + + pair_mask_dtensor = distribute_tensor( + pair_mask_global.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_pair_mask, + ) + + # Create copies to verify inputs aren't modified + input_z_dtensor_copy = input_z_dtensor.detach().clone().requires_grad_(True) + pair_mask_dtensor_copy = pair_mask_dtensor.detach().clone() + + # Forward pass + output_z_dtensor_result = module(z=input_z_dtensor, pair_mask=pair_mask_dtensor) + + # Verify inputs weren't modified + assert_tensors_identical( + input_z_dtensor_copy.to_local(), input_z_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(pair_mask_dtensor_copy.to_local(), pair_mask_dtensor.to_local()) + + # Backward pass - create output grad tensor with same shape and placements as output + d_output_z_dtensor = torch.distributed.tensor.rand( + output_z_dtensor_result.shape, + requires_grad=False, + dtype=dtype, + device_mesh=manager.device_mesh_subgroups, + placements=output_z_dtensor_result.placements, + ) + + d_output_z_dtensor_copy = d_output_z_dtensor.detach().clone() + + torch.autograd.backward([output_z_dtensor_result], [d_output_z_dtensor]) + + # Verify upstream gradients weren't modified + assert_tensors_identical(d_output_z_dtensor_copy.to_local(), d_output_z_dtensor.to_local()) + + # Reset seed + seed_by_rank(0, seed=42) + + # Create new model with activation checkpointing enabled + pairformer_params["activation_checkpointing"] = True + module_serial_act_ckpt = SerialPairformerNoSeqModule(**pairformer_params) + module_serial_act_ckpt.load_state_dict(module_state_dict_ref) + set_dtype_specific_inf_values(module_serial_act_ckpt, dtype) + + module_serial_act_ckpt = module_serial_act_ckpt.to(dtype=dtype, device=manager.device) + module_act_ckpt = DistributedPairformerNoSeqModule(module_serial_act_ckpt, manager) + module_act_ckpt.train() + + # Forward pass + output_z_dtensor_result_act_ckpt = module_act_ckpt(z=input_z_dtensor_copy, pair_mask=pair_mask_dtensor_copy) + + # Verify outputs are the same after activation checkpoint fwd and no actv ckpt fwd + assert_tensors_identical( + output_z_dtensor_result_act_ckpt.to_local(), + output_z_dtensor_result.to_local(), + check_grad=False, + check_grad_fn=False, + ) + + # Backward pass + torch.autograd.backward([output_z_dtensor_result_act_ckpt], [d_output_z_dtensor]) + + # Verify that input gradients are identical + assert_tensors_identical(input_z_dtensor.grad.to_local(), input_z_dtensor_copy.grad.to_local()) + + # Compare parameter gradients + reference_param_grads_dict = {} + for name, param in module.named_parameters(): + if param.grad is not None: + reference_param_grads_dict[name] = param.grad + + for name, param_act_ckpt in module_act_ckpt.named_parameters(): + assert name in reference_param_grads_dict, f"Parameter {name}'s gradient is not found in reference gradients" + reference_grad = reference_param_grads_dict[name] + assert_tensors_identical(reference_grad.to_local(), param_act_ckpt.grad.to_local()) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, dtype", + ( + params_test := [ + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}" + for x in params_test + ], +) +def test_pairformer_noseq_module_parallel_activation_checkpointing(setup_env, dtype): + """ + Test the PairformerNoSeq Module with activation checkpointing enabled vs CP without actv ckpt, results should be identical. Test on small model and input size. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + # Small model params - same as main test's small case; dropout non-zero to exercise activation checkpointing + token_z = 12 + num_blocks = 2 + pairwise_head_width = 4 + pairwise_num_heads = 2 + dropout = 0.5 + post_layer_norm = False + + pairformer_params = { + "token_z": token_z, + "num_blocks": num_blocks, + "dropout": dropout, + "pairwise_head_width": pairwise_head_width, + "pairwise_num_heads": pairwise_num_heads, + "post_layer_norm": post_layer_norm, + "activation_checkpointing": True, + } + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 2 # Number of tokens + min_val_init = -1.0 + max_val_init = 1.0 + + input_z_global = torch.empty((B, N, N, token_z), dtype=dtype, requires_grad=True, device="cpu") + pair_mask_global = torch.randint(0, 2, (B, N, N), dtype=dtype, requires_grad=False, device="cpu") + pair_mask_global[0, N // size_ring :, :] = 0 + pair_mask_global[0, :, N // size_ring :] = 0 + + init_tensors_uniform([input_z_global], low=min_val_init, high=max_val_init) + + spawn_multiprocessing( + parallel_assert_pairformer_noseq_module_activation_checkpointing, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + pairformer_params, + min_val_init, + max_val_init, + input_z_global.detach(), + pair_mask_global.detach(), + ) diff --git a/tests/distributed/model/layers/test_dtensor_redistribute_transpose.py b/tests/distributed/model/layers/test_dtensor_redistribute_transpose.py new file mode 100755 index 000000000..df596bfbd --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_redistribute_transpose.py @@ -0,0 +1,494 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from math import isqrt +from typing import Dict, Optional + +import pytest +import torch +from torch.distributed.tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.redistribute_transpose import redistribute_transpose +from boltz.testing.utils import assert_tensors_identical, seed_by_rank, spawn_multiprocessing + + +def compute_global_expectation(shape, input_placements, output_placements, dim0, dim1, device, device_mesh_shape): + """Compute global expectation using standard PyTorch operations.""" + # Create tensor for operations + x = torch.rand(*shape, device=device, requires_grad=True) + + # Determine the type of operation based on placements and transpose dimensions + has_redistribute = output_placements is not None + has_local_transpose = dim0 is not None and dim1 is not None + is_all_replicate = all(isinstance(p, type(Replicate())) for p in input_placements) + + # Compute on global tensor based on operation semantics: + # 1. redistribute only: no change to global tensor + # 2. local transpose only + # 3. redistribute + local transpose: global tensor transpose + # 4. all-replicate + local transpose: global tensor transpose + if has_redistribute and has_local_transpose: + # Case 3: redistribute + local transpose = global transpose when dim{0, 1} are sharded + assert Shard(dim0) in input_placements and Shard(dim1) in input_placements + y = torch.transpose(x, dim0=dim0, dim1=dim1) + elif not has_redistribute and has_local_transpose and is_all_replicate: + # Case 4: all-replicate + local transpose = global transpose + y = torch.transpose(x, dim0=dim0, dim1=dim1) + elif not has_redistribute and has_local_transpose: + # Case 2: local transpose only -- dim{0, 1} can't be sharded + assert Shard(dim0) not in input_placements and Shard(dim1) not in input_placements + y = torch.transpose(x, dim0=dim0, dim1=dim1) + else: + # Case 1: redistribute only or no-op = identity + y = x + + # Create gradients for backward pass + dy = torch.rand_like(y) + + # Backward pass on global tensor + y.backward(dy) + + # Collect input gradient + input_grad = x.grad.detach().clone() + + return x.detach().clone(), y.detach().clone(), input_grad, dy.detach().clone() + + +def compute_dtensor_native( + x_global: torch.Tensor, + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + output_placements: Optional[tuple[Placement, ...]], + dim0: Optional[int], + dim1: Optional[int], +) -> tuple[DTensor, DTensor]: + """Compute DTensor native operations for comparison.""" + # Create DTensor native input + x_dtensor = distribute_tensor(x_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + + # Apply redistribute if output_placements is specified + if output_placements is not None: + if output_placements != x_dtensor.placements: + y_dtensor_result = x_dtensor.redistribute(device_mesh, output_placements) + else: + # this can only work if dim{0, 1} are sharded + assert Shard(dim0) in x_dtensor.placements and Shard(dim1) in x_dtensor.placements + # swap the shard placements of dim0 and dim1 + output_placements_ = list(output_placements) + i_axis_mesh_shard_dim0 = output_placements_.index(Shard(dim0)) + i_axis_mesh_shard_dim1 = output_placements_.index(Shard(dim1)) + output_placements_[i_axis_mesh_shard_dim0] = Shard(dim1) + output_placements_[i_axis_mesh_shard_dim1] = Shard(dim0) + y_dtensor_result = x_dtensor.redistribute(device_mesh, tuple(output_placements_)) + else: + y_dtensor_result = x_dtensor + + # Apply local transpose if dim0 and dim1 are specified + if dim0 is not None and dim1 is not None: + y_dtensor_result = torch.transpose(y_dtensor_result, dim0=dim0, dim1=dim1) + # assert view semantics when no redistribute + if output_placements is None: + assert y_dtensor_result.to_local().is_set_to(x_dtensor.to_local().transpose(dim0=dim0, dim1=dim1)) + + # Backward pass with native DTensor op + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + y_dtensor_result.backward(dy_dtensor) + + # assert view semantics of gradients if no redistribute + if output_placements is None: + if dim0 is not None and dim1 is not None: + assert x_dtensor.grad.to_local().is_set_to(dy_dtensor.to_local().transpose(dim0=dim0, dim1=dim1)) + else: + # DTensor gradient is not view of the upstream adjoint for no-op case + assert not x_dtensor.grad.to_local().is_set_to(dy_dtensor.to_local()) + + return x_dtensor.grad, y_dtensor_result + + +def compute_redistribute_transpose_with_validation( + x_global: torch.Tensor, + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + output_placements: Optional[tuple[Placement, ...]], + transpose_comm: Optional[TransposeComm], + dim0: Optional[int], + dim1: Optional[int], + label_test_case: str, +) -> tuple[DTensor, DTensor, DTensor]: + """ + Compute redistribute_transpose forward and backward pass with input validation checks. + + Returns: + y_dtensor_result: Forward pass result + x_dtensor: Input tensor with computed gradient + dy_dtensor: Distributed upstream gradient + """ + # Create DTensor input + x_dtensor = distribute_tensor(x_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + x_dtensor_copy = x_dtensor.detach().clone().requires_grad_(True) + + # Compute on distributed tensor using redistribute_transpose + y_dtensor_result = redistribute_transpose(x_dtensor, transpose_comm, output_placements, dim0, dim1) + + # verify no change to the fwd input + assert_tensors_identical(x_dtensor.to_local(), x_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # assert view semantics of fwd result if no redistribute + if output_placements is None: + if dim0 is not None and dim1 is not None: + assert y_dtensor_result.to_local().is_set_to(x_dtensor.to_local().transpose(dim0=dim0, dim1=dim1)) + else: + assert y_dtensor_result.to_local().is_set_to(x_dtensor.to_local()) + + # Distribute the upstream adjoint for backward pass + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + + # Perform backward pass + dy_dtensor_copy = dy_dtensor.detach().clone() + y_dtensor_result.backward(dy_dtensor) + + # verify no change to the bwd input + assert_tensors_identical(dy_dtensor.to_local(), dy_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # assert view semantics of bwd result if no redistribute + if output_placements is None: + if dim0 is not None and dim1 is not None: + assert x_dtensor.grad.to_local().is_set_to(dy_dtensor.to_local().transpose(dim0=dim0, dim1=dim1)) + else: + # For some reason, our implementation also follows the native DTensor's semantics + # for the no-op case + assert not x_dtensor.grad.to_local().is_set_to(dy_dtensor.to_local()) + + # verify input gradient placements are consistent with input placements + assert ( + x_dtensor.grad.placements == input_placements + ), f"{label_test_case} inconsistent input gradient placements with input placements" + + return y_dtensor_result, x_dtensor, dy_dtensor + + +def parallel_assert_dtensor_redistribute_transpose( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # each rank uses the same seed to generate the same input tensors + seed_by_rank(0, seed=42) + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters - 8D tensor for comprehensive testing + shape = (3, 5, grid_group_sizes["dp"] * 2, 5, size_ring * 4, 5, size_ring * 3, 2) + + # Create TransposeComm for tests that need it + layout_group_cp = manager.layout_subgroups["cp"] + transpose_comm = TransposeComm(manager.group["cp"], layout_group_cp) + + invalid_test_cases = [ + ( + (Shard(2), Shard(4), Shard(6)), + (Shard(2), Shard(4), Shard(6)), + 2, + 6, + ValueError, + "Inconsistent device mesh coordinate.*", + manager.subgroups_rank["cp"][0] != manager.group_rank["dp"], # otherwise assertion trivially pass + ), # cross DP - CP mesh transpose will raise if the CP transpose_comm + ( + (Shard(2), Shard(4), Shard(6)), + None, + 4, + 0, + NotImplementedError, + "Local transpose on sharded dimensions.*", + True, + ), # local transpose on sharded dimensions will raise + ( + (Shard(2), Shard(4), Shard(6)), + (Shard(2), Shard(3), Shard(2)), + 4, + 0, + ValueError, + "Simultaneous redistribute and local transpose is only supported.*", + True, + ), # redistribute and local transpose without same output placements will raise + ( + (Shard(2), Shard(4), Shard(6)), + (Shard(2), Shard(4), Shard(6)), + 4, + 0, + ValueError, + "Both dim0 and dim1 must be sharded.*", + True, + ), # redistribute and local transpose on non-sharded dimensions will raise + ( + (Shard(2), Shard(4), Shard(6)), + (Replicate(), Replicate(), Replicate()), + None, + None, + ValueError, + "Input and output placements are not strictly a permutation of each other.*", + True, + ), # redistribute other than device mesh transpose will raise + ] + + # Define test cases based on user specification + # Format: (input_placements, output_placements, dim0, dim1, description) + valid_test_cases = [ + # Global transpose of sharded dtensor along both sharded axes (implies a redistribute) + ( + (Shard(2), Shard(4), Shard(6)), + (Shard(2), Shard(4), Shard(6)), + 4, + 6, + "global transpose along both sharded axes (4,6)", + ), + # Redistribute of sharded single representation but no local transpose + ( + (Shard(2), Shard(4), Replicate()), + (Shard(2), Replicate(), Shard(4)), + None, + None, + "redistribute only S(2),S(4),R -> S(2),R,S(4)", + ), + ( + (Shard(2), Replicate(), Shard(4)), + (Shard(2), Shard(4), Replicate()), + None, + None, + "redistribute only S(2),R,S(4) -> S(2),S(4),R", + ), + # Local transpose only + ((Shard(2), Replicate(), Replicate()), None, 0, 3, "local transpose only (0,3)"), + ((Shard(2), Replicate(), Replicate()), None, 4, 6, "local transpose only (4,6)"), + # No op + ((Shard(2), Replicate(), Replicate()), None, None, None, "no op"), + # Local transpose of all-replicate dtensor implying global transpose + ((Replicate(), Replicate(), Replicate()), None, 0, 3, "local transpose of all-replicate (0,3)"), + ] + + for input_placements, output_placements, dim0, dim1, description in valid_test_cases: + label_test_case = f"for {description}\n" + + # Determine if we need transpose_comm for this test case + needs_transpose_comm = output_placements is not None + current_transpose_comm = transpose_comm if needs_transpose_comm else None + + # Compute global expectations + x_global, y_expected_global, x_grad_expected_global, dy_global = compute_global_expectation( + shape, input_placements, output_placements, dim0, dim1, manager.device, manager.device_mesh_subgroups.shape + ) + + # Use DTensor native op as an alternative reference + x_grad_dtensor_native, y_dtensor_result_native = compute_dtensor_native( + x_global, dy_global, manager.device_mesh_subgroups, input_placements, output_placements, dim0, dim1 + ) + + # Compute redistribute_transpose forward and backward with validation + y_dtensor_result, x_dtensor, dy_dtensor = compute_redistribute_transpose_with_validation( + x_global, + dy_global, + manager.device_mesh_subgroups, + input_placements, + output_placements, + current_transpose_comm, + dim0, + dim1, + label_test_case, + ) + + # =================================================================== + # BLOCK 1: Check against DTensor native reference + # =================================================================== + + # check metadata against DTensor native + assert ( + y_dtensor_result.placements == y_dtensor_result_native.placements + ), f"{label_test_case} placements mismatch" + assert y_dtensor_result.shape == y_dtensor_result_native.shape, f"{label_test_case} shape mismatch" + assert y_dtensor_result.stride() == y_dtensor_result_native.stride(), f"{label_test_case} stride mismatch" + + # compare forward result with native DTensor op + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_result_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # compare global tensors between redistribute_transpose and native DTensor results + y_result_global = y_dtensor_result.full_tensor() + y_result_global_native = y_dtensor_result_native.full_tensor() + + torch.testing.assert_close( + y_result_global, + y_result_global_native, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs native: {m}", + ) + + # assert input gradient metadata and values against DTensor native + assert ( + x_dtensor.grad.placements == x_grad_dtensor_native.placements + ), f"{label_test_case} input gradient placements mismatch" + assert x_dtensor.grad.shape == x_grad_dtensor_native.shape, f"{label_test_case} input gradient shape mismatch" + assert ( + x_dtensor.grad.stride() == x_grad_dtensor_native.stride() + ), f"{label_test_case} input gradient stride mismatch" + + torch.testing.assert_close( + x_dtensor.grad.to_local(), + x_grad_dtensor_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + torch.testing.assert_close( + x_dtensor.grad.full_tensor(), + x_grad_dtensor_native.full_tensor(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + # =================================================================== + # BLOCK 2: Check against global serial expectation + # =================================================================== + y_dtensor_expected = distribute_tensor( + y_expected_global, manager.device_mesh_subgroups, y_dtensor_result.placements + ) + + # Compare results with expected local shards + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_expected.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # compare forward result with global expectation + torch.testing.assert_close( + y_result_global, + y_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs global expectation: {m}", + ) + + # create distributed tensor from global result for local shard comparison + x_grad_expected_dtensor = distribute_tensor( + x_grad_expected_global, manager.device_mesh_subgroups, input_placements + ) + + # compare local shard with expected + torch.testing.assert_close( + x_dtensor.grad.to_local(), + x_grad_expected_dtensor.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + torch.testing.assert_close( + x_dtensor.grad.full_tensor(), + x_grad_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + # Test invalid cases to assert raise + for input_placements, output_placements, dim0, dim1, error, raise_msg, raise_condition in invalid_test_cases: + label_test_case = f"for {error.__name__} {raise_msg}\n" + + x_global = torch.rand(*shape, device=manager.device, requires_grad=True) + + # Create DTensor input + x_dtensor = distribute_tensor(x_global, manager.device_mesh_subgroups, input_placements) + x_dtensor.requires_grad = True + + # Determine if we need transpose_comm for this test case + needs_transpose_comm = output_placements is not None + current_transpose_comm = transpose_comm if needs_transpose_comm else None + + # This should raise due to sharded dimension being unflattened + if raise_condition: + with pytest.raises(error, match=raise_msg): + redistribute_transpose(x_dtensor, current_transpose_comm, output_placements, dim0, dim1) + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (1, 1)), True, "cuda", "ENV"), # 1 GPU (serial equiv) + ((2, (1, 1)), True, "cuda", "ENV"), # 2 GPUs, dp=2, cp=1x1 + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_dtensor_redistribute_transpose(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_dtensor_redistribute_transpose, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) diff --git a/tests/distributed/model/layers/test_dtensor_repeat_interleave.py b/tests/distributed/model/layers/test_dtensor_repeat_interleave.py new file mode 100644 index 000000000..13900e26b --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_repeat_interleave.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from typing import Dict, Optional + +import pytest +import torch +from torch.distributed.tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.repeat_interleave import shardwise_repeat_interleave +from boltz.testing.utils import assert_tensors_identical, seed_by_rank, spawn_multiprocessing + + +def compute_global_expectation(shape, repeats, dim_to_repeat, device): + """Compute global expectation using standard PyTorch operations.""" + # Create input tensor + x = torch.rand(*shape, device=device, requires_grad=True) + + # Compute on global tensor using standard repeat_interleave operation + y = torch.repeat_interleave(x, repeats=repeats, dim=dim_to_repeat) + + # Create gradients for backward pass + dy = torch.rand_like(y) + + # Backward pass on global tensor + y.backward(dy) + + # Collect input gradient + input_grad = x.grad.detach().clone() + + return x.detach().clone(), y.detach().clone(), input_grad, dy.detach().clone() + + +def compute_dtensor_native( + input_global: torch.Tensor, + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + repeats: int, + dim_to_repeat: int, +) -> tuple[DTensor, DTensor]: + """Compute DTensor native operations for comparison.""" + # Create DTensor native input + input_dtensor = distribute_tensor(input_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + + # Forward pass with native DTensor repeat_interleave operation + y_dtensor_result = torch.repeat_interleave(input_dtensor, repeats=repeats, dim=dim_to_repeat) + + # Backward pass with native DTensor op + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + y_dtensor_result.backward(dy_dtensor) + + input_grad_dtensor = input_dtensor.grad + + return input_grad_dtensor, y_dtensor_result + + +def compute_shardwise_repeat_interleave_with_validation( + input_global: torch.Tensor, + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + repeats: int, + dim_to_repeat: int, + label_test_case: str, +) -> tuple[DTensor, DTensor, DTensor]: + """ + Compute shardwise_repeat_interleave forward and backward pass with input validation checks. + + Returns: + y_dtensor_result: Forward pass result + input_dtensor: Input tensor with computed gradient + dy_dtensor: Distributed upstream gradient + """ + # Create DTensor input + input_dtensor = distribute_tensor(input_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + input_dtensor_copy = input_dtensor.detach().clone().requires_grad_(True) + + # Compute on distributed tensor using shardwise_repeat_interleave + y_dtensor_result = shardwise_repeat_interleave(input_dtensor, repeats, dim_to_repeat) + + # verify no change to the fwd input + assert_tensors_identical( + input_dtensor.to_local(), input_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False + ) + + # Distribute the upstream adjoint for backward pass + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + + # Perform backward pass + dy_dtensor_copy = dy_dtensor.detach().clone() + y_dtensor_result.backward(dy_dtensor) + + # verify no change to the bwd input + assert_tensors_identical(dy_dtensor.to_local(), dy_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # verify input gradient placements are consistent with input placements + assert ( + input_dtensor.grad.placements == input_placements + ), f"{label_test_case} inconsistent input gradient placements with input placements" + + return y_dtensor_result, input_dtensor, dy_dtensor + + +def parallel_assert_dtensor_repeat_interleave( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # each rank uses the same seed to generate the same input tensors + seed_by_rank(0, seed=42) + + size_ring = len(manager.subgroups_ranks["cp"][0]) + + # Set test parameters + shape = (3, grid_group_sizes["dp"] * 2, 5, size_ring * 4, 5) + repeats = 3 # Number of times to repeat each element + # Shard dimensions for input tensor + # this emulates the sharded single representation in the Boltz model + input_placements = (Shard(dim=1), Shard(dim=3), Replicate()) + + # Test all dimensions (all are valid since sharded dimensions are evenly divided) + valid_dims_to_repeat = [0, 1, 2, 3, 4, -1, -2, -3, -4, -5] + + # Test valid repeat_interleave dimensions + for dim_to_repeat in valid_dims_to_repeat: + label_test_case = f"for dim {dim_to_repeat}\n" + + # Compute global expectations + input_global, y_expected_global, input_grad_expected_global, dy_global = compute_global_expectation( + shape, repeats, dim_to_repeat, manager.device + ) + + # use DTensor native op as an alternative reference + input_grad_dtensor_native, y_dtensor_result_native = compute_dtensor_native( + input_global, dy_global, manager.device_mesh_subgroups, input_placements, repeats, dim_to_repeat + ) + + # Compute shardwise_repeat_interleave forward and backward with validation + y_dtensor_result, input_dtensor, dy_dtensor = compute_shardwise_repeat_interleave_with_validation( + input_global, + dy_global, + manager.device_mesh_subgroups, + input_placements, + repeats, + dim_to_repeat, + label_test_case, + ) + + # =================================================================== + # BLOCK 1: Check against DTensor native reference + # =================================================================== + + # check metadata against DTensor native + assert ( + y_dtensor_result.placements == y_dtensor_result_native.placements + ), f"{label_test_case} placements mismatch" + assert y_dtensor_result.shape == y_dtensor_result_native.shape, f"{label_test_case} shape mismatch" + assert y_dtensor_result.stride() == y_dtensor_result_native.stride(), f"{label_test_case} stride mismatch" + + # compare forward result with native DTensor op + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_result_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # compare global tensors between shardwise_repeat_interleave and native DTensor results + y_result_global = y_dtensor_result.full_tensor() + y_result_global_native = y_dtensor_result_native.full_tensor() + + torch.testing.assert_close( + y_result_global, + y_result_global_native, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs native: {m}", + ) + + # assert input gradient metadata and values against DTensor native + assert ( + input_dtensor.grad.placements == input_grad_dtensor_native.placements + ), f"{label_test_case} input gradient placements mismatch" + assert ( + input_dtensor.grad.shape == input_grad_dtensor_native.shape + ), f"{label_test_case} input gradient shape mismatch" + assert ( + input_dtensor.grad.stride() == input_grad_dtensor_native.stride() + ), f"{label_test_case} input gradient stride mismatch" + + torch.testing.assert_close( + input_dtensor.grad.to_local(), + input_grad_dtensor_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + torch.testing.assert_close( + input_dtensor.grad.full_tensor(), + input_grad_dtensor_native.full_tensor(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + # =================================================================== + # BLOCK 2: Check against global serial expectation + # =================================================================== + y_dtensor_expected = distribute_tensor( + y_expected_global, manager.device_mesh_subgroups, y_dtensor_result.placements + ) + + # Compare results with expected local shards + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_expected.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # compare forward result with global expectation + torch.testing.assert_close( + y_result_global, + y_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs global expectation: {m}", + ) + + # create distributed tensor from global result for local shard comparison + input_grad_expected_dtensor = distribute_tensor( + input_grad_expected_global, manager.device_mesh_subgroups, input_placements + ) + + # compare local shards with expected + torch.testing.assert_close( + input_dtensor.grad.to_local(), + input_grad_expected_dtensor.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + torch.testing.assert_close( + input_dtensor.grad.full_tensor(), + input_grad_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_dtensor_repeat_interleave(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_dtensor_repeat_interleave, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) diff --git a/tests/distributed/model/layers/test_dtensor_scatter.py b/tests/distributed/model/layers/test_dtensor_scatter.py new file mode 100644 index 000000000..4cd0f5a5a --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_scatter.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.scatter import distributed_scatter_reduce +from boltz.testing.utils import assert_tensors_identical, seed_by_rank, spawn_multiprocessing + + +def einsum_scatter_reduce(idx_global, src_global, N_output, axis, reduce, idx_mask_global=None): + """Reference scatter-reduce using one-hot encoding and einsum. + + This computes the same result as: + output = zeros(output_shape) + output.scatter_reduce_(axis, idx_expanded, src, reduce=reduce) + + But uses one-hot encoding and einsum which is more explicit and easier to verify. + + Args: + idx_global: Index tensor (*batch, N_src) with values in [0, N_output) + src_global: Source tensor (*batch, N_src, *features) + N_output: Size of the output's scatter axis + axis: The scatter axis position + reduce: "sum" or "mean" + idx_mask_global: Optional mask tensor (*batch, N_src), True=valid, False=invalid + """ + # Get shapes + batch_shape = idx_global.shape[:axis] + N_src = idx_global.shape[axis] + feature_shape = src_global.shape[axis + 1 :] + dtype = src_global.dtype + + # Flatten batch dimensions for easier processing + B = torch.Size(batch_shape).numel() if batch_shape else 1 + F = torch.Size(feature_shape).numel() if feature_shape else 1 + + # Reshape to (B, N_src) and (B, N_src, F) + idx_flat = idx_global.reshape(B, N_src) + src_flat = src_global.reshape(B, N_src, F) + + # Create one-hot tensor: (B, N_src, N_output) + # one_hot[b, i, j] = 1 if idx_flat[b, i] == j + one_hot = torch.nn.functional.one_hot(idx_flat, num_classes=N_output).to(dtype) + + # Apply mask to one-hot (zeroing out masked entries means they don't contribute to output) + if idx_mask_global is not None: + mask_flat = idx_mask_global.reshape(B, N_src, 1).to(dtype) + one_hot = one_hot * mask_flat + + # Compute output using einsum + # out[b, j, f] = sum_i(one_hot[b, i, j] * src[b, i, f]) + # einsum: "bin,bif->bnf" where i=N_src, n=N_output, f=F + out_flat = torch.einsum("bin,bif->bnf", one_hot, src_flat) + + if reduce == "mean": + # count[b, j] = sum_i(one_hot[b, i, j]) = number of elements scattered to position j + count = one_hot.sum(dim=1, keepdim=True) # (B, 1, N_output) + count = count.permute(0, 2, 1) # (B, N_output, 1) + # Clamp to avoid division by zero; positions with count=0 already have out=0 from einsum + out_flat = out_flat / count.clamp(min=1) + + # Reshape back to original structure + output_shape = batch_shape + (N_output,) + feature_shape + out = out_flat.reshape(output_shape) + + return out + + +def parallel_assert_scatter_reduce(rank, grid_group_sizes, device_type, backend, env_map, dtype, reduce_op): + """Test distributed scatter_reduce against reference implementation.""" + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device = manager.device + + seed_by_rank(0, 42) + + # Test configurations with different shape patterns + # Format: (batch_shape_before_axis, N_per_rank, N_src_per_rank, feature_shape_after_N) + # output: (*batch, N, *features) + # idx: (*batch, N_src) + # src: (*batch, N_src, *features) + shape_configs = [ + # Simple cases: output (N, F), idx (N_src,), src (N_src, F) + ((), 8, 12, (3,)), # output: (N,3), idx: (N_src,), src: (N_src,3) + ((), 16, 8, ()), # output: (N,), idx: (N_src,), src: (N_src,) + # With batch: output (B, N, F), idx (B, N_src), src (B, N_src, F) + ((4,), 8, 12, ()), # output: (4,N), idx: (4,N_src), src: (4,N_src) + ((4,), 8, 12, (3,)), # output: (4,N,3), idx: (4,N_src), src: (4,N_src,3) + ((2,), 16, 8, (5,)), # output: (2,N,5), idx: (2,N_src), src: (2,N_src,5) + # Multi-dim features + ((2,), 8, 10, (3, 4)), # output: (2,N,3,4), idx: (2,N_src), src: (2,N_src,3,4) + # Larger test case + ((2,), 100, 1000, (3,)), # output: (2,N,3), idx: (2,N_src), src: (2,N_src,3) + ] + + for batch_shape, N_per_rank, N_src_per_rank, feature_shape in shape_configs: + for device_mesh in [manager.device_mesh_subgroups, manager.device_mesh]: + mesh_ndim = device_mesh.ndim + axis = len(batch_shape) # axis is right after batch dims + + # Determine placement strategy + size_group_shard_axis = None + if axis >= 1 and mesh_ndim >= 2: + # shard leading tensor axes as well as 'axis' + placements = (Shard(0), Shard(axis)) + (Replicate(),) * (mesh_ndim - 2) + size_group_shard_axis = device_mesh.size(1) + elif mesh_ndim >= 2: + # axis == 0: only shard 'axis' + placements = (Replicate(),) * (mesh_ndim - 1) + (Shard(axis),) + size_group_shard_axis = device_mesh.size(-1) + else: + # 1D mesh: shard on axis + placements = (Shard(axis),) + (Replicate(),) * (mesh_ndim - 1) + size_group_shard_axis = device_mesh.size(0) + + if size_group_shard_axis is None: + raise ValueError(f"size_group_shard_axis is None for axis {axis} and device_mesh {device_mesh}") + + N = N_per_rank * size_group_shard_axis + N_src = N_src_per_rank * size_group_shard_axis + + # Build shapes + output_shape = batch_shape + (N,) + feature_shape + idx_shape = batch_shape + (N_src,) + src_shape = batch_shape + (N_src,) + feature_shape + + # Test both without mask (None) and with random mask + for use_mask in [False, True]: + label = f"output:{output_shape}, idx:{idx_shape}, src:{src_shape}, axis:{axis}, reduce:{reduce_op}" + if use_mask: + label += " (masked)" + idx_mask_global = torch.rand(idx_shape, device=device) > 0.5 + else: + idx_mask_global = None + + # Create test data + idx_global = torch.randint(0, N, idx_shape, device=device) + src_global = torch.randn(src_shape, dtype=dtype, device=device, requires_grad=True) + + # Reference computation using einsum with one-hot encoding + out_ref = einsum_scatter_reduce(idx_global, src_global.detach(), N, axis, reduce_op, idx_mask_global) + + # Create grad_out for backward test + grad_out = torch.randn_like(out_ref) + + # Reference backward using autograd + src_global_for_bwd = src_global.detach().clone().requires_grad_(True) + out_ref_bwd = einsum_scatter_reduce(idx_global, src_global_for_bwd, N, axis, reduce_op, idx_mask_global) + out_ref_bwd.backward(grad_out) + grad_src_ref = src_global_for_bwd.grad + + # Distributed computation + idx_dtensor = distribute_tensor(idx_global, device_mesh, placements) + src_dtensor = distribute_tensor(src_global.detach().clone(), device_mesh, placements).requires_grad_( + True + ) + idx_mask_dtensor = ( + distribute_tensor(idx_mask_global, device_mesh, placements) if idx_mask_global is not None else None + ) + + out_dtensor = distributed_scatter_reduce( + output_size_per_rank=N_per_rank, + axis=axis, + idx=idx_dtensor, + src=src_dtensor, + reduce=reduce_op, + idx_mask=idx_mask_dtensor, + are_ids_contiguous=True, + ) + + # Compare forward output + out_local = out_dtensor.full_tensor() + assert_tensors_identical( + out_local.detach(), + out_ref, + check_stride=False, + check_grad=False, + check_grad_fn=False, + rtol=1e-10, + atol=1e-10, + msg=lambda m: f"{label} fwd output mismatch:\n {m}", + ) + + # Backward pass + grad_out_dtensor = distribute_tensor( + grad_out.detach().clone(), out_dtensor.device_mesh, out_dtensor.placements + ) + + out_dtensor.backward(grad_out_dtensor) + + # Compare gradients + grad_src_local = src_dtensor.grad.full_tensor() + + assert_tensors_identical( + grad_src_local, + grad_src_ref, + check_grad=False, + check_grad_fn=False, + rtol=1e-10, + atol=1e-10, + msg=lambda m: f"{label} bwd src gradient mismatch:\n {m}", + ) + + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize("reduce_op", ["sum", "mean"]) +def test_distributed_scatter_reduce(setup_env, reduce_op): + """Test distributed scatter_reduce operation.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + spawn_multiprocessing( + parallel_assert_scatter_reduce, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + torch.float64, + reduce_op, + ) diff --git a/tests/distributed/model/layers/test_dtensor_sharded_op.py b/tests/distributed/model/layers/test_dtensor_sharded_op.py new file mode 100644 index 000000000..7596fbdf2 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_sharded_op.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.sharded_op import sharded_sum +from boltz.testing.utils import assert_tensors_identical, spawn_multiprocessing + + +def serial_sum(x: torch.Tensor, dims: tuple[int, ...] | int, keepdim: bool = False) -> torch.Tensor: + """Serial implementation of sum operation for comparison.""" + return torch.sum(x, dim=dims, keepdim=keepdim) + + +def parallel_assert_sharded_sum( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + dims, + keepdim, + input_tensor_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_expected_global_host, + output_placements, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device_mesh = manager.device_mesh_subgroups + + # Distribute input tensor + input_tensor_dtensor = distribute_tensor( + input_tensor_global_host.to(manager.device), device_mesh=device_mesh, placements=placements + ).requires_grad_(True) + + # Distribute expected outputs + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=device_mesh, + placements=output_placements, + ) + + # Create copy to verify input isn't modified + input_tensor_dtensor_copy = input_tensor_dtensor.detach().clone().requires_grad_(True) + + # Forward pass + output_dtensor_result = sharded_sum(input_tensor_dtensor, dim=dims, keepdim=keepdim) + + # Verify input wasn't modified + assert_tensors_identical( + input_tensor_dtensor_copy.to_local(), input_tensor_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Test forward pass results + assert ( + output_dtensor_result.placements == output_placements + ), f"output_placements: {output_placements}, output_dtensor_result.placements: {output_dtensor_result.placements}" + assert ( + output_dtensor_result.shape == output_expected_global_host.shape + ), f"Output shape mismatch: {output_dtensor_result.shape} != {output_expected_global_host.shape}" + assert ( + output_dtensor_result.stride() == output_expected_global_host.stride() + ), f"Output stride mismatch: {output_dtensor_result.stride()} != {output_expected_global_host.stride()}" + torch.testing.assert_close(output_dtensor_result.full_tensor().cpu(), output_expected_global_host) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradient + assert ( + input_tensor_dtensor.grad.placements == placements + ), f"placements: {placements}, input_tensor_dtensor.grad.placements: {input_tensor_dtensor.grad.placements}" + assert ( + input_tensor_dtensor.grad.shape == input_tensor_global_host.shape + ), f"Input gradient shape mismatch: {input_tensor_dtensor.grad.shape} != {input_tensor_global_host.shape}" + torch.testing.assert_close(input_tensor_dtensor.grad.full_tensor().cpu(), d_input_expected_global_host) + + # Verify full tensors match expected results + torch.testing.assert_close(output_dtensor_result.full_tensor().cpu(), output_expected_global_host) + torch.testing.assert_close(input_tensor_dtensor.grad.full_tensor().cpu(), d_input_expected_global_host) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize( + "placements,dims,keepdim,output_placements", + [ + ((Shard(0), Shard(1), Shard(2)), (0, 1, 2), False, (Replicate(), Replicate(), Replicate())), + ((Shard(0), Shard(1), Shard(2)), (0, 2), False, (Replicate(), Shard(0), Replicate())), + ((Shard(0), Shard(1), Shard(2)), (1, 2), False, (Shard(0), Replicate(), Replicate())), + ((Shard(0), Shard(1), Shard(2)), (1, 2), True, (Shard(0), Replicate(), Replicate())), + ((Shard(0), Shard(1), Shard(2)), (1,), False, (Shard(0), Replicate(), Shard(1))), + ((Shard(0), Shard(1), Shard(2)), (1,), True, (Shard(0), Replicate(), Shard(2))), + ((Shard(0), Shard(1), Replicate()), (1,), False, (Shard(0), Replicate(), Replicate())), + ((Shard(0), Shard(1), Replicate()), (1,), True, (Shard(0), Replicate(), Replicate())), + ((Shard(2), Shard(0), Shard(1)), (1,), False, (Shard(1), Shard(0), Replicate())), + ((Shard(2), Shard(0), Shard(1)), (1,), True, (Shard(2), Shard(0), Replicate())), + ], + ids=[ + "pair_dim_0_1_2", + "pair_dim_0_2", + "pair_dim_1_2_keepdim", + "pair_dim_1_2", + "pair_dim_1", + "pair_dim_1_keepdim", + "single_dim_1", + "single_dim_1_keepdim", + "s2_s0_s1_dim_1", + "s2_s0_s1_dim_1_keepdim", + ], +) +def test_sharded_sum_parallel(setup_env, placements, dims, keepdim, output_placements): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 3 # Number of tokens + D = 5 # Hidden dimension + + # Create input tensor with proper shape + input_tensor_global = torch.randn((B, N, N, D), requires_grad=True, device=device_type) + + # Run serial forward pass + input_tensor_global_host = input_tensor_global.detach().clone().cpu() + output_expected_global = serial_sum(input_tensor_global, dims=dims, keepdim=keepdim) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Create upstream gradient and run backward pass + d_output_expected_global = torch.randn_like(output_expected_global) + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + + # Serial sum grad is not set to upstream gradient + assert not input_tensor_global.grad.is_set_to(d_output_expected_global) + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_sharded_sum, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + dims, + keepdim, + input_tensor_global_host, + output_expected_global_host, + d_output_expected_global_host, + input_tensor_global.grad.detach().clone().cpu(), + output_placements, + ) diff --git a/tests/distributed/model/layers/test_dtensor_shardwise_op.py b/tests/distributed/model/layers/test_dtensor_shardwise_op.py new file mode 100644 index 000000000..08ff4b0b7 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_shardwise_op.py @@ -0,0 +1,1687 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import unittest +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.shardwise_op import ( + ShardwiseOuterOp, + shardwise_argmax, + shardwise_log_softmax, + shardwise_offset, + shardwise_one_hot, + shardwise_outer_op, + shardwise_softmax, + shardwise_sum, +) +from boltz.testing.utils import assert_tensors_identical, init_tensors_uniform, seed_by_rank, spawn_multiprocessing + + +def serial_shardwise_sum(x: torch.Tensor, dim: int, keepdim: Optional[bool] = None) -> torch.Tensor: + """Serial implementation of shardwise sum operation for comparison.""" + if keepdim is None: + return torch.sum(x, dim=dim) + else: + return torch.sum(x, dim=dim, keepdim=keepdim) + + +def serial_shardwise_one_hot(input: torch.Tensor, num_classes: int = -1) -> torch.Tensor: + """Serial implementation of shardwise one_hot operation for comparison.""" + return F.one_hot(input, num_classes=num_classes) + + +def serial_shardwise_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor: + """Serial implementation of shardwise softmax operation for comparison.""" + return F.softmax(x, dim=dim) + + +def serial_shardwise_log_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor: + """Serial implementation of shardwise log_softmax operation for comparison.""" + return F.log_softmax(x, dim=dim) + + +def serial_shardwise_argmax(x: torch.Tensor, dim: int, keepdim: Optional[bool] = None) -> torch.Tensor: + """Serial implementation of shardwise argmax operation for comparison.""" + if keepdim is None: + return torch.argmax(x, dim=dim) + return torch.argmax(x, dim=dim, keepdim=keepdim) + + +def parallel_assert_shardwise_sum( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_x_expected_global_host, + dim: int, + keepdim: Optional[bool] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Distribute input tensor + input_x_dtensor = distribute_tensor( + input_x_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ).requires_grad_(True) + + # Distribute expected outputs + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + d_input_x_expected_dtensor = distribute_tensor( + d_input_x_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Create copy to verify input isn't modified + input_x_dtensor_copy = input_x_dtensor.detach().clone().requires_grad_(True) + + # Forward pass + output_dtensor_result = shardwise_sum(input_x_dtensor, dim, keepdim) + + # Verify input wasn't modified + assert_tensors_identical( + input_x_dtensor_copy.to_local(), input_x_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Test forward pass results + assert output_dtensor_result.shape == output_expected_dtensor.shape + assert output_dtensor_result.stride() == output_expected_dtensor.stride() + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradient + assert input_x_dtensor.grad.shape == d_input_x_expected_dtensor.shape + assert input_x_dtensor.grad.stride() == d_input_x_expected_dtensor.stride() + torch.testing.assert_close(input_x_dtensor.grad.to_local(), d_input_x_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + d_input_x_global_result_host = input_x_dtensor.grad.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(output_global_result_host, output_expected_global_host) + torch.testing.assert_close(d_input_x_global_result_host, d_input_x_expected_global_host) + + DistributedManager.cleanup() + monkeypatch.undo() + + +def parallel_assert_shardwise_argmax( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, + output_expected_global_host, + dim: int, + keepdim: Optional[bool], +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + input_x_dtensor = distribute_tensor( + input_x_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ) + + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + input_x_dtensor_copy = input_x_dtensor.detach().clone() + + output_dtensor_result = shardwise_argmax(input_x_dtensor, dim, keepdim) + + assert_tensors_identical( + input_x_dtensor_copy.to_local(), input_x_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + assert output_dtensor_result.shape == output_expected_dtensor.shape + assert output_dtensor_result.stride() == output_expected_dtensor.stride() + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + output_global_result_host = output_dtensor_result.full_tensor().cpu() + torch.testing.assert_close(output_global_result_host, output_expected_global_host) + + assert output_dtensor_result.placements == placements + assert output_dtensor_result.dtype == torch.long + + with pytest.raises(RuntimeError, match="does not require grad"): + output_dtensor_result.backward(torch.empty_like(output_dtensor_result)) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize( + "placements", [(Shard(0), Shard(1), Shard(2)), (Shard(0), Shard(1), Replicate())], ids=["shard", "replicate"] +) +@pytest.mark.parametrize( + "sum_config", + [ + (-1, False), + (-1, True), + (2, False), + (2, True), + ], + ids=lambda x: f"dim={x[0]}, keepdim={x[1]}", +) +def test_shardwise_sum_parallel(setup_env, placements, sum_config): + dim, keepdim = sum_config + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 100 # Number of tokens + D = 32 # Hidden dimension + + # Skip test if trying to sum along a sharded dimension + input_shape = (B, N, N, D) + actual_dim = dim if dim >= 0 else len(input_shape) + dim + for placement in placements: + if isinstance(placement, Shard) and placement.dim == actual_dim: + pytest.skip(f"Skipping test: sum along sharded dimension {dim} is not supported") + + seed = 42 + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + + # Create input tensor with proper shape + input_x_global = torch.rand((B, N, N, D), requires_grad=True, device=device_type, generator=rng) + + # Run serial forward pass + input_x_global_host = input_x_global.detach().clone().cpu() + output_expected_global = serial_shardwise_sum(input_x_global, dim, keepdim) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Create upstream gradient and run backward pass + d_output_expected_global = torch.rand_like(output_expected_global) + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_shardwise_sum, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, + output_expected_global_host, + d_output_expected_global_host, + input_x_global.grad.detach().clone().cpu(), + dim, + keepdim, + ) + + +def assert_error_case(rank, grid_group_sizes, device_type, backend, env_per_rank): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + seed_by_rank(0) + + B = 2 * grid_group_sizes["dp"] + N = grid_group_sizes["cp"][0] * 100 # Number of tokens + D = 32 # Hidden dimension + + # Test case 1: Sum along sharded dimension should raise ValueError + input_tensor = torch.randn((B, N, D), device=manager.device, requires_grad=True) + sharded_dtensor = distribute_tensor( + input_tensor, device_mesh=manager.device_mesh_subgroups, placements=(Shard(0), Shard(1), Replicate()) + ) + + # This should raise an error because we're trying to sum along dimension 0, which is sharded + with pytest.raises(ValueError, match="Sum along sharded dimension 0 is not supported"): + shardwise_sum(sharded_dtensor, dim=0) + shardwise_sum(sharded_dtensor, dim=1) + + # Test case 2: Invalid input type + with pytest.raises(TypeError, match="Expected DTensor"): + shardwise_sum(input_tensor, dim=1) # Regular tensor instead of DTensor + + # Test case 3: Invalid dim type + with pytest.raises(TypeError, match="Expected int for dim"): + shardwise_sum(sharded_dtensor, dim=1.5) # Float instead of int + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_shardwise_sum_error_cases(setup_env): + """Test error cases for shardwise_sum function.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + assert_error_case, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +def parallel_assert_shardwise_one_hot( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_indices_global_host, + output_expected_global_host, + num_classes: int = -1, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Distribute input tensor (indices should be long type, no gradients needed) + input_indices_dtensor = distribute_tensor( + input_indices_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ) + + # Distribute expected output + # For one_hot, the output has the same placements as input for the original dimensions, + # and the new one-hot dimension will be handled by the shardwise_one_hot implementation + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Create copy to verify input isn't modified + input_indices_dtensor_copy = input_indices_dtensor.detach().clone() + + # Forward pass + output_dtensor_result = shardwise_one_hot(input_indices_dtensor, num_classes) + + # Verify input wasn't modified + assert_tensors_identical( + input_indices_dtensor_copy.to_local(), input_indices_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Test forward pass results + assert ( + output_dtensor_result.shape == output_expected_dtensor.shape + ), f"Output shape mismatch: expected {output_expected_dtensor.shape}, got {output_dtensor_result.shape}" + assert ( + output_dtensor_result.stride() == output_expected_dtensor.stride() + ), f"Output stride mismatch: expected {output_expected_dtensor.stride()}, got {output_dtensor_result.stride()}" + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(output_global_result_host, output_expected_global_host) + + # Verify placements are correct (same as input placements) + assert output_dtensor_result.placements == placements + + # Ensure no backward possible + with pytest.raises(RuntimeError, match="tensors does not require grad and does not have a grad_fn"): + output_dtensor_result.backward(torch.empty_like(output_dtensor_result)) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize( + "placements", [(Shard(0), Shard(1), Replicate()), (Shard(0), Shard(1), Shard(2))], ids=["single", "pair"] +) +@pytest.mark.parametrize( + "num_classes_config", + [ + 3, # explicit num_classes + -1, # inferred num_classes + ], + ids=lambda x: f"num_classes={x}", +) +def test_shardwise_one_hot_parallel(setup_env, placements, num_classes_config): + num_classes = num_classes_config + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 10 # Number of tokens (smaller for indices) + D = 8 # Small dimension for indices + + seed = 42 + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + + # Create input tensor with integer indices for one-hot encoding (3D to match placements) + # Use modest range to ensure valid indices + max_index = 2 if num_classes == -1 else num_classes - 1 + input_indices_global = torch.randint( + 0, max_index + 1, (B, N, N, D), device=device_type, generator=rng, dtype=torch.long + ) + + # Run serial forward pass + input_indices_global_host = input_indices_global.detach().clone().cpu() + output_expected_global = serial_shardwise_one_hot(input_indices_global, num_classes) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Ensure no backward possible + with pytest.raises(RuntimeError, match="tensors does not require grad and does not have a grad_fn"): + output_expected_global.backward(torch.empty_like(output_expected_global)) + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_shardwise_one_hot, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_indices_global_host, + output_expected_global_host, + num_classes, + ) + + +def parallel_assert_shardwise_softmax( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_x_expected_global_host, + dim: int = -1, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Distribute input tensor + input_x_dtensor = distribute_tensor( + input_x_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ).requires_grad_(True) + + # Distribute expected outputs + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + + # Create copy to verify input isn't modified + input_x_dtensor_copy = input_x_dtensor.detach().clone().requires_grad_(True) + + # Forward pass + output_dtensor_result = shardwise_softmax(input_x_dtensor, dim) + + # Verify input wasn't modified + assert_tensors_identical( + input_x_dtensor_copy.to_local(), input_x_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Test forward pass results + assert output_dtensor_result.shape == output_expected_global_host.shape + assert output_dtensor_result.stride() == output_expected_global_host.stride() + torch.testing.assert_close(output_dtensor_result.full_tensor().cpu(), output_expected_global_host.cpu()) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradient + assert input_x_dtensor.grad.shape == d_input_x_expected_global_host.shape + assert input_x_dtensor.grad.stride() == d_input_x_expected_global_host.stride() + torch.testing.assert_close(input_x_dtensor.grad.full_tensor().cpu(), d_input_x_expected_global_host.cpu()) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, dtype", + ( + params_test := [ + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32), + (((2, (2, 2)), True, "cuda", "ENV"), torch.bfloat16), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp={x[0][0][0]}, cp={x[0][0][1]}, specify_method={x[0][1]}, device_type={x[0][2]}, method_init={x[0][3]}, dtype={x[1]}" + for x in params_test + ], +) +@pytest.mark.parametrize( + "dim_config", + [ + -1, # last dimension + 2, # third dimension + ], + ids=lambda x: f"dim={x}", +) +def test_shardwise_softmax_parallel(setup_env, dtype, dim_config): + dim = dim_config + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + placements = (Shard(0), Shard(1), Replicate()) + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 10 # Number of tokens/atoms/etc for single repr input + num_bins = 50 # Number of bins for softmax used in plddt_logits, see ConfidenceHeads class + + min_init_val = -0.2 + max_init_val = 0.2 + + # Skip test if trying to apply softmax along a sharded dimension + input_shape = (B, N, num_bins) + + seed = 42 + seed_by_rank(0, seed=seed) + + # Create input tensor with proper shape and dtype + input_x_global = torch.empty(input_shape, requires_grad=True, device=device_type, dtype=dtype) + init_tensors_uniform([input_x_global], low=min_init_val, high=max_init_val) + + # Run serial forward pass + input_x_global_host = input_x_global.detach().clone().cpu() + output_expected_global = serial_shardwise_softmax(input_x_global, dim) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Create upstream gradient and run backward pass + d_output_expected_global = torch.empty_like(output_expected_global, dtype=dtype) + init_tensors_uniform([d_output_expected_global], low=min_init_val, high=max_init_val) + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_shardwise_softmax, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, + output_expected_global_host, + d_output_expected_global_host, + input_x_global.grad.detach().clone().cpu(), + dim, + ) + + +def parallel_assert_shardwise_softmax_error_cases( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Distribute input tensor + input_x_dtensor = distribute_tensor( + input_x_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ).requires_grad_(True) + + # Expect ValueError when softmax is along a sharded dimension - pick last one for simplicity + with pytest.raises(ValueError, match="Softmax along sharded dimension"): + shardwise_softmax(input_x_dtensor, dim=-1) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + ( + params_test := [ + ((1, (1, 1)), True, "cuda", "ENV"), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}" + for x in params_test + ], +) +def test_shardwise_softmax_parallel_error_cases(setup_env): + """Test that shardwise_softmax raises ValueError when softmax is along a sharded dimension.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + placements = (Shard(0), Shard(1), Shard(2)) + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 10 + num_bins = 50 + + min_init_val = -0.2 + max_init_val = 0.2 + + input_shape = (B, N, num_bins) + + seed = 42 + seed_by_rank(0, seed=seed) + + # Create input tensor with proper shape and dtype + input_x_global = torch.empty(input_shape, requires_grad=True, device=device_type) + init_tensors_uniform([input_x_global], low=min_init_val, high=max_init_val) + + input_x_global_host = input_x_global.detach().clone().cpu() + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_shardwise_softmax_error_cases, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, + ) + + +def parallel_assert_shardwise_log_softmax( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_x_expected_global_host, + dim: int = -1, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Distribute input tensor + input_x_dtensor = distribute_tensor( + input_x_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ).requires_grad_(True) + + # Distribute expected outputs + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + d_input_x_expected_dtensor = distribute_tensor( + d_input_x_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Create copy to verify input isn't modified + input_x_dtensor_copy = input_x_dtensor.detach().clone().requires_grad_(True) + + # Forward pass + output_dtensor_result = shardwise_log_softmax(input_x_dtensor, dim) + + # Verify input wasn't modified + assert_tensors_identical( + input_x_dtensor_copy.to_local(), input_x_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Test forward pass results + assert output_dtensor_result.shape == output_expected_dtensor.shape + assert output_dtensor_result.stride() == output_expected_dtensor.stride() + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradient + assert input_x_dtensor.grad.shape == d_input_x_expected_dtensor.shape + assert input_x_dtensor.grad.stride() == d_input_x_expected_dtensor.stride() + torch.testing.assert_close(input_x_dtensor.grad.to_local(), d_input_x_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + d_input_x_global_result_host = input_x_dtensor.grad.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(output_global_result_host, output_expected_global_host) + torch.testing.assert_close(d_input_x_global_result_host, d_input_x_expected_global_host) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize( + "placements", [(Shard(0), Shard(1), Shard(2)), (Shard(0), Shard(1), Replicate())], ids=["shard", "replicate"] +) +@pytest.mark.parametrize( + "dim_config", + [ + -1, # last dimension + 2, # third dimension + ], + ids=lambda x: f"dim={x}", +) +def test_shardwise_log_softmax_parallel(setup_env, placements, dim_config): + dim = dim_config + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 10 # Number of tokens + D = 8 # Hidden dimension + min_init_val = -0.2 + max_init_val = 0.2 + + # Skip test if trying to apply log_softmax along a sharded dimension + input_shape = (B, N, N, D) + actual_dim = dim if dim >= 0 else len(input_shape) + dim + for placement in placements: + if isinstance(placement, Shard) and placement.dim == actual_dim: + pytest.skip(f"Skipping test: log_softmax along sharded dimension {dim} is not supported") + + seed = 42 + seed_by_rank(0, seed=seed) + + # Create input tensor with proper shape + input_x_global = torch.empty((B, N, N, D), requires_grad=True, device=device_type) + init_tensors_uniform([input_x_global], low=min_init_val, high=max_init_val) + + # Run serial forward pass + input_x_global_host = input_x_global.detach().clone().cpu() + output_expected_global = serial_shardwise_log_softmax(input_x_global, dim) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Create upstream gradient and run backward pass + d_output_expected_global = torch.empty_like(output_expected_global) + init_tensors_uniform([d_output_expected_global], low=min_init_val, high=max_init_val) + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_shardwise_log_softmax, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, + output_expected_global_host, + d_output_expected_global_host, + input_x_global.grad.detach().clone().cpu(), + dim, + ) + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize( + "placements", [(Shard(0), Shard(1), Shard(2)), (Shard(0), Shard(1), Replicate())], ids=["shard", "replicate"] +) +@pytest.mark.parametrize( + "argmax_config", + [ + (-1, False), + (-1, True), + (2, False), + (2, True), + ], + ids=lambda x: f"dim={x[0]}, keepdim={x[1]}", +) +def test_shardwise_argmax_parallel(setup_env, placements, argmax_config): + dim, keepdim = argmax_config + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 100 + D = 16 + + input_shape = (B, N, N, D) + actual_dim = dim if dim >= 0 else len(input_shape) + dim + for placement in placements: + if isinstance(placement, Shard) and placement.dim == actual_dim: + pytest.skip(f"Skipping test: argmax along sharded dimension {dim} is not supported") + + seed = 123 + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + + input_x_global = torch.rand(input_shape, device=device_type, generator=rng) + + input_x_global_host = input_x_global.detach().clone().cpu() + output_expected_global = serial_shardwise_argmax(input_x_global, dim, keepdim) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + spawn_multiprocessing( + parallel_assert_shardwise_argmax, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, + output_expected_global_host, + dim, + keepdim, + ) + + +def serial_shardwise_offset(x: torch.Tensor, dim: int, offset_per_rank: float, num_shards: int) -> torch.Tensor: + """Serial implementation of shardwise offset operation for comparison. + + Simulates what shardwise_offset does across distributed ranks by applying + rank-dependent offsets to each shard of the tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + dim : int + The dimension that would be sharded. + offset_per_rank : float + The offset value per rank. + num_shards : int + Number of shards (ranks) along the specified dimension. + + Returns + ------- + torch.Tensor + Tensor with rank-dependent offsets applied to each shard. + """ + actual_dim = dim if dim >= 0 else x.ndim + dim + dim_size = x.shape[actual_dim] + shard_size = dim_size // num_shards + + output = x.clone() + for rank in range(num_shards): + # Create slice for this rank's shard + slices = [slice(None)] * x.ndim + slices[actual_dim] = slice(rank * shard_size, (rank + 1) * shard_size) + + # Add offset for this rank + output[tuple(slices)] = output[tuple(slices)] + rank * offset_per_rank + + return output + + +def parallel_assert_shardwise_offset( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_x_expected_global_host, + dim: int, + offset_per_rank: float, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Distribute input tensor + input_x_dtensor = distribute_tensor( + input_x_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ).requires_grad_(True) + + # Distribute expected outputs + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + d_input_x_expected_dtensor = distribute_tensor( + d_input_x_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Create copy to verify input isn't modified + input_x_dtensor_copy = input_x_dtensor.detach().clone().requires_grad_(True) + + # Forward pass + output_dtensor_result = shardwise_offset(input_x_dtensor, dim, offset_per_rank) + + # Verify input wasn't modified + assert_tensors_identical( + input_x_dtensor_copy.to_local(), input_x_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Test forward pass results + assert output_dtensor_result.shape == output_expected_dtensor.shape + assert output_dtensor_result.stride() == output_expected_dtensor.stride() + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradient + assert input_x_dtensor.grad.shape == d_input_x_expected_dtensor.shape + assert input_x_dtensor.grad.stride() == d_input_x_expected_dtensor.stride() + torch.testing.assert_close(input_x_dtensor.grad.to_local(), d_input_x_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + d_input_x_global_result_host = input_x_dtensor.grad.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(output_global_result_host, output_expected_global_host) + torch.testing.assert_close(d_input_x_global_result_host, d_input_x_expected_global_host) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize( + "placements", [(Shard(0), Shard(1), Shard(2)), (Shard(0), Shard(1), Replicate())], ids=["shard", "replicate"] +) +@pytest.mark.parametrize( + "offset_config", + [ + (1, 100.0), # dim 1 with offset 100.0 + (2, 50.0), # dim 2 with offset 50.0 + ], + ids=lambda x: f"dim={x[0]}, offset={x[1]}", +) +def test_shardwise_offset_parallel(setup_env, placements, offset_config): + dim, offset_per_rank = offset_config + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 100 # Number of tokens + D = 32 # Hidden dimension + + # Skip test if the specified dimension is NOT sharded (offset requires sharded dim) + input_shape = (B, N, N, D) + actual_dim = dim if dim >= 0 else len(input_shape) + dim + dim_is_sharded = False + mesh_axis_for_dim = None + for i, placement in enumerate(placements): + if isinstance(placement, Shard) and placement.dim == actual_dim: + dim_is_sharded = True + mesh_axis_for_dim = i + break + + if not dim_is_sharded: + pytest.skip(f"Skipping test: dimension {dim} is not sharded, but shardwise_offset requires it to be sharded") + + # Get the number of shards for this dimension from the mesh + # mesh shape is (dp, cp[0], cp[1]) which corresponds to (mesh_dim_0, mesh_dim_1, mesh_dim_2) + mesh_shape = (grid_group_sizes["dp"], grid_group_sizes["cp"][0], grid_group_sizes["cp"][1]) + num_shards = mesh_shape[mesh_axis_for_dim] + + seed = 42 + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + + # Create input tensor with proper shape + input_x_global = torch.rand((B, N, N, D), requires_grad=True, device=device_type, generator=rng) + + # Run serial forward pass + input_x_global_host = input_x_global.detach().clone().cpu() + output_expected_global = serial_shardwise_offset(input_x_global, dim, offset_per_rank, num_shards) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Create upstream gradient and run backward pass + # For offset, backward is identity (gradient passes through) + d_output_expected_global = torch.rand_like(output_expected_global) + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + + # Since offset is x + constant, grad_input = grad_output + # The serial implementation clones input, so grad flows through clone to input + # grad_input should equal grad_output + d_input_x_expected_global_host = d_output_expected_global_host.clone() + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_shardwise_offset, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_x_expected_global_host, + dim, + offset_per_rank, + ) + + +def assert_offset_error_case(rank, grid_group_sizes, device_type, backend, env_per_rank): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + seed_by_rank(0) + + B = 2 * grid_group_sizes["dp"] + N = grid_group_sizes["cp"][0] * 100 # Number of tokens + D = 32 # Hidden dimension + + # Test case 1: Offset along non-sharded dimension should raise ValueError + input_tensor = torch.randn((B, N, D), device=manager.device, requires_grad=True) + # Shard dim 0 and dim 1, but dim 2 is Replicate + sharded_dtensor = distribute_tensor( + input_tensor, device_mesh=manager.device_mesh_subgroups, placements=(Shard(0), Shard(1), Replicate()) + ) + + # This should raise an error because dim 2 is not sharded + with pytest.raises(ValueError, match="Dimension 2 must be sharded for shardwise_offset"): + shardwise_offset(sharded_dtensor, dim=2, offset_per_rank=100.0) + + # Test case 2: Invalid input type + with pytest.raises(TypeError, match="Expected DTensor"): + shardwise_offset(input_tensor, dim=1, offset_per_rank=100.0) # Regular tensor instead of DTensor + + # Test case 3: Invalid dim type + with pytest.raises(TypeError, match="Expected int for dim"): + shardwise_offset(sharded_dtensor, dim=1.5, offset_per_rank=100.0) # Float instead of int + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_shardwise_offset_error_cases(setup_env): + """Test error cases for shardwise_offset function.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + assert_offset_error_case, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +def serial_shardwise_outer_op(x: torch.Tensor, y: torch.Tensor, axis: int, op: ShardwiseOuterOp) -> torch.Tensor: + """Serial implementation of shardwise outer op for comparison. + + Takes tensors without singletons and computes the outer operation at the specified axis. + x: (..., L, ...) at axis + y: (..., R, ...) at axis + output: (..., L, R, ...) with one more dimension + """ + # Unsqueeze to create broadcast-compatible shapes + x_expanded = x.unsqueeze(axis + 1) # (..., L, 1, ...) + y_expanded = y.unsqueeze(axis) # (..., 1, R, ...) + + if op == ShardwiseOuterOp.SUBTRACT: + return x_expanded - y_expanded + elif op == ShardwiseOuterOp.ADD: + return x_expanded + y_expanded + elif op == ShardwiseOuterOp.LOGICAL_AND: + return x_expanded & y_expanded + elif op == ShardwiseOuterOp.EQUAL: + return x_expanded == y_expanded + else: + raise ValueError(f"Unsupported operation: {op}") + + +def parallel_assert_shardwise_outer_op( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + axis, + op, + is_differentiable, + input_x_global_host, + input_y_global_host, + output_expected_global_host, + output_placements_expected, + d_output_expected_global_host=None, + d_input_x_expected_global_host=None, + d_input_y_expected_global_host=None, +): + """Parallel assertion function for shardwise_outer_op with all ShardwiseOuterOp types.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Distribute input tensors + input_x_dtensor = distribute_tensor( + input_x_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ) + input_y_dtensor = distribute_tensor( + input_y_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ) + + if is_differentiable: + input_x_dtensor = input_x_dtensor.requires_grad_(True) + input_y_dtensor = input_y_dtensor.requires_grad_(True) + + # Distribute expected output + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=output_placements_expected, + src_data_rank=None, + ) + + # Create copies to verify inputs aren't modified + # Note: detach().clone() drops requires_grad, so we restore it for proper comparison + input_x_dtensor_copy = input_x_dtensor.detach().clone().requires_grad_(is_differentiable) + input_y_dtensor_copy = input_y_dtensor.detach().clone().requires_grad_(is_differentiable) + + # Forward pass + output_dtensor_result = shardwise_outer_op(input_x_dtensor, input_y_dtensor, axis, op) + + # Verify inputs weren't modified + assert_tensors_identical( + input_x_dtensor_copy.to_local(), input_x_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical( + input_y_dtensor_copy.to_local(), input_y_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Test forward pass results + assert ( + output_dtensor_result.shape == output_expected_dtensor.shape + ), f"Shape mismatch: {output_dtensor_result.shape} != {output_expected_dtensor.shape}" + assert ( + output_dtensor_result.stride() == output_expected_dtensor.stride() + ), f"Stride mismatch: {output_dtensor_result.stride()} != {output_expected_dtensor.stride()}" + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + torch.testing.assert_close(output_global_result_host, output_expected_global_host) + + # Verify placements are correct + assert output_dtensor_result.placements == output_placements_expected + + if is_differentiable: + # Distribute expected gradients + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=output_placements_expected, + ) + d_input_x_expected_dtensor = distribute_tensor( + d_input_x_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + d_input_y_expected_dtensor = distribute_tensor( + d_input_y_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradients + assert input_x_dtensor.grad.shape == d_input_x_expected_dtensor.shape + assert input_x_dtensor.grad.stride() == d_input_x_expected_dtensor.stride() + torch.testing.assert_close(input_x_dtensor.grad.to_local(), d_input_x_expected_dtensor.to_local()) + + assert input_y_dtensor.grad.shape == d_input_y_expected_dtensor.shape + assert input_y_dtensor.grad.stride() == d_input_y_expected_dtensor.stride() + torch.testing.assert_close(input_y_dtensor.grad.to_local(), d_input_y_expected_dtensor.to_local()) + + # Verify full gradient tensors match expected results + d_input_x_global_result_host = input_x_dtensor.grad.full_tensor().cpu() + d_input_y_global_result_host = input_y_dtensor.grad.full_tensor().cpu() + torch.testing.assert_close(d_input_x_global_result_host, d_input_x_expected_global_host) + torch.testing.assert_close(d_input_y_global_result_host, d_input_y_expected_global_host) + else: + # Ensure no backward possible (non-differentiable) + with pytest.raises(RuntimeError, match="does not require grad"): + output_dtensor_result.backward(torch.empty_like(output_dtensor_result)) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize( + # Note: Cannot use Shard on the axis dimension (outer op axis must be local) + # Input shape is (B, K, L, D) or (B, K, R, D) with axis=2 + # So axis dim (2) must use Replicate + "placements", + [(Shard(0), Shard(1), Replicate()), (Replicate(), Shard(1), Replicate())], + ids=["shard_batch_and_K", "shard_K_only"], +) +@pytest.mark.parametrize( + "op", + [ShardwiseOuterOp.SUBTRACT, ShardwiseOuterOp.ADD, ShardwiseOuterOp.LOGICAL_AND, ShardwiseOuterOp.EQUAL], + ids=["SUBTRACT", "ADD", "LOGICAL_AND", "EQUAL"], +) +def test_shardwise_outer_op_parallel(setup_env, placements, op): + """Test shardwise_outer_op for all ShardwiseOuterOp types.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + K = size_ring * 4 # Number of windows + L = 8 # Size at axis for x (e.g., queries) + R = 16 # Size at axis for y (e.g., keys) + D = 3 # Feature dimension + axis = 2 # The axis at which to perform outer operation + + seed = 42 + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + + is_differentiable = op in (ShardwiseOuterOp.SUBTRACT, ShardwiseOuterOp.ADD) + + # Create input tensors WITHOUT singletons + # x: (B, K, L, D) and y: (B, K, R, D) + if op == ShardwiseOuterOp.SUBTRACT: + input_x_global = torch.rand((B, K, L, D), requires_grad=True, device=device_type, generator=rng) + input_y_global = torch.rand((B, K, R, D), requires_grad=True, device=device_type, generator=rng) + elif op == ShardwiseOuterOp.ADD: + input_x_global = torch.rand((B, K, L, D), requires_grad=True, device=device_type, generator=rng) + input_y_global = torch.rand((B, K, R, D), requires_grad=True, device=device_type, generator=rng) + elif op == ShardwiseOuterOp.LOGICAL_AND: + input_x_global = torch.randint(0, 2, (B, K, L, D), device=device_type, generator=rng).bool() + input_y_global = torch.randint(0, 2, (B, K, R, D), device=device_type, generator=rng).bool() + elif op == ShardwiseOuterOp.EQUAL: + num_unique_values = 5 + input_x_global = torch.randint(0, num_unique_values, (B, K, L, D), device=device_type, generator=rng) + input_y_global = torch.randint(0, num_unique_values, (B, K, R, D), device=device_type, generator=rng) + else: + raise ValueError(f"Unknown op: {op}") + + # Run serial forward pass + input_x_global_host = input_x_global.detach().clone().cpu() + input_y_global_host = input_y_global.detach().clone().cpu() + output_expected_global = serial_shardwise_outer_op(input_x_global, input_y_global, axis, op) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Compute expected output placements (axis+1 is inserted, so Shard dims > axis shift) + output_placements_expected = list(placements) + for i, p in enumerate(placements): + if isinstance(p, Shard) and p.dim > axis: + output_placements_expected[i] = Shard(p.dim + 1) + output_placements_expected = tuple(output_placements_expected) + + # Prepare gradient data for differentiable ops + d_output_expected_global_host = None + d_input_x_expected_global_host = None + d_input_y_expected_global_host = None + + if is_differentiable: + d_output_expected_global = torch.rand_like(output_expected_global) + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + d_input_x_expected_global_host = input_x_global.grad.detach().clone().cpu() + d_input_y_expected_global_host = input_y_global.grad.detach().clone().cpu() + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_shardwise_outer_op, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + axis, + op, + is_differentiable, + input_x_global_host, + input_y_global_host, + output_expected_global_host, + output_placements_expected, + d_output_expected_global_host, + d_input_x_expected_global_host, + d_input_y_expected_global_host, + ) + + +def assert_outer_op_error_cases(rank, grid_group_sizes, device_type, backend, env_per_rank): + """Test error cases for shardwise_outer_op.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + seed_by_rank(0) + + B = 2 * grid_group_sizes["dp"] + K = grid_group_sizes["cp"][0] * 4 + L = 8 + R = 16 + D = 3 + axis = 2 + + placements = (Shard(0), Shard(1), Replicate()) + + # Test case 1: Invalid input type (regular tensor instead of DTensor) + regular_tensor = torch.randn((B, K, L, D), device=manager.device) + dtensor = distribute_tensor( + torch.randn((B, K, R, D), device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + + with pytest.raises(TypeError, match="Expected DTensor for lhs"): + shardwise_outer_op(regular_tensor, dtensor, axis, ShardwiseOuterOp.SUBTRACT) + + with pytest.raises(TypeError, match="Expected DTensor for rhs"): + shardwise_outer_op(dtensor, regular_tensor, axis, ShardwiseOuterOp.SUBTRACT) + + # Test case 2: Mismatched placements + placements2 = (Shard(0), Replicate(), Shard(2)) + dtensor1 = distribute_tensor( + torch.randn((B, K, L, D), device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + dtensor2 = distribute_tensor( + torch.randn((B, K, R, D), device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements2, + ) + + with pytest.raises(ValueError, match="must have the same placements"): + shardwise_outer_op(dtensor1, dtensor2, axis, ShardwiseOuterOp.SUBTRACT) + + # Test case 3: Trying to shard the axis dimension (outer op must be local) + placements_axis_shard = (Shard(0), Shard(1), Shard(2)) # Try to shard axis dim (2) + dtensor3 = distribute_tensor( + torch.randn((B, K, L, D), device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_axis_shard, + ) + dtensor4 = distribute_tensor( + torch.randn((B, K, R, D), device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_axis_shard, + ) + + with pytest.raises(ValueError, match="Cannot shard dimension.*outer operation axis"): + shardwise_outer_op(dtensor3, dtensor4, axis, ShardwiseOuterOp.SUBTRACT) + + # Test case 4: Invalid axis type + dtensor5 = distribute_tensor( + torch.randn((B, K, L, D), device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + dtensor6 = distribute_tensor( + torch.randn((B, K, R, D), device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + + with pytest.raises(TypeError, match="Expected int for axis"): + shardwise_outer_op(dtensor5, dtensor6, "invalid", ShardwiseOuterOp.SUBTRACT) + + # Test case 5: axis out of bounds + with pytest.raises(ValueError, match="axis.*out of bounds"): + shardwise_outer_op(dtensor5, dtensor6, 10, ShardwiseOuterOp.SUBTRACT) + + # Test case 6: Mismatched number of dimensions + dtensor7 = distribute_tensor( + torch.randn((B, K, L), device=manager.device), # 3D instead of 4D + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + + with pytest.raises(ValueError, match="must have the same number of dimensions"): + shardwise_outer_op(dtensor5, dtensor7, axis, ShardwiseOuterOp.SUBTRACT) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_shardwise_outer_op_error_cases(setup_env): + """Test error cases for shardwise_outer_op.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + assert_outer_op_error_cases, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/distributed/model/layers/test_dtensor_sigmoid_gate.py b/tests/distributed/model/layers/test_dtensor_sigmoid_gate.py new file mode 100644 index 000000000..532379565 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_sigmoid_gate.py @@ -0,0 +1,198 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import itertools + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.sigmoid_gate import sigmoid_gate +from boltz.testing.utils import assert_tensors_identical, spawn_multiprocessing + + +def serial_sigmoid_gate(x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: + """Serial implementation of sigmoid gate for comparison.""" + return x * g.sigmoid() + + +def parallel_assert_sigmoid_gate( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, + input_g_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_x_expected_global_host, + d_input_g_expected_global_host, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Distribute input tensors + input_x_dtensor = distribute_tensor( + input_x_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ).requires_grad_(True) + + input_g_dtensor = distribute_tensor( + input_g_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ).requires_grad_(True) + + # Distribute expected outputs + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + d_input_x_expected_dtensor = distribute_tensor( + d_input_x_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + d_input_g_expected_dtensor = distribute_tensor( + d_input_g_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Create copies to verify inputs aren't modified + input_x_dtensor_copy = input_x_dtensor.detach().clone().requires_grad_(True) + input_g_dtensor_copy = input_g_dtensor.detach().clone().requires_grad_(True) + + # Forward pass + output_dtensor_result = sigmoid_gate(input_x_dtensor, input_g_dtensor) + + # Verify inputs weren't modified + assert_tensors_identical( + input_x_dtensor_copy.to_local(), input_x_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical( + input_g_dtensor_copy.to_local(), input_g_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Test forward pass results + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradients + torch.testing.assert_close(input_x_dtensor.grad.to_local(), d_input_x_expected_dtensor.to_local()) + torch.testing.assert_close(input_g_dtensor.grad.to_local(), d_input_g_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + d_input_x_global_result_host = input_x_dtensor.grad.full_tensor().cpu() + d_input_g_global_result_host = input_g_dtensor.grad.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(output_global_result_host, output_expected_global_host) + torch.testing.assert_close(d_input_x_global_result_host, d_input_x_expected_global_host) + torch.testing.assert_close(d_input_g_global_result_host, d_input_g_expected_global_host) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + itertools.product([(1, (2, 2)), (2, (2, 2))], [True], ["cpu", "cuda"], ["ENV"]), + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize( + "placements", [(Shard(0), Shard(1), Shard(2)), (Shard(0), Shard(1), Replicate())], ids=["shard", "replicate"] +) +def test_sigmoid_gate_parallel(setup_env, placements): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 2 # Number of tokens + D = 8 # Hidden dimension + + seed = 42 + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + + # Create input tensors with proper shapes + input_x_global = torch.rand((B, N, N, D), generator=rng, requires_grad=True, device=device_type) + input_g_global = torch.rand((B, N, N, D), generator=rng, requires_grad=True, device=device_type) + + # Run serial forward pass + input_x_global_host = input_x_global.detach().clone().cpu() + input_g_global_host = input_g_global.detach().clone().cpu() + output_expected_global = serial_sigmoid_gate(input_x_global, input_g_global) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Create upstream gradient and run backward pass + d_output_expected_global = torch.rand_like(output_expected_global) + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_sigmoid_gate, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_x_global_host, + input_g_global_host, + output_expected_global_host, + d_output_expected_global_host, + input_x_global.grad.detach().clone().cpu(), + input_g_global.grad.detach().clone().cpu(), + ) diff --git a/tests/distributed/model/layers/test_dtensor_squeeze.py b/tests/distributed/model/layers/test_dtensor_squeeze.py new file mode 100644 index 000000000..108d3c1a7 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_squeeze.py @@ -0,0 +1,611 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import unittest +from math import isqrt +from typing import Dict, Optional + +import pytest +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.squeeze import shardwise_squeeze, shardwise_unsqueeze +from boltz.testing.utils import assert_tensors_identical, seed_by_rank, spawn_multiprocessing + + +def compute_global_expectation_unsqueeze(batch_size, seq_len, feature_dim, dim_to_unsqueeze, device): + x = torch.rand(batch_size, seq_len, feature_dim, device=device, requires_grad=True) + + # Compute on global tensor using standard unsqueeze operation + y = x.unsqueeze(dim_to_unsqueeze) + + # Create gradients for backward pass + dy = torch.rand_like(y) + + # Backward pass on global tensor + y.backward(dy) + + return x.detach().clone(), y.detach().clone(), x.grad.detach().clone(), dy.detach().clone() + + +def compute_dtensor_native_unsqueeze( + x_global: torch.Tensor, + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + dim_to_unsqueeze: int, +) -> tuple[DTensor, DTensor]: + """Compute DTensor native operations for comparison.""" + # Create DTensor native input + x_dtensor = distribute_tensor(x_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + + # Forward pass with native DTensor unsqueeze operation + y_dtensor_result = x_dtensor.unsqueeze(dim_to_unsqueeze) + + # Backward pass with native DTensor op + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + y_dtensor_result.backward(dy_dtensor) + + x_grad_dtensor = x_dtensor.grad + + # do the view check on the DTensor native op + assert ( + y_dtensor_result.to_local().squeeze(dim_to_unsqueeze).is_set_to(x_dtensor.to_local()) + ), f"for dim {dim_to_unsqueeze} output local shard is not a view of the input shard for native DTensor op" + # do the view check on the DTensor native op for backward pass + + assert x_grad_dtensor.to_local().is_set_to( + dy_dtensor.to_local().squeeze(dim_to_unsqueeze) + ), f"for dim {dim_to_unsqueeze} input grad is not a view of the upstream adjoint for native DTensor op" + + return x_grad_dtensor, y_dtensor_result + + +def parallel_assert_dtensor_unsqueeze( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # each rank uses the same seed to generate the same input tensors + seed_by_rank(0, seed=42) + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters + batch_size = 2 + seq_len_per_rank = 4 + seq_len_global = size_ring * seq_len_per_rank + feature_dim = 3 + dims_to_unsqueeze = [0, 1, 2, 3, -1, -2, -3, -4] # Test various dimensions including negative indexing + + for dim_to_unsqueeze in dims_to_unsqueeze: + label_test_case = f"for dim {dim_to_unsqueeze}\n" + # Compute global expectations + x_global, y_expected_global, x_grad_expected_global, dy_global = compute_global_expectation_unsqueeze( + batch_size, seq_len_global, feature_dim, dim_to_unsqueeze, manager.device + ) + + # Create distributed tensors + # Shard the sequence dimension (dim=1) for input tensor + # this emulates the sharded single representation in the Boltz model + input_placements = (Shard(dim=0), Shard(dim=1), Replicate()) + + # use DTensor native op as an alternative reference + x_grad_dtensor_native, y_dtensor_result_native = compute_dtensor_native_unsqueeze( + x_global, dy_global, manager.device_mesh_subgroups, input_placements, dim_to_unsqueeze + ) + + # Create DTensor input + x_dtensor = distribute_tensor(x_global, manager.device_mesh_subgroups, input_placements).requires_grad_(True) + x_dtensor_copy = x_dtensor.detach().clone().requires_grad_(True) + + # Compute on distributed tensor using shardwise_unsqueeze + y_dtensor_result = shardwise_unsqueeze(x_dtensor, dim_to_unsqueeze) + + # check if the output local shard is a view of the input. We know + # that squeeze() and unsqueeze() guarantees a view of the input + # so we can use them here to do is_set_to() check, which otherwise + # wouldn't work because is_set_to() also checks the strides, which + # are different between the pre-squeeze/unsqueeze and post-squeeze/unsqueeze + assert ( + y_dtensor_result.to_local().squeeze(dim_to_unsqueeze).is_set_to(x_dtensor.to_local()) + ), f"{label_test_case} output local shard is not a view of the input shard" + + # verify no change to the fwd input + assert_tensors_identical(x_dtensor.to_local(), x_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # Distribute the upstream adjoint for backward pass + dy_dtensor = distribute_tensor(dy_global, manager.device_mesh_subgroups, y_dtensor_result.placements) + + # Perform backward pass + dy_dtensor_copy = dy_dtensor.detach().clone() + y_dtensor_result.backward(dy_dtensor) + + # check if the input grad is a view of the upstream adjoint + assert x_dtensor.grad.to_local().is_set_to( + dy_dtensor.to_local().squeeze(dim_to_unsqueeze) + ), f"{label_test_case} input grad is not a view of the upstream adjoint" + + # verify no change to the bwd input + assert_tensors_identical( + dy_dtensor.to_local(), dy_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False + ) + + # verify input gradient placements are consistent with input placements + assert ( + x_dtensor.grad.placements == input_placements + ), f"{label_test_case} inconsistent input gradient placements with input placements" + + # =================================================================== + # BLOCK 1: Check against DTensor native reference + # =================================================================== + + # check metadata against DTensor native + assert ( + y_dtensor_result.placements == y_dtensor_result_native.placements + ), f"{label_test_case} placements mismatch" + assert y_dtensor_result.shape == y_dtensor_result_native.shape, f"{label_test_case} shape mismatch" + assert y_dtensor_result.stride() == y_dtensor_result_native.stride(), f"{label_test_case} stride mismatch" + + # compare forward result with native DTensor op + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_result_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # assert input gradients' metadata and values against DTensor native + assert ( + x_dtensor.grad.placements == x_grad_dtensor_native.placements + ), f"{label_test_case} input gradient placements mismatch" + assert x_dtensor.grad.shape == x_grad_dtensor_native.shape, f"{label_test_case} input gradient shape mismatch" + assert ( + x_dtensor.grad.stride() == x_grad_dtensor_native.stride() + ), f"{label_test_case} input gradient stride mismatch" + + torch.testing.assert_close( + x_dtensor.grad.to_local(), + x_grad_dtensor_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + # compare global tensors between shardwise_unsqueeze and native DTensor results + y_result_global = y_dtensor_result.full_tensor() + y_result_global_native = y_dtensor_result_native.full_tensor() + + torch.testing.assert_close( + y_result_global, + y_result_global_native, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs native: {m}", + ) + + torch.testing.assert_close( + x_dtensor.grad.full_tensor(), + x_grad_dtensor_native.full_tensor(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + # =================================================================== + # BLOCK 2: Check against global serial expectation + # =================================================================== + + y_dtensor_expected = distribute_tensor( + y_expected_global, manager.device_mesh_subgroups, y_dtensor_result.placements + ) + + # Compare results with expected local shards + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_expected.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # compare forward result with global expectation + torch.testing.assert_close( + y_result_global, + y_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs global expectation: {m}", + ) + + # create distributed tensor from global results for local shard comparison + x_grad_dtensor_expected = distribute_tensor( + x_grad_expected_global, manager.device_mesh_subgroups, input_placements + ) + + # compare local shards with expected + torch.testing.assert_close( + x_dtensor.grad.to_local(), + x_grad_dtensor_expected.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + # compare global gradients with serial expectation + torch.testing.assert_close( + x_dtensor.grad.full_tensor(), + x_grad_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_dtensor_unsqueeze(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_dtensor_unsqueeze, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +def compute_global_expectation_squeeze(batch_size, seq_len, dim_to_squeeze, device): + # Create tensor with singleton dimensions that can be squeezed + # Shape will be (batch_size, seq_len) for 2D tensor matching device mesh + x = torch.rand(batch_size, seq_len, device=device) + x = x.unsqueeze(dim_to_squeeze) + x.requires_grad_(True) + + # Compute on global tensor using standard squeeze operation + y = x.squeeze(dim_to_squeeze) + + # Create gradients for backward pass + dy = torch.rand_like(y) + + # Backward pass on global tensor + y.backward(dy) + + return x.detach().clone(), y.detach().clone(), x.grad.detach().clone(), dy.detach().clone() + + +def compute_dtensor_native_squeeze( + x_global: torch.Tensor, + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + dim_to_squeeze: int, +) -> tuple[DTensor, DTensor]: + """Compute DTensor native operations for comparison.""" + # Create DTensor native input + x_dtensor = distribute_tensor(x_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + + # Forward pass with native DTensor squeeze operation + y_dtensor_result = x_dtensor.squeeze(dim_to_squeeze) + + # Backward pass with native DTensor op + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + y_dtensor_result.backward(dy_dtensor) + + x_grad_dtensor = x_dtensor.grad + + # do the view check on the DTensor native op + assert ( + y_dtensor_result.to_local().unsqueeze(dim_to_squeeze).is_set_to(x_dtensor.to_local()) + ), f"for dim {dim_to_squeeze} output local shard is not a view of the input shard for native DTensor op" + # do the view check on the DTensor native op for backward pass + + assert x_grad_dtensor.to_local().is_set_to( + dy_dtensor.to_local().unsqueeze(dim_to_squeeze) + ), f"for dim {dim_to_squeeze} input grad is not a view of the upstream adjoint for native DTensor op" + + return x_grad_dtensor, y_dtensor_result + + +def parallel_assert_dtensor_squeeze( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # each rank uses the same seed to generate the same input tensors + seed_by_rank(0, seed=42) + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters + batch_size = 2 + seq_len_per_rank = 4 + seq_len_global = size_ring * seq_len_per_rank + dims_to_squeeze = [0, 1, 2, -1, -2, -3] + + for dim_to_squeeze in dims_to_squeeze: + label_test_case = f"for dim {dim_to_squeeze}\n" + # Compute global expectations + x_global, y_expected_global, x_grad_expected_global, dy_global = compute_global_expectation_squeeze( + batch_size, seq_len_global, dim_to_squeeze, manager.device + ) + + # Create distributed tensors + # Shard the batch and sequence dimensions for input tensor + # Input shape is (batch_size, seq_len, 1) + input_placements = (Shard(dim=0), Shard(dim=1), Replicate()) + + # use DTensor native op as an alternative reference + wrapped_dim_to_squeeze = dim_to_squeeze if dim_to_squeeze >= 0 else dim_to_squeeze + x_global.ndim + dim_is_sharded = any( + placement + for placement in input_placements + if isinstance(placement, Shard) and placement.dim == wrapped_dim_to_squeeze + ) + + if not dim_is_sharded: + x_grad_dtensor_native, y_dtensor_result_native = compute_dtensor_native_squeeze( + x_global, dy_global, manager.device_mesh_subgroups, input_placements, dim_to_squeeze + ) + + # Create DTensor input + x_dtensor = distribute_tensor(x_global, manager.device_mesh_subgroups, input_placements).requires_grad_(True) + x_dtensor_copy = x_dtensor.detach().clone().requires_grad_(True) + + # Compute on distributed tensor using shardwise_squeeze + if dim_is_sharded: # short circuit if squeeze on sharded dimensions + with pytest.raises(ValueError, match=r"Cannot squeeze dimension .* as it is sharded"): + y_dtensor_result = shardwise_squeeze(x_dtensor, dim_to_squeeze) + continue + + y_dtensor_result = shardwise_squeeze(x_dtensor, dim_to_squeeze) + + # check if the output local shard is a view of the input. We know + # that squeeze() and unsqueeze() guarantees a view of the input + # so we can use them here to do is_set_to() check, which otherwise + # wouldn't work because is_set_to() also checks the strides, which + # are different between the pre-squeeze/unsqueeze and post-squeeze/unsqueeze + assert ( + y_dtensor_result.to_local().unsqueeze(dim_to_squeeze).is_set_to(x_dtensor.to_local()) + ), f"{label_test_case} output local shard is not a view of the input shard" + + # verify no change to the fwd input + assert_tensors_identical(x_dtensor.to_local(), x_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # Distribute the upstream adjoint for backward pass + dy_dtensor = distribute_tensor(dy_global, manager.device_mesh_subgroups, y_dtensor_result.placements) + + # Perform backward pass + dy_dtensor_copy = dy_dtensor.detach().clone() + y_dtensor_result.backward(dy_dtensor) + + # check if the input grad is a view of the upstream adjoint + assert x_dtensor.grad.to_local().is_set_to( + dy_dtensor.to_local().unsqueeze(dim_to_squeeze) + ), f"{label_test_case} input grad is not a view of the upstream adjoint" + + # verify no change to the bwd input + assert_tensors_identical( + dy_dtensor.to_local(), dy_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False + ) + + # verify input gradient placements are consistent with input placements + assert ( + x_dtensor.grad.placements == input_placements + ), f"{label_test_case} inconsistent input gradient placements with input placements" + + # =================================================================== + # BLOCK 1: Check against DTensor native reference + # =================================================================== + + # check metadata against DTensor native + assert ( + y_dtensor_result.placements == y_dtensor_result_native.placements + ), f"{label_test_case} placements mismatch" + assert y_dtensor_result.shape == y_dtensor_result_native.shape, f"{label_test_case} shape mismatch" + assert y_dtensor_result.stride() == y_dtensor_result_native.stride(), f"{label_test_case} stride mismatch" + + # compare forward result with native DTensor op + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_result_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # assert input gradients' metadata and values against DTensor native + assert ( + x_dtensor.grad.placements == x_grad_dtensor_native.placements + ), f"{label_test_case} input gradient placements mismatch" + assert x_dtensor.grad.shape == x_grad_dtensor_native.shape, f"{label_test_case} input gradient shape mismatch" + assert ( + x_dtensor.grad.stride() == x_grad_dtensor_native.stride() + ), f"{label_test_case} input gradient stride mismatch" + + torch.testing.assert_close( + x_dtensor.grad.to_local(), + x_grad_dtensor_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + # compare global tensors between shardwise_squeeze and native DTensor results + y_result_global = y_dtensor_result.full_tensor() + y_result_global_native = y_dtensor_result_native.full_tensor() + + torch.testing.assert_close( + y_result_global, + y_result_global_native, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs native: {m}", + ) + + torch.testing.assert_close( + x_dtensor.grad.full_tensor(), + x_grad_dtensor_native.full_tensor(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + # =================================================================== + # BLOCK 2: Check against global serial expectation + # =================================================================== + + y_dtensor_expected = distribute_tensor( + y_expected_global, manager.device_mesh_subgroups, y_dtensor_result.placements + ) + + # Compare results with expected local shards + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_expected.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # compare forward result with global expectation + torch.testing.assert_close( + y_result_global, + y_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs global expectation: {m}", + ) + + # create distributed tensor from global results for local shard comparison + x_grad_dtensor_expected = distribute_tensor( + x_grad_expected_global, manager.device_mesh_subgroups, input_placements + ) + + # compare local shards with expected + torch.testing.assert_close( + x_dtensor.grad.to_local(), + x_grad_dtensor_expected.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + # compare global gradients with serial expectation + torch.testing.assert_close( + x_dtensor.grad.full_tensor(), + x_grad_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_dtensor_squeeze(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_dtensor_squeeze, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/distributed/model/layers/test_dtensor_swiglu.py b/tests/distributed/model/layers/test_dtensor_swiglu.py new file mode 100755 index 000000000..29db3b282 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_swiglu.py @@ -0,0 +1,404 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +"""Tests a single instance of SwiGLU with DTensor + +Verification requirements + + V1: single-proc FW input tensor values unchanged by FW and BW + V2: single-proc BW input tensor values unchanged by BW + V3: single-proc FW input tensor grads are zero at padded locations (virtual atoms) + - for input tensors that require grads + + V4: multi-proc version of V1 + V5: multi-proc version of V2 + V6: multi-proc version of V3: implied by V3 and V9 + + V7: multi-proc FW input tensor values and meta match single-proc inputs + V8: multi-proc FW output tensor values close-to single-proc + V9: multi-proc FW input gradient values close-to single-proc + V10: multi-proc parameter gradient values close-to single-proc + V11: multi-proc parameter gradient values identical across proc's + +Implementation status + V1: implemented + V2: implemented + V3: NA (no padding in SwiGLU) + V4: implemented + V5: implemented + V6: NA (no padding in SwiGLU) + V7: same data + V8: implemented + V9: implemented + V10: NA (no parameters in SwiGLU) + V11: NA (no parameters in SwiGLU) + +Assertion threshold defaults for pytorch +""" + +import itertools + +import pytest +import torch +from torch import Tensor +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor +from torch.testing import assert_close + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.swiglu import SwiGLU as SwiGLUWithDTensor +from boltz.model.modules.utils import SwiGLU as SwiGluBoltz +from boltz.testing.utils import ( + assert_tensors_identical, + seed_by_rank, + skip_if_cuda_not_avail_or_device_count_less_than_word_size, + spawn_multiprocessing, +) + +SEED = 42 + + +def assert_swiglu_with_dtensor_fw_bw( + rank: int, # noqa + input_example: Tensor, + output_ref: Tensor, + output_grad_example: Tensor, + input_grads_ref: Tensor, + grid_group_sizes: tuple[int, ...], + device_type: str, + backend: str, + env_per_rank: dict[str, str], # noqa +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # ------------------------------------------------------------- + # Move inputs and ref outputs to device + # -------------------------------------------------------------- + input_example_device = input_example.detach().to(manager.device, copy=True) + output_ref_device = output_ref.detach().to(manager.device, copy=True) + input_grads_ref_device = input_grads_ref.detach().to(manager.device, copy=True) + output_grad_example = output_grad_example.detach().to(manager.device, copy=True) + + # ------------------------------------------------------------- + # Create module to test + # - do not need to load state_dict + # -------------------------------------------------------------- + multi_proc_module = SwiGLUWithDTensor() + multi_proc_module = multi_proc_module.train() + multi_proc_module = multi_proc_module.to(manager.device) + + # ----------------------------------------------------- + # Create input DTensors + # ---------------------------------------------------- + placements_for_single_rep_nonparam = (Shard(0), Shard(1), Replicate()) + + input_example_as_dtensor = distribute_tensor( + input_example_device, manager.device_mesh_subgroups, placements_for_single_rep_nonparam + ).requires_grad_(True) + + output_grad_example_as_dtensor = distribute_tensor( + output_grad_example, + manager.device_mesh_subgroups, + placements_for_single_rep_nonparam, + ).requires_grad_(False) + + # ------------------------------------------------- + # Run FW + # ------------------------------------------------- + input_example_clone_as_dtensor = input_example_as_dtensor.detach().clone().requires_grad_(True) + output_actual_as_dtensor: DTensor = multi_proc_module(input_example_as_dtensor) + + # ------------------------------------------------------- + # V4a: multi-proc FW input tensor values unchanged by FW and BW + # ------------------------------------------------------ + assert_tensors_identical( + input_example_clone_as_dtensor.full_tensor(), + input_example_as_dtensor.full_tensor(), + check_grad=False, + check_grad_fn=False, + ) + + # ------------------------------------------------- + # Run BW + # ------------------------------------------------- + output_grad_example_clone_as_dtensor = ( + output_grad_example_as_dtensor.detach().clone().requires_grad_(output_grad_example_as_dtensor.requires_grad) + ) + output_actual_as_dtensor.backward(output_grad_example_clone_as_dtensor) + + # ------------------------------------------------------- + # V8: multi-proc FW output tensor values close-to single-proc + # ------------------------------------------------------ + assert_close(output_actual_as_dtensor.full_tensor(), output_ref_device) + + # ------------------------------------------------------- + # V4b: multi-proc FW input tensor values unchanged by FW and BW + # - check again that input is unchanged by BW + # ------------------------------------------------------ + assert_tensors_identical( + input_example_clone_as_dtensor.full_tensor(), + input_example_as_dtensor.full_tensor(), + check_grad=False, + check_grad_fn=False, + ) + # ------------------------------------------------------- + # V5: multi-proc BW input tensor values unchanged by BW + # ------------------------------------------------------ + assert_tensors_identical( + output_grad_example_clone_as_dtensor.full_tensor(), + output_grad_example_as_dtensor.full_tensor(), + check_grad=False, + check_grad_fn=False, + ) + # ------------------------------------------------------- + # V9: multi-proc FW input gradient values close-to single-proc + # ------------------------------------------------------ + assert_close( + input_example_as_dtensor.grad.full_tensor(), + input_grads_ref_device, + ) + + # cleanup + DistributedManager.cleanup() + monkeypatch.undo() + + +def get_example_input_and_reference_output( + bs: int, + N_tokens: int, + dim: int, + seed: int = SEED, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Generate example input and reference output for SwiGLU testing. + + Parameters + ---------- + bs : int + Batch size. + N_tokens : int + Number of tokens. + dim : int + Input dimension (must be even for SwiGLU). + seed : int, optional + Random seed, by default SEED. + + Returns + ------- + tuple[Tensor, Tensor, Tensor, Tensor] + (input_example, output_ref, output_grad_example, input_grads_ref) + """ + # ---------------------------------------- + # Set random seed + # ---------------------------------------- + seed_by_rank(seed) + + # ---------------------------------------- + # Create input tensor + # ---------------------------------------- + input_example = torch.randn(bs, N_tokens, dim, requires_grad=True) + input_example_copy = input_example.detach().clone().requires_grad_(input_example.requires_grad) + + # ------------------------------------------- + # Create single-proc module and run serial FW + # ------------------------------------------- + single_proc_module = SwiGluBoltz() + single_proc_module = single_proc_module.train() + output_ref = single_proc_module(input_example) + + # ---------------------------------------------------------------- + # V1a: single-proc FW input tensor values unchanged by FW and BW + # ----------------------------------------------------------------- + assert_tensors_identical( + input_example, + input_example_copy, + check_grad=False, + check_grad_fn=False, + check_storage_pointer=False, + ) + + # ---------------------------------------- + # Create output gradient + # ---------------------------------------- + output_grad_example = torch.randn_like(output_ref, requires_grad=False) + output_grad_example_copy = output_grad_example.detach().clone().requires_grad_(output_grad_example.requires_grad) + + # ---------------------------------------- + # Serial BW Compute reference gradients + # ---------------------------------------- + torch.autograd.backward(output_ref, output_grad_example) + + # ---------------------------------------------------------------- + # V1b: single-proc FW input tensor values unchanged by FW and BW + # ----------------------------------------------------------------- + assert_tensors_identical( + input_example, + input_example_copy, + check_grad=False, + check_grad_fn=False, + check_storage_pointer=False, + ) + # ----------------------------------------------------------------------- + # V2: single-proc BW input tensor values unchanged by BW + # ----------------------------------------------------------------------- + assert_tensors_identical( + output_grad_example, + output_grad_example_copy, + check_grad=False, + check_grad_fn=False, + check_storage_pointer=False, + ) + + # Get input gradients + input_grads_ref = input_example.grad.clone() + + return input_example, output_ref.detach(), output_grad_example, input_grads_ref + + +@pytest.mark.parametrize( + "setup_env", + itertools.product([(1, (2, 2)), (2, (2, 2))], [True], ["cpu", "cuda"], ["ENV"]), + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_swiglu_with_dtensor( + setup_env: dict[str, int], + bs: int = 2, + N_tokens: int = 4**2, + dim: int = 2, + seed: int = SEED, +): + """Test SwiGLU with DTensor for various configurations. + + Parameters + ---------- + setup_env : dict[str, int] + Environment setup for distributed testing. + bs : int, optional + Batch size, by default 2. + N_tokens : int, optional + Number of tokens, by default 8**2. + dim : int, optional + Input dimension, by default 384. + seed : int, optional + Random seed, by default SEED. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + skip_if_cuda_not_avail_or_device_count_less_than_word_size( + device_type=device_type, + world_size=world_size, + ) + + # (0) Check use-case requirements + if dim % 2 != 0: + raise ValueError(f"Dimension must be even for SwiGLU. Got dim={dim}") + + # (1) Get example input and reference output + input_example, output_ref, output_grad_example, input_grads_ref = get_example_input_and_reference_output( + bs=bs, + N_tokens=N_tokens, + dim=dim, + seed=seed, + ) + + spawn_multiprocessing( + assert_swiglu_with_dtensor_fw_bw, + world_size, + input_example, + output_ref, + output_grad_example, + input_grads_ref, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +@pytest.mark.parametrize( + "setup_env", + itertools.product([(1, (1, 1))], [True], ["cpu", "cuda"], ["ENV"]), + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_swiglu_with_dtensor_for_metadata_checks( + setup_env: dict[str, int], + bs: int = 2, + N_tokens: int = 4**2, + dim: int = 2, + seed: int = SEED, +): + """Test SwiGLU with DTensor for various configurations. + + Parameters + ---------- + setup_env : dict[str, int] + Environment setup for distributed testing. + bs : int, optional + Batch size, by default 2. + N_tokens : int, optional + Number of tokens, by default 8**2. + dim : int, optional + Input dimension, by default 384. + seed : int, optional + Random seed, by default SEED. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + skip_if_cuda_not_avail_or_device_count_less_than_word_size( + device_type=device_type, + world_size=world_size, + ) + + # (0) Check use-case requirements + if dim % 2 != 0: + raise ValueError(f"Dimension must be even for SwiGLU. Got dim={dim}") + + # (1) Get example input and reference output + input_example, output_ref, output_grad_example, input_grads_ref = get_example_input_and_reference_output( + bs=bs, + N_tokens=N_tokens, + dim=dim, + seed=seed, + ) + + spawn_multiprocessing( + assert_swiglu_with_dtensor_fw_bw, + world_size, + input_example, + output_ref, + output_grad_example, + input_grads_ref, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/distributed/model/layers/test_dtensor_transition.py b/tests/distributed/model/layers/test_dtensor_transition.py new file mode 100644 index 000000000..9b9e35d53 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_transition.py @@ -0,0 +1,343 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.transition import Transition as DistributedTransition +from boltz.model.layers.transition import Transition +from boltz.testing.utils import ( + assert_all_identical, + assert_no_percentile_upshift, + assert_tensors_identical, + get_param_by_key, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_assert_transition( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + dim, + hidden, + out_dim, + layer_state_dict, + input_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_expected_global_host, + grad_params_expected_global_host, + output_global_fp32_host: torch.Tensor | None = None, + d_input_global_fp32_host: torch.Tensor | None = None, + grad_params_fp32_global_host: dict[str, torch.Tensor] | None = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + if torch.finfo(dtype).resolution < torch.finfo(output_expected_global_host.dtype).resolution: + raise ValueError( + f"Target dtype {dtype} has higher precision than reference output's dtype {output_expected_global_host.dtype}" + ) + + if ((output_global_fp32_host is None) != (d_input_global_fp32_host is None)) or ( + (output_global_fp32_host is not None) != (grad_params_fp32_global_host is not None) + ): + raise ValueError( + "output_global_fp32_host, d_input_global_fp32_host, and grad_params_fp32_global_host must be either all None or all not None" + ) + + check_error_hist = output_global_fp32_host is not None + + # Create serial reference module + module_serial = Transition(dim, hidden, out_dim) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.to(dtype=dtype, device=manager.device) + module_serial.train() + + # Create distributed module + module = DistributedTransition(module_serial, manager.device_mesh_subgroups) + module.train() + + # Input tensor has shape (B, S, D) - sharded on dims 0 and 1 (B and S) + placements_input = (Shard(0), Shard(1), Replicate()) + + # Distribute input tensor + input_dtensor = distribute_tensor( + input_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ).requires_grad_(True) + + # Distribute expected outputs + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + src_data_rank=None, + ) + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + ) + d_input_expected_dtensor = distribute_tensor( + d_input_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + src_data_rank=None, + ) + + # Create copies to verify inputs aren't modified + input_dtensor_copy = input_dtensor.detach().clone().requires_grad_(True) + + if check_error_hist: + # Forward and backward pass for error histogram checking + output_dtensor_result = module(input_dtensor) + output_dtensor_result.backward(d_output_expected_dtensor) + + output_fp32_dtensor = distribute_tensor( + output_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + src_data_rank=None, + ) + + d_input_fp32_dtensor = distribute_tensor( + d_input_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_input, + src_data_rank=None, + ) + + assert_no_percentile_upshift( + output_dtensor_result.to_local(), + output_expected_dtensor.to_local(), + output_fp32_dtensor.to_local(), + names_input=("output_cp_fp32", "output_serial_fp64", "output_serial_fp32"), + ) + + assert_no_percentile_upshift( + input_dtensor.grad.to_local(), + d_input_expected_dtensor.to_local(), + d_input_fp32_dtensor.to_local(), + names_input=("d_input_cp_fp32", "d_input_serial_fp64", "d_input_serial_fp32"), + ) + + # Check parameter gradients error histograms + for name, grad_param_expected_global in grad_params_expected_global_host.items(): + grad_param_result_global = get_param_by_key(module, name).grad.full_tensor().cpu() + assert_no_percentile_upshift( + grad_param_result_global, + grad_param_expected_global.to(dtype=grad_param_result_global.dtype), + grad_params_fp32_global_host[name], + names_input=(f"d_{name}_cp_fp32", f"d_{name}_serial_fp64", f"d_{name}_serial_fp32"), + ) + else: + # Forward pass + output_dtensor_result = module(input_dtensor) + + # Verify inputs weren't modified + assert_tensors_identical( + input_dtensor_copy.to_local(), input_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # Test forward pass results + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradients + torch.testing.assert_close(input_dtensor.grad.to_local(), d_input_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + input_global_result_host = input_dtensor.full_tensor().cpu() + output_global_result_host = output_dtensor_result.full_tensor().cpu() + d_input_global_result_host = input_dtensor.grad.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(input_global_result_host, input_global_host.to(dtype=dtype)) + torch.testing.assert_close(output_global_result_host, output_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_global_result_host, d_input_expected_global_host.to(dtype=dtype)) + + # Test parameter gradients + grad_params_result_dtensors = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in grad_params_expected_global_host: + # do an extra check here to make sure the parallel computation don't result in extra gradients + raise ValueError(f"Parameter {name} has a resulting gradient but it is not in the reference module") + grad_params_result_dtensors[name] = param.grad + + for name, grad_param_expected_global in grad_params_expected_global_host.items(): + assert name in grad_params_result_dtensors, f"Parameter {name}'s gradient is not found in result gradients" + grad_params_result = grad_params_result_dtensors[name] + # Test parameter gradients with full tensor gathering + param_grad_result = grad_params_result.full_tensor() + torch.testing.assert_close(param_grad_result.cpu(), grad_param_expected_global.to(dtype=dtype)) + assert_all_identical(param_grad_result, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize( + "dtype_and_check_error_hist", + [ + (torch.float32, False), + (torch.float32, True), + (torch.float64, False), + ], + ids=lambda x: f"dtype={x[0]}, check_error_hist={x[1]}", +) +def test_transition_parallel(setup_env, dtype_and_check_error_hist): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + # dtype is the dtype used by the parallel computation + # check_error_hist determine whether to compare the error histograms between + # (CP_in_FP32, serial_in_FP64) and (serial_in_FP32, serial_in_FP64) + # Typically, check_error_hist will use large input dimensions to emulate + # the real-world use cases. Same with dtype==torch.float64. + dtype, check_error_hist = dtype_and_check_error_hist + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + if check_error_hist: + if grid_group_sizes["dp"] > 1: + pytest.skip("skip error histogram check for dp > 1 to save test time") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + if check_error_hist or dtype == torch.float64: + S = size_ring * 128 # Sequence length + dim = 128 # Input dimension + hidden = 512 # Hidden dimension + out_dim = 128 # Output dimension (same as input) + else: + S = size_ring * 4 # Sequence length + dim = 16 # Input dimension + hidden = 64 # Hidden dimension + out_dim = 16 # Output dimension (same as input) + + seed = 42 + seed_by_rank(0, seed=seed) + + # compute reference results with FP64 + input_global_fp64 = torch.empty((B, S, dim), dtype=torch.float64, requires_grad=True, device=device_type) + + # Create reference serial module + reference_module = Transition(dim, hidden, out_dim) + + # Initialize parameters to ensure reproducible behavior + with torch.no_grad(): + input_global_fp64.uniform_(-5e-2, 5e-2) + for name, param in reference_module.named_parameters(): + param.uniform_(-5e-2, 5e-2) + + layer_state_dict_fp64 = reference_module.state_dict() + reference_module = reference_module.to(dtype=torch.float64, device=device_type).train() + + # Run forward pass + output_expected_global_fp64 = reference_module(input_global_fp64) + d_output_expected_global_fp64 = torch.rand_like(output_expected_global_fp64) + output_expected_global_fp64.backward(d_output_expected_global_fp64) + + grad_params_fp64_expected_global_host = { + name: param.grad.detach().clone().cpu() for name, param in reference_module.named_parameters() + } + + if check_error_hist: + input_global_fp32 = input_global_fp64.detach().clone().to(dtype=torch.float32).requires_grad_(True) + reference_module_fp32 = Transition(dim, hidden, out_dim) + reference_module_fp32.load_state_dict(layer_state_dict_fp64) + reference_module_fp32 = reference_module_fp32.to(dtype=torch.float32, device=device_type).train() + + output_global_fp32 = reference_module_fp32(input_global_fp32) + d_output_expected_global_fp32 = d_output_expected_global_fp64.to(dtype=torch.float32) + output_global_fp32.backward(d_output_expected_global_fp32) + + output_global_fp32_host = output_global_fp32.detach().clone().cpu() + d_input_global_fp32_host = input_global_fp32.grad.detach().clone().cpu() + grad_params_fp32_global_host = { + name: param.grad.detach().clone().cpu() for name, param in reference_module_fp32.named_parameters() + } + else: + output_global_fp32_host = None + d_input_global_fp32_host = None + grad_params_fp32_global_host = None + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_transition, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + dim, + hidden, + out_dim, + layer_state_dict_fp64, + input_global_fp64.detach().clone().cpu(), + output_expected_global_fp64.detach().clone().cpu(), + d_output_expected_global_fp64.detach().clone().cpu(), + input_global_fp64.grad.detach().clone().cpu(), + grad_params_fp64_expected_global_host, + output_global_fp32_host, + d_input_global_fp32_host, + grad_params_fp32_global_host, + ) diff --git a/tests/distributed/model/layers/test_dtensor_triangle_attention.py b/tests/distributed/model/layers/test_dtensor_triangle_attention.py new file mode 100644 index 000000000..fea21580e --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_triangle_attention.py @@ -0,0 +1,913 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import warnings + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.comm import Ring2DCommTriAttn +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.triangular_attention import ( + TriangleAttentionEndingNode as DistributedTriangleAttentionEndingNode, +) +from boltz.distributed.model.layers.triangular_attention import ( + TriangleAttentionStartingNode as DistributedTriangleAttentionStartingNode, +) +from boltz.distributed.model.layers.triangular_attention import ( + _Mode, + _RingMultiHeadTriangleAttentionImpl, + can_run_cueq_triattn_sm100f, + cueq_is_installed, +) +from boltz.distributed.model.modules.utils import TriAttnBackend +from boltz.model.layers.triangular_attention.attention import ( + TriangleAttentionEndingNode, + TriangleAttentionStartingNode, +) +from boltz.testing.utils import ( + assert_all_identical, + assert_no_percentile_upshift, + assert_tensors_identical, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_assert_triangle_attention( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + c_in, + c_hidden, + no_heads, + mode, + layer_state_dict, + input_x_global_host, + mask_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_x_expected_global_host, + grad_params_expected_global_host, + output_global_fp32_host: torch.Tensor | None = None, + d_input_x_global_fp32_host: torch.Tensor | None = None, + grad_params_fp32_global_host: dict[str, torch.Tensor] | None = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + if torch.finfo(dtype).resolution < torch.finfo(output_expected_global_host.dtype).resolution: + raise ValueError( + f"Target dtype {dtype} has higher precision than reference output's dtype {output_expected_global_host.dtype}" + ) + + if ((output_global_fp32_host is None) != (d_input_x_global_fp32_host is None)) or ( + (output_global_fp32_host is not None) != (grad_params_fp32_global_host is not None) + ): + raise ValueError( + "output_global_fp32_host, d_input_x_global_fp32_host, and grad_params_fp32_global_host must be either all None or all not None" + ) + + check_error_hist = output_global_fp32_host is not None + + layout_map = manager.layout_subgroups["cp"] + + # Set up communication based on mode + if mode == _Mode.Starting: + axis_cp = 1 + elif mode == _Mode.Ending: + axis_cp = 0 + else: + raise ValueError(f"Invalid mode {mode}") + + ring_comm = Ring2DCommTriAttn(manager.group["cp"], layout_map, axis_cp) + + # Create reference serial module + dtype_to_inf = {torch.float32: 1e9, torch.float64: 1e18} + if mode == _Mode.Starting: + module_serial = TriangleAttentionStartingNode(c_in, c_hidden, no_heads, inf=dtype_to_inf[dtype]) + elif mode == _Mode.Ending: + module_serial = TriangleAttentionEndingNode(c_in, c_hidden, no_heads, inf=dtype_to_inf[dtype]) + else: + raise ValueError(f"Invalid mode {mode}") + + module_serial = module_serial.to(dtype=dtype, device=manager.device) + module_serial.load_state_dict(layer_state_dict) + + # Create distributed module + if mode == _Mode.Starting: + module = DistributedTriangleAttentionStartingNode(module_serial, manager.device_mesh_subgroups, ring_comm) + elif mode == _Mode.Ending: + module = DistributedTriangleAttentionEndingNode(module_serial, manager.device_mesh_subgroups, ring_comm) + else: + raise ValueError(f"Invalid mode {mode}") + + module = module.to(device=manager.device).train() + + # Input tensors have the same sharding pattern: + # x: (B, N, N, D) - sharded on dims 1 and 2 (N and N) + # mask: (B, N, N) - sharded on dims 1 and 2 (N and N) + placements = (Shard(0), Shard(1), Shard(2)) + + # Distribute input tensors + input_x_dtensor = distribute_tensor( + input_x_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ).requires_grad_(True) + + mask_dtensor = distribute_tensor( + mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + + # Distribute expected outputs + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + d_input_x_expected_dtensor = distribute_tensor( + d_input_x_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Create copies to verify inputs aren't modified + input_x_dtensor_copy = input_x_dtensor.detach().clone().requires_grad_(True) + mask_dtensor_copy = mask_dtensor.detach().clone() + + if check_error_hist: + # Forward and backward pass for error histogram checking + output_dtensor_result = module(input_x_dtensor, mask_dtensor) + output_dtensor_result.backward(d_output_expected_dtensor) + + output_fp32_dtensor = distribute_tensor( + output_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + d_input_x_fp32_dtensor = distribute_tensor( + d_input_x_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + assert_no_percentile_upshift( + output_dtensor_result.to_local(), + output_expected_dtensor.to_local(), + output_fp32_dtensor.to_local(), + names_input=("output_cp_fp32", "output_serial_fp64", "output_serial_fp32"), + ) + + assert_no_percentile_upshift( + input_x_dtensor.grad.to_local(), + d_input_x_expected_dtensor.to_local(), + d_input_x_fp32_dtensor.to_local(), + names_input=("d_input_x_cp_fp32", "d_input_x_serial_fp64", "d_input_x_serial_fp32"), + ) + + # Check parameter gradients error histograms + for name, grad_param_expected_global in grad_params_expected_global_host.items(): + grad_param_result_global = get_param_by_key(module, name).grad.full_tensor().cpu() + assert_no_percentile_upshift( + grad_param_result_global, + grad_param_expected_global.to(dtype=grad_param_result_global.dtype), + grad_params_fp32_global_host[name], + names_input=(f"d_{name}_cp_fp32", f"d_{name}_serial_fp64", f"d_{name}_serial_fp32"), + ) + else: + # Forward pass + output_dtensor_result = module(input_x_dtensor, mask_dtensor) + + # Verify inputs weren't modified + assert_tensors_identical( + input_x_dtensor_copy.to_local(), input_x_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(mask_dtensor_copy.to_local(), mask_dtensor.to_local()) + + # Test forward pass results + assert output_dtensor_result.shape == output_expected_dtensor.shape + assert output_dtensor_result.stride() == output_expected_dtensor.stride() + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradients + assert input_x_dtensor.grad.shape == d_input_x_expected_dtensor.shape + assert input_x_dtensor.grad.stride() == d_input_x_expected_dtensor.stride() + torch.testing.assert_close(input_x_dtensor.grad.to_local(), d_input_x_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + d_input_x_global_result_host = input_x_dtensor.grad.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(output_global_result_host, output_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_x_global_result_host, d_input_x_expected_global_host.to(dtype=dtype)) + + # Test parameter gradients + grad_params_result_dtensors = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in grad_params_expected_global_host: + # do an extra check here to make sure the parallel computation don't result in extra gradients + raise ValueError(f"Parameter {name} has a resulting gradient but it is not in the reference module") + grad_params_result_dtensors[name] = param.grad + + for name, grad_param_expected_global_host in grad_params_expected_global_host.items(): + assert name in grad_params_result_dtensors, f"Parameter {name}'s gradient is not found in result gradients" + grad_params_result = grad_params_result_dtensors[name] + assert grad_params_result.shape == grad_param_expected_global_host.shape + assert grad_params_result.stride() == grad_param_expected_global_host.stride() + grad_params_result_global = grad_params_result.full_tensor() + torch.testing.assert_close(grad_params_result_global.cpu(), grad_param_expected_global_host.to(dtype=dtype)) + assert_all_identical(grad_params_result_global, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, dtype, check_error_hist", + ( + params_test := [ + (((1, (2, 2)), True, "cuda", "ENV"), torch.float32, False), + (((1, (2, 2)), True, "cuda", "ENV"), torch.float64, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, False), + (((1, (3, 3)), True, "cuda", "ENV"), torch.float32, False), + (((1, (3, 3)), True, "cpu", "ENV"), torch.float32, False), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}, check_error_hist:{x[2]}" + for x in params_test + ], +) +@pytest.mark.parametrize("mode", [_Mode.Starting, _Mode.Ending]) +def test_triangle_attention_parallel(setup_env, dtype, check_error_hist, mode): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + # dtype is the dtype used by the parallel computation + # check_error_hist determine whether to compare the error histograms between + # (CP_in_FP32, serial_in_FP64) and (serial_in_FP32, serial_in_FP64) + # Typically, check_error_hist will use large input dimensions to emulate + # the real-world use cases. Same with dtype==torch.float64. + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + if check_error_hist: + if grid_group_sizes["dp"] > 1: + pytest.skip("skip error histogram check for dp > 1 to save test time") + + # For float64 and error histogram check, we use a realistic model and input size + # with heavier computation to test the numerical stability. On the other hand, + # a smaller model and input size incur less numerical error accumulation to allow + # a larger range of input values to detect logical bugs inexpensively by using + # smaller dimensions. + test_large_model = check_error_hist or dtype == torch.float64 + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + if test_large_model: + N = size_ring * 128 # Number of tokens + c_in = 128 # Input dimension + c_hidden = 32 # Hidden dimension + no_heads = 4 # Number of attention heads + min_val_init = -5e-2 if dtype == torch.float64 else -5e-2 + max_val_init = -min_val_init + else: + N = size_ring * 4 # Number of tokens + c_in = 8 # Input dimension + c_hidden = 16 # Hidden dimension + no_heads = 2 # Number of attention heads + min_val_init = -0.5 + max_val_init = 0.5 + + seed = 42 + seed_by_rank(0, seed=seed) + + # compute reference results with FP64 + input_x_global_fp64 = torch.empty((B, N, N, c_in), dtype=torch.float64, requires_grad=True, device=device_type) + mask_global_fp64 = torch.randint(0, 2, (B, N, N), dtype=torch.float64, requires_grad=False, device=device_type) + # create pure padding chunk in the mask + mask_global_fp64[0, N // size_ring :] = 0 + mask_global_fp64[0, :, N // size_ring :] = 0 + + # Create reference serial module + if mode == _Mode.Starting: + reference_module = TriangleAttentionStartingNode(c_in, c_hidden, no_heads, inf=1e18) + elif mode == _Mode.Ending: + reference_module = TriangleAttentionEndingNode(c_in, c_hidden, no_heads, inf=1e18) + else: + raise ValueError(f"Invalid mode {mode}") + + # Initialize parameters to ensure reproducible behavior + # The output activation and gradient of the layer weights typically increase by 3 to 4 orders of magnitude, + # where the ULP would be too large and numerical error distribution becomes very wide, i.e., we would have + # very unpredictable numerical errors. That would make the test results very noisy and not very useful to + # detect logical bugs in the code. To avoid this, we use a smaller range for the input and layer weights. + init_tensors_uniform([input_x_global_fp64], low=min_val_init, high=max_val_init) + init_module_params_uniform(reference_module, low=min_val_init, high=max_val_init) + + reference_module = reference_module.to(dtype=torch.float64, device=device_type).train() + layer_state_dict_fp64 = reference_module.state_dict() + + # Run forward pass + output_expected_global_fp64 = reference_module(input_x_global_fp64, mask_global_fp64) + d_output_expected_global_fp64 = torch.rand_like(output_expected_global_fp64) + output_expected_global_fp64.backward(d_output_expected_global_fp64) + + grad_params_fp64_expected_global_host = { + name: param.grad.detach().clone().cpu() for name, param in reference_module.named_parameters() + } + + if check_error_hist: + input_x_global_fp32 = input_x_global_fp64.detach().clone().to(dtype=torch.float32).requires_grad_(True) + mask_global_fp32 = mask_global_fp64.detach().clone().to(dtype=torch.float32).requires_grad_(False) + + if mode == _Mode.Starting: + reference_module_fp32 = TriangleAttentionStartingNode(c_in, c_hidden, no_heads) + elif mode == _Mode.Ending: + reference_module_fp32 = TriangleAttentionEndingNode(c_in, c_hidden, no_heads) + else: + raise ValueError(f"Invalid mode {mode}") + + reference_module_fp32.load_state_dict(layer_state_dict_fp64) + reference_module_fp32 = reference_module_fp32.to(dtype=torch.float32, device=device_type).train() + + output_global_fp32 = reference_module_fp32(input_x_global_fp32, mask_global_fp32) + d_output_expected_global_fp32 = d_output_expected_global_fp64.to(dtype=torch.float32) + output_global_fp32.backward(d_output_expected_global_fp32) + + output_global_fp32_host = output_global_fp32.detach().clone().cpu() + d_input_x_global_fp32_host = input_x_global_fp32.grad.detach().clone().cpu() + grad_params_fp32_global_host = { + name: param.grad.detach().clone().cpu() for name, param in reference_module_fp32.named_parameters() + } + else: + output_global_fp32_host = None + d_input_x_global_fp32_host = None + grad_params_fp32_global_host = None + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_triangle_attention, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + c_in, + c_hidden, + no_heads, + mode, + layer_state_dict_fp64, + input_x_global_fp64.detach().clone().cpu(), + mask_global_fp64.detach().clone().cpu(), + output_expected_global_fp64.detach().clone().cpu(), + d_output_expected_global_fp64.detach().clone().cpu(), + input_x_global_fp64.grad.detach().clone().cpu(), + grad_params_fp64_expected_global_host, + output_global_fp32_host, + d_input_x_global_fp32_host, + grad_params_fp32_global_host, + ) + + +# --------------------------------------------------------------------------- +# test_cueq_triattn_sm100f_util +# --------------------------------------------------------------------------- + +try: + from cuequivariance_ops_torch.triangle_attention import _can_run_sm100f + + _cueq_ops_has_can_run_sm100f = True +except ImportError: + _cueq_ops_has_can_run_sm100f = False + + +@pytest.mark.parametrize( + "dim_token, dim_hidden, is_fwd, dtype", + [ + (8, 8, True, torch.bfloat16), + (8, 8, False, torch.bfloat16), + (7, 8, True, torch.bfloat16), + (7, 8, False, torch.bfloat16), + (8, 7, True, torch.bfloat16), + (8, 7, False, torch.bfloat16), + (8, 128, True, torch.bfloat16), + (8, 128, False, torch.bfloat16), + (8, 129, True, torch.bfloat16), + (8, 129, False, torch.bfloat16), + (16, 32, True, torch.bfloat16), + (16, 32, False, torch.bfloat16), + (16, 32, True, torch.float16), + (16, 32, False, torch.float16), + (16, 32, True, torch.float32), + (16, 32, False, torch.float32), + ], + ids=lambda v: str(v), +) +def test_cueq_triattn_sm100f_util(dim_token, dim_hidden, is_fwd, dtype): + """Verify can_run_cueq_triattn_sm100f matches cuEq's private _can_run_sm100f.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if not _cueq_ops_has_can_run_sm100f: + pytest.skip("cuequivariance_ops_torch._can_run_sm100f not available") + + device = torch.device("cuda", 0) + q = torch.empty(1, 1, 1, dim_token, dim_hidden, device=device, dtype=dtype) + k = torch.empty(1, 1, 1, dim_token, dim_hidden, device=device, dtype=dtype) + + cueq_result = _can_run_sm100f(q, k, training=not is_fwd) + # _can_run_sm100f returns (can_run: bool, device_cc: list) + expected = cueq_result[0] + result = can_run_cueq_triattn_sm100f(device, dtype, dim_token, dim_hidden, is_fwd) + assert result == expected, ( + f"can_run_cueq_triattn_sm100f({device}, {dtype}, {dim_token}, {dim_hidden}, {is_fwd}) = {result}, " + f"but _can_run_sm100f(q, k, training={not is_fwd}) = {cueq_result}" + ) + + +# --------------------------------------------------------------------------- +# test_triangle_attention_parallel_sm100f +# --------------------------------------------------------------------------- + +SM100F_BWD_WARNING_SUBSTR = "SM100f kernel expects bias to be of the same dtype as q" + + +def parallel_assert_sm100f_bwd_warning( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + c_in, + c_hidden, + no_heads, + mode, + layer_state_dict, + input_x_global_host, + mask_global_host, + d_output_global_host, + expect_warning, + mock_util_always_false, +): + """Worker: run distributed forward+backward and check SM100f warning.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + if mock_util_always_false: + import sys + + triattn_mod = sys.modules["boltz.distributed.model.layers.triangular_attention"] + monkeypatch.setattr(triattn_mod, "can_run_cueq_triattn_sm100f", lambda *_args, **_kw: False) + + dtype = torch.bfloat16 + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + layout_map = manager.layout_subgroups["cp"] + axis_cp = 1 if mode == _Mode.Starting else 0 + ring_comm = Ring2DCommTriAttn(manager.group["cp"], layout_map, axis_cp) + + if mode == _Mode.Starting: + module_serial = TriangleAttentionStartingNode(c_in, c_hidden, no_heads, inf=1e9) + else: + module_serial = TriangleAttentionEndingNode(c_in, c_hidden, no_heads, inf=1e9) + module_serial = module_serial.to(dtype=dtype) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.to(device=manager.device) + + if mode == _Mode.Starting: + module = DistributedTriangleAttentionStartingNode(module_serial, manager.device_mesh_subgroups, ring_comm) + else: + module = DistributedTriangleAttentionEndingNode(module_serial, manager.device_mesh_subgroups, ring_comm) + module = module.to(device=manager.device).train() + + placements = (Shard(0), Shard(1), Shard(2)) + input_x = distribute_tensor( + input_x_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ).requires_grad_(True) + mask = distribute_tensor( + mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + d_output = distribute_tensor( + d_output_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + + output = module(input_x, mask, triattn_backend=TriAttnBackend.CUEQ) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + output.backward(d_output) + + sm100f_msgs = [w for w in caught if SM100F_BWD_WARNING_SUBSTR in str(w.message)] + if expect_warning: + assert sm100f_msgs, f"Rank {rank}: expected SM100f bwd warning but none was emitted" + else: + assert not sm100f_msgs, f"Rank {rank}: SM100f bwd warning(s) emitted ({len(sm100f_msgs)}): " + "; ".join( + str(w.message) for w in sm100f_msgs + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=["setup_env"], + ids=["dp:2-cp:2x2-cuda"], +) +@pytest.mark.parametrize( + "use_util_to_condition_fp32_cast", + [True, False], + ids=["util_active", "util_mocked_false"], +) +def test_triangle_attention_parallel_sm100f(setup_env, use_util_to_condition_fp32_cast): + """Assert SM100f backward warning fires when util returns True, absent when mocked False.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + if not cueq_is_installed: + pytest.skip("cuequivariance_torch is not installed") + device_cc = torch.cuda.get_device_capability() + if device_cc not in ((10, 0), (10, 3)): + pytest.skip(f"GPU compute capability {device_cc} is not SM100/SM103") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 8 + c_in = 8 + c_hidden = 8 + no_heads = 2 + mode = _Mode.Starting + seed = 42 + seed_by_rank(0, seed=seed) + + input_x_global = torch.empty((B, N, N, c_in), dtype=torch.float64, device="cuda") + mask_global = torch.ones((B, N, N), dtype=torch.float64, device="cuda") + init_tensors_uniform([input_x_global], low=-0.5, high=0.5) + + reference_module = TriangleAttentionStartingNode(c_in, c_hidden, no_heads, inf=1e9) + init_module_params_uniform(reference_module, low=-0.5, high=0.5) + reference_module = reference_module.to(dtype=torch.float64, device="cuda") + layer_state_dict = reference_module.state_dict() + + d_output_global = torch.rand((B, N, N, c_in), dtype=torch.float64, device="cuda") + + # True: real util pre-casts bias to q.dtype -> cuEq sees correct dtype -> no warning. + # False: util mocked -> bias cast to fp32 -> cuEq internally detects SM100f and warns. + mock_util_always_false = not use_util_to_condition_fp32_cast + expect_warning = mock_util_always_false + + spawn_multiprocessing( + parallel_assert_sm100f_bwd_warning, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + c_in, + c_hidden, + no_heads, + mode, + layer_state_dict, + input_x_global.detach().clone().cpu(), + mask_global.detach().clone().cpu(), + d_output_global.detach().clone().cpu(), + expect_warning, + mock_util_always_false, + ) + + +# --------------------------------------------------------------------------- +# test_triangle_attention_bf16_autocast — dtype-preservation regression +# --------------------------------------------------------------------------- + + +_real_logsumexp = torch.logsumexp + + +def _logsumexp_fp32_promotion(*args, **kwargs): + """Wrapper that reproduces CUDA autocast's logsumexp FP32 promotion on CPU.""" + result = _real_logsumexp(*args, **kwargs) + if result.dtype in (torch.bfloat16, torch.float16): + result = result.to(dtype=torch.float32) + return result + + +def _unfixed_tiled_softmax_update(o_chunk, lse_m_chunk, amax_chunk, o=None, lse_m=None, amax=None): + """tiled_softmax_attention_update WITHOUT the .to(dtype=lse_m_chunk.dtype) fix. + + When torch.logsumexp is separately monkeypatched to promote BF16 → FP32 + (as CUDA autocast does), the missing cast-back causes the FP32 cascade: + lse_m(FP32) → delta_lse → sigmoid → o(FP32). + """ + has_amax = amax_chunk is not None + if o is None: + return o_chunk, lse_m_chunk, amax_chunk + + if has_amax: + d_lse_m = lse_m - lse_m_chunk + amax_next = torch.maximum(amax_chunk, amax) + delta_lse = amax_chunk - amax - d_lse_m + o = o - torch.sigmoid(delta_lse) * (o - o_chunk) + lse_m = lse_m_chunk + torch.logsumexp( + torch.cat([(amax - amax_next) + d_lse_m, amax_chunk - amax_next], dim=-1), + dim=-1, + keepdim=True, + ) + amax = amax_next + else: + d_lse_m = lse_m - lse_m_chunk + delta_lse = -d_lse_m + o = o - torch.sigmoid(delta_lse) * (o - o_chunk) + lse_m = lse_m - torch.nn.functional.logsigmoid(d_lse_m) + amax = None + + return o, lse_m, amax + + +def parallel_assert_bf16_autocast_dtype( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + c_in, + c_hidden, + no_heads, + weight_q_global_host, + weight_k_global_host, + weight_v_global_host, + input_x_global_host, + mask_global_host, + triangle_bias_global_host, + d_output_global_host, + use_bf16_logsoftmax_cast, +): + """Worker: call _RingMultiHeadTriangleAttentionImpl.apply() and assert + output / gradient dtypes. + + torch.logsumexp is always monkeypatched to promote half → FP32 (simulating + CUDA autocast). The ``use_bf16_logsoftmax_cast`` flag controls whether the + fixed tiled_softmax_attention_update (with .to(dtype=...) cast-back) or the + unfixed version is used. + """ + import sys + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + monkeypatch.setattr(torch, "logsumexp", _logsumexp_fp32_promotion) + + if not use_bf16_logsoftmax_cast: + triattn_mod = sys.modules["boltz.distributed.model.layers.triangular_attention"] + monkeypatch.setattr(triattn_mod, "tiled_softmax_attention_update", _unfixed_tiled_softmax_update) + + dtype = torch.bfloat16 + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + layout_map = manager.layout_subgroups["cp"] + axis_cp = 1 + ring_comm = Ring2DCommTriAttn(manager.group["cp"], layout_map, axis_cp) + + shard_placements = (Shard(0), Shard(1), Shard(2)) + replicate_placements = tuple(Replicate() for _ in range(manager.device_mesh_subgroups.ndim)) + + q_x = distribute_tensor( + input_x_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=shard_placements, + ).requires_grad_(True) + kv_x = distribute_tensor( + input_x_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=shard_placements, + ).requires_grad_(True) + mask = distribute_tensor( + mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=shard_placements, + ) + triangle_bias = distribute_tensor( + triangle_bias_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=shard_placements, + ).requires_grad_(True) + weight_q = distribute_tensor( + weight_q_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=replicate_placements, + ).requires_grad_(True) + weight_k = distribute_tensor( + weight_k_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=replicate_placements, + ).requires_grad_(True) + weight_v = distribute_tensor( + weight_v_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=replicate_placements, + ).requires_grad_(True) + + inf_val = 1e9 + with torch.amp.autocast("cpu", dtype=dtype): + output = _RingMultiHeadTriangleAttentionImpl.apply( + q_x, + kv_x, + mask, + triangle_bias, + weight_q, + weight_k, + weight_v, + no_heads, + c_hidden, + ring_comm, + inf_val, + TriAttnBackend.REFERENCE, + ) + + if use_bf16_logsoftmax_cast: + assert output.dtype == dtype, f"Rank {rank}: fwd output dtype {output.dtype}, expected {dtype} (fix active)" + else: + assert ( + output.dtype == torch.float32 + ), f"Rank {rank}: fwd output dtype {output.dtype}, expected float32 (bug should manifest)" + + d_output = distribute_tensor( + d_output_global_host.to(dtype=output.dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=shard_placements, + ) + + if use_bf16_logsoftmax_cast: + output.backward(d_output) + for name, tensor in [("q_x", q_x), ("kv_x", kv_x), ("triangle_bias", triangle_bias)]: + assert ( + tensor.grad.dtype == dtype + ), f"Rank {rank}: bwd grad {name} dtype {tensor.grad.dtype}, expected {dtype}" + else: + # With the bug, the FP32 output produces FP32 do_local in backward, + # which mixes with BF16 saved tensors (q, kT, v). CPU matmul rejects + # this mixed-dtype operand pair, confirming the bug propagates into + # backward. On CUDA the promotion would silently succeed but produce + # FP32 gradients. + with pytest.raises(RuntimeError, match="expected scalar type"): + output.backward(d_output) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [((1, (3, 3)), True, "cpu", "ENV")], + indirect=["setup_env"], + ids=["dp:1-cp:3x3-cpu"], +) +@pytest.mark.parametrize( + "use_bf16_logsoftmax_cast", + [True, False], + ids=["fixed", "buggy"], +) +def test_triangle_attention_bf16_autocast(setup_env, use_bf16_logsoftmax_cast): + """Regression: BF16 dtype preservation through ring attention under autocast. + + Under CUDA autocast, torch.logsumexp promotes BF16 → FP32. Without the + fix (.to(dtype=lse_m_chunk.dtype) in tiled_softmax_attention_update), the + FP32 cascades from lse_m → delta_lse → sigmoid → o after ≥3 ring steps, + making the forward output and all backward gradients FP32. + + Both cases monkeypatch torch.logsumexp to promote half → FP32 (simulating + CUDA autocast on CPU, since CPU logsumexp preserves dtype natively). + + use_bf16_logsoftmax_cast=True : fixed tiled_softmax_attention_update (with + .to(dtype=...) cast-back) — logsumexp still promotes but the fix casts + back, so all outputs/grads stay BF16. + use_bf16_logsoftmax_cast=False: unfixed tiled_softmax_attention_update — + FP32 from logsumexp cascades into o, confirming the bug path. + + Calls _RingMultiHeadTriangleAttentionImpl.apply() directly (not through the + module wrapper) so that downstream linears cannot mask the FP32 promotion. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + size_ring = grid_group_sizes["cp"][0] + B = 2 + N = size_ring * 4 + c_in = 8 + c_hidden = 16 + no_heads = 2 + seed = 42 + seed_by_rank(0, seed=seed) + + input_x_global = torch.randn(B, N, N, c_in, dtype=torch.float64) * 0.5 + mask_global = torch.ones(B, N, N, dtype=torch.float64) + triangle_bias_global = torch.randn(B, N, N, no_heads, dtype=torch.float64) * 0.1 + weight_q_global = torch.randn(no_heads * c_hidden, c_in, dtype=torch.float64) * 0.1 + weight_k_global = torch.randn(no_heads * c_hidden, c_in, dtype=torch.float64) * 0.1 + weight_v_global = torch.randn(no_heads * c_hidden, c_in, dtype=torch.float64) * 0.1 + d_output_global = torch.randn(B, N, N, no_heads * c_hidden, dtype=torch.float64) * 0.1 + + spawn_multiprocessing( + parallel_assert_bf16_autocast_dtype, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + c_in, + c_hidden, + no_heads, + weight_q_global, + weight_k_global, + weight_v_global, + input_x_global, + mask_global, + triangle_bias_global, + d_output_global, + use_bf16_logsoftmax_cast, + ) diff --git a/tests/distributed/model/layers/test_dtensor_triangular_mult.py b/tests/distributed/model/layers/test_dtensor_triangular_mult.py new file mode 100644 index 000000000..19156c140 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_triangular_mult.py @@ -0,0 +1,429 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import pytest +import torch +from torch.distributed.tensor import Shard, distribute_tensor + +from boltz.distributed.comm import Ring2DComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.triangular_mult import ( + TriangleMultiplicationIncoming as DistributedTriangleMultiplicationIncoming, +) +from boltz.distributed.model.layers.triangular_mult import ( + TriangleMultiplicationOutgoing as DistributedTriangleMultiplicationOutgoing, +) +from boltz.distributed.model.layers.triangular_mult import _Direction +from boltz.model.layers.triangular_mult import TriangleMultiplicationIncoming, TriangleMultiplicationOutgoing +from boltz.testing.utils import ( + assert_all_identical, + assert_no_percentile_upshift, + assert_tensors_identical, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_assert_triangle_multiplication( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + dim, + direction, + layer_state_dict, + input_x_global_host, + mask_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_x_expected_global_host, + grad_params_expected_global_host, + output_global_fp32_host: torch.Tensor | None = None, + d_input_x_global_fp32_host: torch.Tensor | None = None, + grad_params_fp32_global_host: dict[str, torch.Tensor] | None = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + if torch.finfo(dtype).resolution < torch.finfo(output_expected_global_host.dtype).resolution: + raise ValueError( + f"Target dtype {dtype} has higher precision than reference output's dtype {output_expected_global_host.dtype}" + ) + + if ((output_global_fp32_host is None) != (d_input_x_global_fp32_host is None)) or ( + (output_global_fp32_host is not None) != (grad_params_fp32_global_host is not None) + ): + raise ValueError( + "output_global_fp32_host, d_input_x_global_fp32_host, and grad_params_fp32_global_host must be either all None or all not None" + ) + + check_error_hist = output_global_fp32_host is not None + + layout_map = manager.layout_subgroups["cp"] + ring_comm = Ring2DComm(manager.group["cp"], manager.subgroups["cp"][0], layout_map) + + if direction == _Direction.Outgoing: + module_serial = TriangleMultiplicationOutgoing(dim) + elif direction == _Direction.Incoming: + module_serial = TriangleMultiplicationIncoming(dim) + else: + raise ValueError(f"Invalid direction {direction}") + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.to(dtype=dtype, device=manager.device) + + if direction == _Direction.Outgoing: + module = DistributedTriangleMultiplicationOutgoing(module_serial, manager.device_mesh_subgroups, ring_comm) + elif direction == _Direction.Incoming: + module = DistributedTriangleMultiplicationIncoming(module_serial, manager.device_mesh_subgroups, ring_comm) + else: + raise ValueError(f"Invalid direction {direction}") + module = module.train() + + # Input tensors have the same sharding pattern: + # x: (B, N, N, D) - sharded on dims 1 and 2 (N and N) + # mask: (B, N, N) - sharded on dims 1 and 2 (N and N) + placements = (Shard(0), Shard(1), Shard(2)) + + # Distribute input tensors + input_x_dtensor = distribute_tensor( + input_x_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ).requires_grad_(True) + + mask_dtensor = distribute_tensor( + mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + + # Distribute expected outputs + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + d_input_x_expected_dtensor = distribute_tensor( + d_input_x_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Create copies to verify inputs aren't modified + input_x_dtensor_copy = input_x_dtensor.detach().clone().requires_grad_(True) + mask_dtensor_copy = mask_dtensor.detach().clone() + + if check_error_hist: + # Forward and backward pass for error histogram checking + output_dtensor_result = module(input_x_dtensor, mask_dtensor) + output_dtensor_result.backward(d_output_expected_dtensor) + + output_fp32_dtensor = distribute_tensor( + output_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + d_input_x_fp32_dtensor = distribute_tensor( + d_input_x_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Check that the output tensor has the correct shape + assert ( + output_dtensor_result.shape == output_expected_dtensor.shape + ), f"Output DTensor has shape {output_dtensor_result.shape} but expected shape {output_expected_dtensor.shape}" + + # Check that the output tensor has the correct shape + assert ( + output_dtensor_result.stride() == output_expected_dtensor.stride() + ), f"Output DTensor has stride {output_dtensor_result.stride()} but expected stride {output_expected_dtensor.stride()}" + + assert ( + input_x_dtensor.grad.shape == d_input_x_expected_dtensor.shape + ), f"Input DTensor grad has shape {input_x_dtensor.grad.shape} but expected shape {d_input_x_expected_dtensor.shape}" + + assert ( + input_x_dtensor.grad.stride() == d_input_x_expected_dtensor.stride() + ), f"Input DTensor grad has stride {input_x_dtensor.grad.stride()} but expected stride {d_input_x_expected_dtensor.stride()}" + + assert_no_percentile_upshift( + output_dtensor_result.to_local(), + output_expected_dtensor.to_local(), + output_fp32_dtensor.to_local(), + names_input=("output_cp_fp32", "output_serial_fp64", "output_serial_fp32"), + ) + + assert_no_percentile_upshift( + input_x_dtensor.grad.to_local(), + d_input_x_expected_dtensor.to_local(), + d_input_x_fp32_dtensor.to_local(), + names_input=("d_input_x_cp_fp32", "d_input_x_serial_fp64", "d_input_x_serial_fp32"), + ) + + # Check parameter gradients error histograms + for name, grad_param_expected_global in grad_params_expected_global_host.items(): + grad_param_result_global = get_param_by_key(module, name).grad.full_tensor().cpu() + assert_no_percentile_upshift( + grad_param_result_global, + grad_param_expected_global.to(dtype=grad_param_result_global.dtype), + grad_params_fp32_global_host[name], + names_input=(f"d_{name}_cp_fp32", f"d_{name}_serial_fp64", f"d_{name}_serial_fp32"), + ) + else: + # Forward pass + output_dtensor_result = module(input_x_dtensor, mask_dtensor) + + # Check that the output tensor has the correct shape + assert ( + output_dtensor_result.shape == output_expected_dtensor.shape + ), f"Output DTensor has shape {output_dtensor_result.shape} but expected shape {output_expected_dtensor.shape}" + + # Check that the output tensor has the correct shape + assert ( + output_dtensor_result.stride() == output_expected_dtensor.stride() + ), f"Output DTensor has stride {output_dtensor_result.stride()} but expected stride {output_expected_dtensor.stride()}" + + # Verify inputs weren't modified + assert_tensors_identical( + input_x_dtensor_copy.to_local(), input_x_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(mask_dtensor_copy.to_local(), mask_dtensor.to_local()) + + # Test forward pass results + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + assert ( + input_x_dtensor.grad.shape == d_input_x_expected_dtensor.shape + ), f"Input DTensor grad has shape {input_x_dtensor.grad.shape} but expected shape {d_input_x_expected_dtensor.shape}" + + assert ( + input_x_dtensor.grad.stride() == d_input_x_expected_dtensor.stride() + ), f"Input DTensor grad has stride {input_x_dtensor.grad.stride()} but expected stride {d_input_x_expected_dtensor.stride()}" + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradients + torch.testing.assert_close(input_x_dtensor.grad.to_local(), d_input_x_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + d_input_x_global_result_host = input_x_dtensor.grad.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(output_global_result_host, output_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_x_global_result_host, d_input_x_expected_global_host.to(dtype=dtype)) + + # Test parameter gradients + grad_params_result_dtensors = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in grad_params_expected_global_host: + # do an extra check here to make sure the parallel computation don't result in extra gradients + raise ValueError(f"Parameter {name} has a resulting gradient but it is not in the reference module") + grad_params_result_dtensors[name] = param.grad + + for name, grad_param_expected_global_host in grad_params_expected_global_host.items(): + assert name in grad_params_result_dtensors, f"Parameter {name}'s gradient is not found in result gradients" + grad_params_result = grad_params_result_dtensors[name] + grad_params_result_global = grad_params_result.full_tensor() + torch.testing.assert_close(grad_params_result_global.cpu(), grad_param_expected_global_host.to(dtype=dtype)) + assert_all_identical(grad_params_result_global, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, dtype, check_error_hist", + ( + params_test := [ + (((1, (2, 2)), True, "cuda", "ENV"), torch.float32, True), + (((1, (2, 2)), True, "cuda", "ENV"), torch.float64, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, True), + (((1, (3, 3)), True, "cuda", "ENV"), torch.float32, False), + (((1, (3, 3)), True, "cpu", "ENV"), torch.float32, False), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}, check_error_hist:{x[2]}" + for x in params_test + ], +) +@pytest.mark.parametrize("direction", [_Direction.Outgoing, _Direction.Incoming]) +def test_triangle_multiplication_parallel(setup_env, dtype, check_error_hist, direction): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + # dtype is the dtype used by the parallel computation + # check_error_hist determine whether to compare the error histograms between + # (CP_in_FP32, serial_in_FP64) and (serial_in_FP32, serial_in_FP64) + # Typically, check_error_hist will use large input dimensions to emulate + # the real-world use cases. Same with dtype==torch.float64. + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + if check_error_hist: + if grid_group_sizes["dp"] > 1: + pytest.skip("skip error histogram check for dp > 1 to save test time") + + # For float64 and error histogram check, we use a realistic model and input size + # with heavier computation to test the numerical stability. On the other hand, + # a smaller model and input size incur less numerical error accumulation to allow + # a larger range of input values to detect logical bugs inexpensively by using + # smaller dimensions. + test_large_model = check_error_hist or dtype == torch.float64 + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + if test_large_model: + N = size_ring * 128 # Number of tokens + dim = 128 # Hidden dimension + min_val_init = -5e-2 if dtype == torch.float64 else -1e-3 + max_val_init = -min_val_init + else: + N = size_ring * 4 # Number of tokens + dim = 8 # Hidden dimension + min_val_init = -0.5 + max_val_init = 0.5 + + seed = 42 + seed_by_rank(0, seed=seed) + + # compute reference results with FP64 + input_x_global_fp64 = torch.empty((B, N, N, dim), dtype=torch.float64, requires_grad=True, device=device_type) + mask_global_fp64 = torch.randint(0, 2, (B, N, N), dtype=torch.float64, requires_grad=False, device=device_type) + + # emulate blocks of pure padding + mask_global_fp64[0, N // size_ring :, :] = 0 + mask_global_fp64[0, :, N // size_ring :] = 0 + + # Create reference serial module + if direction == _Direction.Outgoing: + reference_module = TriangleMultiplicationOutgoing(dim) + elif direction == _Direction.Incoming: + reference_module = TriangleMultiplicationIncoming(dim) + else: + raise ValueError(f"Invalid direction {direction}") + + # The output activation and gradient of the layer weights typically increase by 2 to 3 orders of magnitude, + # where the ULP would be too large and numerical error distribution becomes very wide, i.e., we would have + # very unpredictable numerical errors. That would make the test results very noisy and not very useful to + # detect logical bugs in the code. To avoid this, we use a smaller range for the input and layer weights. + init_tensors_uniform([input_x_global_fp64], low=min_val_init, high=max_val_init) + init_module_params_uniform(reference_module, low=min_val_init, high=max_val_init) + + layer_state_dict_fp64 = reference_module.state_dict() + reference_module = reference_module.to(dtype=torch.float64, device=device_type).train() + + # Run forward pass + output_expected_global_fp64 = reference_module(input_x_global_fp64, mask_global_fp64) + d_output_expected_global_fp64 = torch.rand_like(output_expected_global_fp64) + output_expected_global_fp64.backward(d_output_expected_global_fp64) + + grad_params_fp64_expected_global_host = { + name: param.grad.detach().clone().cpu() for name, param in reference_module.named_parameters() + } + + if check_error_hist: + input_x_global_fp32 = input_x_global_fp64.detach().clone().to(dtype=torch.float32).requires_grad_(True) + mask_global_fp32 = mask_global_fp64.detach().clone().to(dtype=torch.float32).requires_grad_(False) + + if direction == _Direction.Outgoing: + reference_module_fp32 = TriangleMultiplicationOutgoing(dim) + elif direction == _Direction.Incoming: + reference_module_fp32 = TriangleMultiplicationIncoming(dim) + else: + raise ValueError(f"Invalid direction {direction}") + + reference_module_fp32.load_state_dict(layer_state_dict_fp64) + reference_module_fp32 = reference_module_fp32.to(dtype=torch.float32, device=device_type).train() + + output_global_fp32 = reference_module_fp32(input_x_global_fp32, mask_global_fp32) + d_output_expected_global_fp32 = d_output_expected_global_fp64.to(dtype=torch.float32) + output_global_fp32.backward(d_output_expected_global_fp32) + + output_global_fp32_host = output_global_fp32.detach().clone().cpu() + d_input_x_global_fp32_host = input_x_global_fp32.grad.detach().clone().cpu() + grad_params_fp32_global_host = { + name: param.grad.detach().clone().cpu() for name, param in reference_module_fp32.named_parameters() + } + else: + output_global_fp32_host = None + d_input_x_global_fp32_host = None + grad_params_fp32_global_host = None + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_triangle_multiplication, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + dim, + direction, + layer_state_dict_fp64, + input_x_global_fp64.detach().clone().cpu(), + mask_global_fp64.detach().clone().cpu(), + output_expected_global_fp64.detach().clone().cpu(), + d_output_expected_global_fp64.detach().clone().cpu(), + input_x_global_fp64.grad.detach().clone().cpu(), + grad_params_fp64_expected_global_host, + output_global_fp32_host, + d_input_x_global_fp32_host, + grad_params_fp32_global_host, + ) diff --git a/tests/distributed/model/layers/test_dtensor_unflatten.py b/tests/distributed/model/layers/test_dtensor_unflatten.py new file mode 100644 index 000000000..c79ff6b30 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_unflatten.py @@ -0,0 +1,742 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from math import isqrt +from typing import Dict, Optional + +import pytest +import torch +from torch.distributed.tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.flatten_and_unflatten import shardwise_unflatten, shardwise_unflatten_sharded +from boltz.testing.utils import assert_tensors_identical, seed_by_rank, spawn_multiprocessing + + +def compute_global_expectation(shape, dim, sizes, device): + """Compute global expectation using standard PyTorch operations.""" + # Create tensor for unflattening + x = torch.rand(*shape, device=device, requires_grad=True) + + # Compute on global tensor using standard unflatten operation + y = torch.unflatten(x, dim=dim, sizes=sizes) + + # Create gradients for backward pass + dy = torch.rand_like(y) + + # Backward pass on global tensor + y.backward(dy) + + # Collect input gradient + input_grad = x.grad.detach().clone() + + return x.detach().clone(), y.detach().clone(), input_grad, dy.detach().clone() + + +def compute_dtensor_native( + x_global: torch.Tensor, + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + dim: int, + sizes: tuple[int, ...], +) -> tuple[DTensor, DTensor]: + """Compute DTensor native operations for comparison.""" + # Create DTensor native input + x_dtensor = distribute_tensor(x_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + + # Forward pass with native DTensor unflatten operation + y_dtensor_result = torch.unflatten(x_dtensor, dim=dim, sizes=sizes) + + # Backward pass with native DTensor op + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + y_dtensor_result.backward(dy_dtensor) + + x_grad_dtensor = x_dtensor.grad + + return x_grad_dtensor, y_dtensor_result + + +def compute_shardwise_unflatten_with_validation( + x_global: torch.Tensor, + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + dim: int, + sizes: tuple[int, ...], + label_test_case: str, +) -> tuple[DTensor, DTensor, DTensor]: + """ + Compute shardwise_unflatten forward and backward pass with input validation checks. + + Returns: + y_dtensor_result: Forward pass result + x_dtensor: Input tensor with computed gradient + dy_dtensor: Distributed upstream gradient + """ + # Create DTensor input + x_dtensor = distribute_tensor(x_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + x_dtensor_copy = x_dtensor.detach().clone().requires_grad_(True) + + # Compute on distributed tensor using shardwise_unflatten + y_dtensor_result = shardwise_unflatten(x_dtensor, dim=dim, sizes=sizes) + + # verify no change to the fwd input + assert_tensors_identical(x_dtensor.to_local(), x_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # Distribute the upstream adjoint for backward pass + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + + # Perform backward pass + dy_dtensor_copy = dy_dtensor.detach().clone() + y_dtensor_result.backward(dy_dtensor) + + # verify no change to the bwd input + assert_tensors_identical(dy_dtensor.to_local(), dy_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # verify input gradient placements are consistent with input placements + assert ( + x_dtensor.grad.placements == input_placements + ), f"{label_test_case} inconsistent input gradient placements with input placements" + + return y_dtensor_result, x_dtensor, dy_dtensor + + +def parallel_assert_dtensor_unflatten( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # each rank uses the same seed to generate the same input tensors + seed_by_rank(0, seed=42) + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters - 8D tensor for comprehensive testing + shape = (3, 5, grid_group_sizes["dp"] * 2, 5, size_ring * 4, 5, 3, 2) + # Shard the sequence dimension (dim=2) and another dimension (dim=4) for input tensor + # this emulates the sharded single representation in the Boltz model + input_placements = (Shard(dim=2), Shard(dim=4), Replicate()) + + # Test valid unflattening dimensions (not sharded) + # Sharded dims are 2 and 4, so valid unflatten dims must not be these dimensions + # For each dim, we need sizes that multiply to the original dimension size + valid_unflatten_params = [ + (0, (1, 3)), # unflatten dim 0 (size 3) into (1, 3) + (0, (3, 1)), # unflatten dim 0 (size 3) into (3, 1) + (1, (1, 5)), # unflatten dim 1 (size 5) into (1, 5) + (1, (5, 1)), # unflatten dim 1 (size 5) into (5, 1) + (3, (1, 5)), # unflatten dim 3 (size 5) into (1, 5) + (5, (1, 5)), # unflatten dim 5 (size 5) into (1, 5) + (6, (1, 3)), # unflatten dim 6 (size 3) into (1, 3) + (6, (3, 1)), # unflatten dim 6 (size 3) into (3, 1) + (7, (1, 2)), # unflatten dim 7 (size 2) into (1, 2) + (7, (2, 1)), # unflatten dim 7 (size 2) into (2, 1) + (-1, (1, 2)), # unflatten dim -1 (last dim, size 2) into (1, 2) + (-2, (1, 3)), # unflatten dim -2 (dim 6, size 3) into (1, 3) + ] + + # Test invalid unflattening dimensions (sharded dims 2 and 4) + invalid_unflatten_params = [ + (2, (1, grid_group_sizes["dp"] * 2)), # unflatten sharded dim 2 + (2, (grid_group_sizes["dp"], 2)), # unflatten sharded dim 2 (different split) + (4, (1, size_ring * 4)), # unflatten sharded dim 4 + (4, (size_ring, 4)), # unflatten sharded dim 4 (different split) + (-6, (1, grid_group_sizes["dp"] * 2)), # unflatten dim -6 (equivalent to dim 2) + (-4, (1, size_ring * 4)), # unflatten dim -4 (equivalent to dim 4) + ] + + # Test valid unflattening dimensions + for dim, sizes in valid_unflatten_params: + label_test_case = f"for dim={dim}, sizes={sizes}\n" + + # Compute global expectations + x_global, y_expected_global, x_grad_expected_global, dy_global = compute_global_expectation( + shape, dim, sizes, manager.device + ) + + # use DTensor native op as an alternative reference + x_grad_dtensor_native, y_dtensor_result_native = compute_dtensor_native( + x_global, dy_global, manager.device_mesh_subgroups, input_placements, dim, sizes + ) + + # Compute shardwise_unflatten forward and backward with validation + y_dtensor_result, x_dtensor, dy_dtensor = compute_shardwise_unflatten_with_validation( + x_global, dy_global, manager.device_mesh_subgroups, input_placements, dim, sizes, label_test_case + ) + + # =================================================================== + # BLOCK 1: Check against DTensor native reference + # =================================================================== + + # check metadata against DTensor native + assert ( + y_dtensor_result.placements == y_dtensor_result_native.placements + ), f"{label_test_case} placements mismatch" + assert y_dtensor_result.shape == y_dtensor_result_native.shape, f"{label_test_case} shape mismatch" + assert y_dtensor_result.stride() == y_dtensor_result_native.stride(), f"{label_test_case} stride mismatch" + + # compare forward result with native DTensor op + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_result_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # compare global tensors between shardwise_unflatten and native DTensor results + y_result_global = y_dtensor_result.full_tensor() + y_result_global_native = y_dtensor_result_native.full_tensor() + + torch.testing.assert_close( + y_result_global, + y_result_global_native, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs native: {m}", + ) + + # assert input gradient metadata and values against DTensor native + assert ( + x_dtensor.grad.placements == x_grad_dtensor_native.placements + ), f"{label_test_case} input gradient placements mismatch" + assert x_dtensor.grad.shape == x_grad_dtensor_native.shape, f"{label_test_case} input gradient shape mismatch" + assert ( + x_dtensor.grad.stride() == x_grad_dtensor_native.stride() + ), f"{label_test_case} input gradient stride mismatch" + + torch.testing.assert_close( + x_dtensor.grad.to_local(), + x_grad_dtensor_native.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + torch.testing.assert_close( + x_dtensor.grad.full_tensor(), + x_grad_dtensor_native.full_tensor(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient mismatch: {m}", + ) + + # =================================================================== + # BLOCK 2: Check against global serial expectation + # =================================================================== + y_dtensor_expected = distribute_tensor( + y_expected_global, manager.device_mesh_subgroups, y_dtensor_result.placements + ) + + # Compare results with expected local shards + torch.testing.assert_close( + y_dtensor_result.to_local(), + y_dtensor_expected.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} {m}", + ) + + # compare forward result with global expectation + torch.testing.assert_close( + y_result_global, + y_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} output vs global expectation: {m}", + ) + + # create distributed tensor from global result for local shard comparison + x_grad_expected_dtensor = distribute_tensor( + x_grad_expected_global, manager.device_mesh_subgroups, input_placements + ) + + # compare local shard with expected + torch.testing.assert_close( + x_dtensor.grad.to_local(), + x_grad_expected_dtensor.to_local(), + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + torch.testing.assert_close( + x_dtensor.grad.full_tensor(), + x_grad_expected_global, + atol=0, + rtol=0, + msg=lambda m: f"{label_test_case} input gradient vs global expectation: {m}", + ) + + # Test invalid unflattening dimensions (should raise NotImplementedError) + for dim, sizes in invalid_unflatten_params: + label_test_case = f"for invalid dim={dim}, sizes={sizes}\n" + + # Compute global expectations (this should work fine) + x_global, _, _, _ = compute_global_expectation(shape, dim, sizes, manager.device) + + # Create DTensor input + x_dtensor = distribute_tensor(x_global, manager.device_mesh_subgroups, input_placements) + x_dtensor.requires_grad = True + + # This should raise due to sharded dimension being unflattened + with pytest.raises(NotImplementedError, match="Unflattening dimension .* shared by device_mesh axis"): + shardwise_unflatten(x_dtensor, dim=dim, sizes=sizes) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_dtensor_unflatten(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_dtensor_unflatten, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +# ============================================================================== +# Tests for shardwise_unflatten_sharded +# ============================================================================== + + +def compute_shardwise_unflatten_sharded_with_validation( + x_global: torch.Tensor, + dy_global: torch.Tensor, + device_mesh: DeviceMesh, + input_placements: tuple[Placement, ...], + dim: int, + sizes: tuple[int, ...], + label_test_case: str, +) -> tuple[DTensor, DTensor, DTensor]: + """ + Compute shardwise_unflatten_sharded forward and backward pass with input validation checks. + + Returns: + y_dtensor_result: Forward pass result + x_dtensor: Input tensor with computed gradient + dy_dtensor: Distributed upstream gradient + """ + # Create DTensor input + x_dtensor = distribute_tensor(x_global.detach().clone(), device_mesh, input_placements).requires_grad_(True) + x_dtensor_copy = x_dtensor.detach().clone().requires_grad_(True) + + # Compute on distributed tensor using shardwise_unflatten_sharded + y_dtensor_result = shardwise_unflatten_sharded(x_dtensor, axis=dim, sizes=sizes) + + # verify no change to the fwd input + assert_tensors_identical(x_dtensor.to_local(), x_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # Distribute the upstream adjoint for backward pass + dy_dtensor = distribute_tensor(dy_global.detach().clone(), device_mesh, y_dtensor_result.placements) + + # Perform backward pass + dy_dtensor_copy = dy_dtensor.detach().clone() + y_dtensor_result.backward(dy_dtensor) + + # verify no change to the bwd input + assert_tensors_identical(dy_dtensor.to_local(), dy_dtensor_copy.to_local(), check_grad=False, check_grad_fn=False) + + # verify input gradient placements are consistent with input placements + assert ( + x_dtensor.grad.placements == input_placements + ), f"{label_test_case} inconsistent input gradient placements with input placements" + + return y_dtensor_result, x_dtensor, dy_dtensor + + +def parallel_assert_dtensor_unflatten_sharded( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + """ + Test shardwise_unflatten_sharded which unflattens a sharded dimension. + + Unlike shardwise_unflatten, this function is designed to unflatten a dimension that is + itself sharded. DTensor native op doesn't support this, so we only compare + against the global serial version as reference. + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # each rank uses the same seed to generate the same input tensors + seed_by_rank(0, seed=42) + + size_dp = grid_group_sizes["dp"] + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters - 5D tensor + # Shape designed so that sharded dims can be evenly divided + # dim=1 is sharded by dp with size (size_dp * 4 * 3) + # dim=2 is sharded by ring with size (size_ring * 6 * 2) + shape = (2, size_dp * 4 * 3, size_ring * 6 * 2, 5, 4) + + # Test cases: unflatten a sharded dimension + # shardwise_unflatten_sharded requires dim to be the sharded dimension + # and sizes[0] must be evenly shardable by the device mesh size + + # Test Case 1: Shard on dim=1, unflatten into 2 dims + # shape[1] = size_dp * 4 * 3 = size_dp * 12 + # unflatten into (size_dp * 4, 3) -> sizes[0] = size_dp * 4 is evenly shardable by size_dp + test_cases_dim1 = [ + # (input_placements, dim, sizes, description) + ((Shard(dim=1), Replicate(), Replicate()), 1, (size_dp * 4, 3), "unflatten dim=1 into (size_dp*4, 3)"), + ((Shard(dim=1), Replicate(), Replicate()), 1, (size_dp * 2, 6), "unflatten dim=1 into (size_dp*2, 6)"), + ((Shard(dim=1), Replicate(), Replicate()), 1, (size_dp * 12, 1), "unflatten dim=1 into (size_dp*12, 1)"), + ((Shard(dim=1), Replicate(), Replicate()), 1, (size_dp * 4, 3, 1), "unflatten dim=1 into (size_dp*4, 3, 1)"), + ((Shard(dim=1), Replicate(), Replicate()), 1, (size_dp * 2, 2, 3), "unflatten dim=1 into (size_dp*2, 2, 3)"), + ] + + # Test Case 2: Shard on dim=2, unflatten into 2 dims + # shape[2] = size_ring * 6 * 2 = size_ring * 12 + # unflatten into (size_ring * 6, 2) -> sizes[0] = size_ring * 6 is evenly shardable by size_ring + test_cases_dim2 = [ + ((Replicate(), Shard(dim=2), Replicate()), 2, (size_ring * 6, 2), "unflatten dim=2 into (size_ring*6, 2)"), + ((Replicate(), Shard(dim=2), Replicate()), 2, (size_ring * 4, 3), "unflatten dim=2 into (size_ring*4, 3)"), + ((Replicate(), Shard(dim=2), Replicate()), 2, (size_ring * 12, 1), "unflatten dim=2 into (size_ring*12, 1)"), + ( + (Replicate(), Shard(dim=2), Replicate()), + 2, + (size_ring * 6, 2, 1), + "unflatten dim=2 into (size_ring*6, 2, 1)", + ), + ( + (Replicate(), Shard(dim=2), Replicate()), + 2, + (size_ring * 2, 3, 2), + "unflatten dim=2 into (size_ring*2, 3, 2)", + ), + ] + + # Test Case 3: Both dim=1 and dim=2 are sharded + # dim=1 sharded by dp (mesh dim 0), dim=2 sharded by ring (mesh dim 1) + # unflatten dim=2 so dim=1's shard placement is unaffected + test_cases_both_sharded = [ + ( + (Shard(dim=1), Shard(dim=2), Replicate()), + 2, + (size_ring * 6, 2), + "unflatten dim=2 with both dim=1,2 sharded", + ), + ( + (Shard(dim=1), Shard(dim=2), Replicate()), + 2, + (size_ring * 4, 3), + "unflatten dim=2 into (size_ring*4, 3) with both sharded", + ), + ( + (Shard(dim=1), Shard(dim=2), Replicate()), + 2, + (size_ring * 2, 3, 2), + "unflatten dim=2 into (size_ring*2, 3, 2) with both sharded", + ), + ] + + # Test Case 4: Both dim=1 and dim=2 are sharded, unflatten dim=1. + # Unflattening at a lower dim shifts the higher shard's placement index. + # E.g., unflatten dim=1 into (a, b) adds 1 dim → Shard(dim=2) must become Shard(dim=3). + # Format: (input_placements, dim, sizes, expected_output_placements, description) + test_cases_placement_shift = [ + # unflatten dim=1 into 2 parts: adds 1 dim → Shard(2) shifts to Shard(3) + ( + (Shard(dim=1), Shard(dim=2), Replicate()), + 1, + (size_dp * 4, 3), + (Shard(dim=1), Shard(dim=3), Replicate()), + "unflatten dim=1 into (size_dp*4, 3) shifting Shard(2)->Shard(3)", + ), + # unflatten dim=1 into 3 parts: adds 2 dims → Shard(2) shifts to Shard(4) + ( + (Shard(dim=1), Shard(dim=2), Replicate()), + 1, + (size_dp * 2, 2, 3), + (Shard(dim=1), Shard(dim=4), Replicate()), + "unflatten dim=1 into (size_dp*2, 2, 3) shifting Shard(2)->Shard(4)", + ), + ] + + all_test_cases = [ + (pl, d, s, pl, desc) for pl, d, s, desc in test_cases_dim1 + test_cases_dim2 + test_cases_both_sharded + ] + test_cases_placement_shift + + for input_placements, dim, sizes, expected_output_placements, description in all_test_cases: + label_test_case = f"{description} (dim={dim}, sizes={sizes})\n" + + # Compute global expectations using standard PyTorch operations + x_global, y_expected_global, x_grad_expected_global, dy_global = compute_global_expectation( + shape, dim, sizes, manager.device + ) + + # NOTE: DTensor native op doesn't support unflattening a sharded dimension, + # so we skip DTensor native comparison and only use global serial version as reference. + + # Compute shardwise_unflatten_sharded forward and backward with validation + y_dtensor_result, x_dtensor, dy_dtensor = compute_shardwise_unflatten_sharded_with_validation( + x_global, dy_global, manager.device_mesh_subgroups, input_placements, dim, sizes, label_test_case + ) + + # =================================================================== + # Check output shape and placements + # =================================================================== + # Verify output shape matches expected global shape + assert ( + y_dtensor_result.shape == y_expected_global.shape + ), f"{label_test_case} output shape mismatch: got {y_dtensor_result.shape}, expected {y_expected_global.shape}" + + # Verify output placements: Shard dims beyond the unflatten point shift up + # by (len(sizes) - 1) because that many new dims are introduced. + assert y_dtensor_result.placements == expected_output_placements, ( + f"{label_test_case} output placements mismatch: got {y_dtensor_result.placements}, " + f"expected {expected_output_placements}" + ) + + # =================================================================== + # Check against global serial expectation + # =================================================================== + # Distribute expected output to compare local shards + y_dtensor_expected = distribute_tensor( + y_expected_global, manager.device_mesh_subgroups, y_dtensor_result.placements + ) + + # Compare forward result local shards + assert_tensors_identical( + y_dtensor_result.to_local().detach(), + y_dtensor_expected.to_local().detach(), + ) + + # Compare forward result global tensor + y_result_global = y_dtensor_result.full_tensor() + assert_tensors_identical( + y_result_global.detach(), + y_expected_global.detach(), + ) + + # =================================================================== + # Check backward pass against global serial expectation + # =================================================================== + # Verify input gradient shape + assert x_dtensor.grad.shape == x_grad_expected_global.shape, ( + f"{label_test_case} input gradient shape mismatch: got {x_dtensor.grad.shape}, " + f"expected {x_grad_expected_global.shape}" + ) + + # Distribute expected input gradient for local shard comparison + x_grad_expected_dtensor = distribute_tensor( + x_grad_expected_global, manager.device_mesh_subgroups, input_placements + ) + + # Compare input gradient local shards + assert_tensors_identical( + x_dtensor.grad.to_local().detach(), + x_grad_expected_dtensor.to_local().detach(), + ) + + # Compare input gradient global tensor + assert_tensors_identical( + x_dtensor.grad.full_tensor().detach(), + x_grad_expected_global.detach(), + ) + + # =================================================================== + # Test invalid cases that should raise ValueError + # =================================================================== + + # Create a test tensor for invalid cases + x_global_invalid = torch.rand(*shape, device=manager.device) + + # Invalid Case 1: dim is NOT sharded + invalid_not_sharded_cases = [ + # (input_placements, dim, sizes, expected_error_pattern) + ((Shard(dim=1), Replicate(), Replicate()), 0, (1, 2), "input is not sharded along dim"), + ((Shard(dim=1), Replicate(), Replicate()), 2, (size_ring * 6, 2), "input is not sharded along dim"), + ((Replicate(), Shard(dim=2), Replicate()), 0, (1, 2), "input is not sharded along dim"), + ((Replicate(), Shard(dim=2), Replicate()), 1, (size_dp * 4, 3), "input is not sharded along dim"), + ] + + for input_placements, dim, sizes, error_pattern in invalid_not_sharded_cases: + x_dtensor = distribute_tensor(x_global_invalid.clone(), manager.device_mesh_subgroups, input_placements) + with pytest.raises(ValueError, match=error_pattern): + shardwise_unflatten_sharded(x_dtensor, axis=dim, sizes=sizes) + + # Invalid Case 2: sizes[0] not evenly shardable + # Need tensors with shapes that allow sizes product to match but sizes[0] not evenly shardable + # We need to pick sizes[0] that is NOT divisible by the mesh size + # For size_dp=2: use 3 (3 % 2 != 0), shape[1] = 2*12 = 24 = 3*8 + # For size_ring=2: use 3 (3 % 2 != 0), shape[2] = 2*12 = 24 = 3*8 + # For size_ring=3: use 4 (4 % 3 != 0), shape[2] = 3*12 = 36 = 4*9 + + # Test 1: dim=1 sharded by size_dp, sizes[0] not evenly shardable (when size_dp>1 and 3 % size_dp != 0) + if size_dp > 1 and 3 % size_dp != 0: + shape_for_uneven1 = (2, size_dp * 12, size_ring * 12, 5, 4) + x_for_uneven1 = torch.rand(*shape_for_uneven1, device=manager.device) + x_dtensor_uneven1 = distribute_tensor( + x_for_uneven1.clone(), manager.device_mesh_subgroups, (Shard(dim=1), Replicate(), Replicate()) + ) + # sizes = (3, size_dp*4) so 3 * size_dp*4 = size_dp*12 = shape[1], but 3 % size_dp != 0 + with pytest.raises(ValueError, match="must be evenly sharded"): + shardwise_unflatten_sharded(x_dtensor_uneven1, axis=1, sizes=(3, size_dp * 4)) + + # Test 2: dim=2 sharded by size_ring, sizes[0] not evenly shardable + if size_ring > 1: + # Choose uneven_val such that uneven_val % size_ring != 0 + # For size_ring=2: use 3, for size_ring=3: use 4 + if size_ring == 2: + uneven_val = 3 + other_factor = 8 # 3 * 8 = 24 = size_ring * 12 + elif size_ring == 3: + uneven_val = 4 + other_factor = 9 # 4 * 9 = 36 = size_ring * 12 + else: + # General case: use (size_ring + 1) if it divides size_ring * 12 + # This test may be skipped for some unusual mesh sizes + uneven_val = size_ring + 1 + product = size_ring * 12 + if product % uneven_val == 0: + other_factor = product // uneven_val + else: + uneven_val = None # Skip this test + + if uneven_val is not None and uneven_val % size_ring != 0: + shape_for_uneven2 = (2, size_dp * 12, size_ring * 12, 5, 4) + x_for_uneven2 = torch.rand(*shape_for_uneven2, device=manager.device) + x_dtensor_uneven2 = distribute_tensor( + x_for_uneven2.clone(), manager.device_mesh_subgroups, (Replicate(), Shard(dim=2), Replicate()) + ) + with pytest.raises(ValueError, match="must be evenly sharded"): + shardwise_unflatten_sharded(x_dtensor_uneven2, axis=2, sizes=(uneven_val, other_factor)) + + # Invalid Case 3: Dimension out of range + invalid_out_of_range_cases = [ + ((Shard(dim=1), Replicate(), Replicate()), 10, (1, 2), "out of range"), + ] + + for input_placements, dim, sizes, error_pattern in invalid_out_of_range_cases: + x_dtensor = distribute_tensor(x_global_invalid.clone(), manager.device_mesh_subgroups, input_placements) + with pytest.raises(ValueError, match=error_pattern): + shardwise_unflatten_sharded(x_dtensor, axis=dim, sizes=sizes) + + # Invalid Case 4: sizes has less than 2 elements + invalid_sizes_cases = [ + ((Shard(dim=1), Replicate(), Replicate()), 1, (size_dp * 12,), "at least two dimensions"), + ] + + for input_placements, dim, sizes, error_pattern in invalid_sizes_cases: + x_dtensor = distribute_tensor(x_global_invalid.clone(), manager.device_mesh_subgroups, input_placements) + with pytest.raises(ValueError, match=error_pattern): + shardwise_unflatten_sharded(x_dtensor, axis=dim, sizes=sizes) + + # Invalid Case 5: Product of sizes doesn't match dim size + invalid_product_cases = [ + ((Shard(dim=1), Replicate(), Replicate()), 1, (size_dp * 4, 5), "Expected size"), + ] + + for input_placements, dim, sizes, error_pattern in invalid_product_cases: + x_dtensor = distribute_tensor(x_global_invalid.clone(), manager.device_mesh_subgroups, input_placements) + with pytest.raises(ValueError, match=error_pattern): + shardwise_unflatten_sharded(x_dtensor, axis=dim, sizes=sizes) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_dtensor_unflatten_sharded(setup_env): + """Test shardwise_unflatten_sharded for unflattening a sharded dimension.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_dtensor_unflatten_sharded, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) diff --git a/tests/distributed/model/layers/test_dtensor_where.py b/tests/distributed/model/layers/test_dtensor_where.py new file mode 100644 index 000000000..01be4c510 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_where.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import unittest + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.where import where +from boltz.testing.utils import assert_tensors_identical, spawn_multiprocessing + + +def serial_where(condition: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Serial implementation of where operation for comparison.""" + return torch.where(condition, x, y) + + +def parallel_assert_where( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + condition_global_host, + x_global_host, + y_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_x_expected_global_host, + d_y_expected_global_host, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Distribute input tensors + condition_dtensor = distribute_tensor( + condition_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ) + x_dtensor = distribute_tensor( + x_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ).requires_grad_(True) + y_dtensor = distribute_tensor( + y_global_host.to(manager.device), device_mesh=manager.device_mesh_subgroups, placements=placements + ).requires_grad_(True) + + # Distribute expected outputs + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + d_x_expected_dtensor = distribute_tensor( + d_x_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + d_y_expected_dtensor = distribute_tensor( + d_y_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + # Create copies to verify inputs aren't modified + condition_dtensor_copy = condition_dtensor.detach().clone() + x_dtensor_copy = x_dtensor.detach().clone().requires_grad_(True) + y_dtensor_copy = y_dtensor.detach().clone().requires_grad_(True) + + # Forward pass + output_dtensor_result = where(condition_dtensor, x_dtensor, y_dtensor) + + # Verify inputs weren't modified + assert_tensors_identical( + condition_dtensor_copy.to_local(), condition_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(x_dtensor_copy.to_local(), x_dtensor.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(y_dtensor_copy.to_local(), y_dtensor.to_local(), check_grad=False, check_grad_fn=False) + + # Test forward pass results + assert ( + output_dtensor_result.shape == output_expected_dtensor.shape + ), f"Output shape mismatch: {output_dtensor_result.shape} != {output_expected_dtensor.shape}" + assert ( + output_dtensor_result.stride() == output_expected_dtensor.stride() + ), f"Output stride mismatch: {output_dtensor_result.stride()} != {output_expected_dtensor.stride()}" + torch.testing.assert_close(output_dtensor_result.to_local(), output_expected_dtensor.to_local()) + + # Backward pass + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + # Verify upstream gradient wasn't modified + assert_tensors_identical(d_output_expected_dtensor_copy.to_local(), d_output_expected_dtensor.to_local()) + + # Test input gradients + assert ( + x_dtensor.grad.shape == d_x_expected_dtensor.shape + ), f"Input gradient shape mismatch: {x_dtensor.grad.shape} != {d_x_expected_dtensor.shape}" + assert ( + x_dtensor.grad.stride() == d_x_expected_dtensor.stride() + ), f"Input gradient stride mismatch: {x_dtensor.grad.stride()} != {d_x_expected_dtensor.stride()}" + torch.testing.assert_close(x_dtensor.grad.to_local(), d_x_expected_dtensor.to_local()) + + assert ( + y_dtensor.grad.shape == d_y_expected_dtensor.shape + ), f"Input gradient shape mismatch: {y_dtensor.grad.shape} != {d_y_expected_dtensor.shape}" + assert ( + y_dtensor.grad.stride() == d_y_expected_dtensor.stride() + ), f"Input gradient stride mismatch: {y_dtensor.grad.stride()} != {d_y_expected_dtensor.stride()}" + torch.testing.assert_close(y_dtensor.grad.to_local(), d_y_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + output_global_result_host = output_dtensor_result.full_tensor().cpu() + d_x_global_result_host = x_dtensor.grad.full_tensor().cpu() + d_y_global_result_host = y_dtensor.grad.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(output_global_result_host, output_expected_global_host) + torch.testing.assert_close(d_x_global_result_host, d_x_expected_global_host) + torch.testing.assert_close(d_y_global_result_host, d_y_expected_global_host) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize( + "placements", [(Shard(0), Shard(1), Shard(2)), (Shard(0), Shard(1), Replicate())], ids=["shard", "replicate"] +) +@pytest.mark.parametrize( + "condition_type", + ["random", "threshold"], +) +def test_where_parallel(setup_env, placements, condition_type): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 100 # Number of tokens + D = 32 # Hidden dimension + + seed = 42 + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + + # Create input tensors + x_global = torch.empty((B, N, N, D), requires_grad=True, device=device_type) + y_global = torch.empty((B, N, N, D), requires_grad=True, device=device_type) + with torch.no_grad(): + x_global.uniform_(-5, 5, generator=rng) + y_global.uniform_(-3, 7, generator=rng) + + # Create different types of conditions for testing + if condition_type == "random": + condition_global = torch.randint(0, 2, (B, N, N, D), dtype=torch.bool, device=device_type, generator=rng) + elif condition_type == "threshold": + # Create condition based on x values (test gradient flow) + condition_global = x_global > 0.0 + else: + raise ValueError(f"Invalid condition type: {condition_type}") + + # Run serial forward pass + condition_global_host = condition_global.detach().clone().cpu() + x_global_host = x_global.detach().clone().cpu() + y_global_host = y_global.detach().clone().cpu() + output_expected_global = serial_where(condition_global, x_global, y_global) + output_expected_global_host = output_expected_global.detach().clone().cpu() + + # Create upstream gradient and run backward pass + d_output_expected_global = torch.rand_like(output_expected_global) + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_where, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + condition_global_host, + x_global_host, + y_global_host, + output_expected_global_host, + d_output_expected_global_host, + x_global.grad.detach().clone().cpu(), + y_global.grad.detach().clone().cpu(), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/distributed/model/layers/test_dtensor_window_batch_utils.py b/tests/distributed/model/layers/test_dtensor_window_batch_utils.py new file mode 100644 index 000000000..13f89b7e7 --- /dev/null +++ b/tests/distributed/model/layers/test_dtensor_window_batch_utils.py @@ -0,0 +1,1021 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import pytest +import torch +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.flatten_and_unflatten import shardwise_unflatten_sharded +from boltz.distributed.model.layers.utils import ( + convert_single_repr_to_window_batched_key, + distributed_gather_sliding_windows, + distributed_pack_and_pad, + distributed_unpad_and_unpack, + gather_sliding_windows, + pack_and_pad, +) +from boltz.model.modules.encoders import get_indexing_matrix, single_to_keys +from boltz.testing.utils import assert_tensors_identical, seed_by_rank, spawn_multiprocessing + + +def parallel_assert_gather_sliding_windows(rank, grid_group_sizes, device_type, backend, env_map): + """Run distributed version on each rank.""" + + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device = manager.device + device_mesh = manager.device_mesh + + seed_by_rank(0, 42) + + for W, H, K_per_rank in [(8, 24, 3), (32, 128, 8), (32, 128, 1)]: + assert H % (W // 2) == 0, f"H must be divisible by W // 2 but got {H} % {W // 2} != 0" + h = H // (W // 2) + assert h % 2 == 0, f"h := H // (W // 2) must be divisible by 2 but got {h} % 2 != 0" + K = K_per_rank * grid_group_sizes["cp"] + for n_per_rank in [4, 100]: + # always shard the (2 * K, ) axis by the "cp" submesh + for input_shape_extra in [(None,), (4, None), (None, 3), (2, None, 3)]: + # The "None" element indicates the axis to be used as "half-window" axis + # there can be only one "None" element in the input_shape + if sum(1 for x in input_shape_extra if x is None) != 1: + raise ValueError( + f"There can be one and only one 'None' element in the input_shape but got {input_shape_extra}" + ) + axis = input_shape_extra.index(None) + input_shape = input_shape_extra[:axis] + (2 * K, W // 2) + input_shape_extra[axis + 1 :] + + label = f"W:{W}, H:{H}, K:{K}, input_shape:{input_shape}" + + # Shard the leading axis if they exist and evenly divisible by the mesh shape + # otherwise replicate + ndim = len(input_shape) + if ndim > 2 and axis != 0 and device_mesh.ndim >= 2 and input_shape[0] % device_mesh.size(0) == 0: + placements = (Shard(0),) + (Replicate(),) * (device_mesh.ndim - 2) + (Shard(axis),) + else: + placements = (Replicate(),) * (device_mesh.ndim - 1) + (Shard(axis),) + + input_global = torch.randn(input_shape, dtype=torch.float32, device=device, requires_grad=True) + + offset_start = 1 - h // 2 + offsets = torch.arange(offset_start, offset_start + 2 * K, 2, device=device) + + output_ref = gather_sliding_windows(input_global, offsets, h, axis) + + # Backward + grad_output_ref = torch.randn_like(output_ref) + output_ref.backward(grad_output_ref) + + # Create sharded DTensor + input_dtensor = distribute_tensor( + input_global.detach().clone(), device_mesh, placements + ).requires_grad_(True) + + # Distributed forward + output_dtensor = distributed_gather_sliding_windows(input_dtensor, h, axis) + + # Verify forward + output_result_global = output_dtensor.full_tensor() + # due the complicated reshaping involved and the potential concatenation of halo + # between ranks, it's difficult to guarantee identical strides as in the global + # case, which is not very useful in practice, so we just check the values + # Also, compute_global_expectation returns detached tensors which voids grad existence check + assert_tensors_identical( + output_result_global, + output_ref, + check_stride=False, + check_grad=False, + check_grad_fn=False, + msg=lambda m: f"{label} fwd output mismatch:\n {m}", + ) + + # Distributed backward + grad_output_dtensor = distribute_tensor( + grad_output_ref.detach().clone(), device_mesh, output_dtensor.placements + ) + output_dtensor.backward(grad_output_dtensor) + + # Verify backward (with tolerance for numerical precision) + grad_input_result_global = input_dtensor.grad.full_tensor() + + torch.testing.assert_close( + grad_input_result_global, + input_global.grad, + msg=lambda m: f"{label} input gradient mismatch:\n {m}", + ) + + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, 4), True, "cuda", "ENV"), + ((1, 7), True, "cuda", "ENV"), + ((2, 4), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_distributed_gather_sliding_windows(setup_env): + """Test distributed gather sliding windows""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_gather_sliding_windows, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +def parallel_assert_pack_and_pad(rank, grid_group_sizes, device_type, backend, env_map): + """Run distributed pack and pad test.""" + + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device = manager.device + device_mesh = manager.device_mesh + + # 1. Create global data + seed_by_rank(0, 42) + + for W in [8, 32]: + for n_per_rank in [4, 100]: + # For testing purpose, always shard the "axis" along last dimension of the device mesh + n_global = n_per_rank * device_mesh.size(-1) + for input_shape_extra in [(None,), (4, None), (None, 3), (2, None, 3)]: + if sum(1 for x in input_shape_extra if x is None) != 1: + raise ValueError( + f"There can be one and only one 'None' element in the input_shape but got {input_shape_extra}" + ) + axis = input_shape_extra.index(None) + ndim = len(input_shape_extra) + input_shape = input_shape_extra[:axis] + (n_global,) + input_shape_extra[axis + 1 :] + + axes_mask_broadcast = torch.arange(ndim).tolist() + for axis_mask_broadcast in axes_mask_broadcast: + if axis_mask_broadcast == axis: + # mask has same shape as input + shape_mask = input_shape + else: + # mask has shape 1 along the broadcast axis + shape_mask = input_shape[:axis_mask_broadcast] + (1,) + input_shape[axis_mask_broadcast + 1 :] + + for keep_input_padding in [False, True]: + label = f"W:{W}, input_shape:{input_shape}, shape_mask:{shape_mask}, keep_input_padding:{keep_input_padding}" + + input_global = torch.randn(input_shape, requires_grad=True, device=device) + mask_global = torch.randint(0, 2, shape_mask, dtype=torch.bool, device=device) + # 1. reference fwd and bwd + output_ref, _, mask_ref = pack_and_pad(input_global, mask_global, axis, W, keep_input_padding) + grad_output = torch.randn_like(output_ref) + # backprop without padding in the reference, which additionally affirms that the extra padding + # in DTensor version doesn't contribute to its input gradients regardless of the padding values + + # We can only backward one of them or both summed. + # Let's verify sum + torch.autograd.backward([output_ref], [grad_output.detach().clone()]) + grad_input_ref = input_global.grad.clone() + input_global.grad.zero_() + + # 2. Distribute + # Shard the leading axis if they exist and evenly divisible by the mesh shape + # otherwise replicate + if ( + ndim > 1 + and axis != 0 + and device_mesh.ndim >= 2 + and input_shape[0] % device_mesh.size(0) == 0 + and shape_mask[0] % device_mesh.size(0) == 0 + ): + placements = (Shard(0),) + (Replicate(),) * (device_mesh.ndim - 2) + (Shard(axis),) + else: + placements = (Replicate(),) * (device_mesh.ndim - 1) + (Shard(axis),) + i_dim_mesh_shard_axis = placements.index(Shard(axis)) + world_size_shard_axis = device_mesh.shape[i_dim_mesh_shard_axis] + + # Use detach().clone() to ensure leaf tensor + input_dtensor = distribute_tensor( + input_global.detach().clone(), device_mesh, placements + ).requires_grad_(True) + mask_dtensor = distribute_tensor(mask_global.detach().clone(), device_mesh, placements) + + # 3. Distributed Unmask + input_dtensor_copy_local = input_dtensor.detach().clone().to_local().requires_grad_(True) + mask_dtensor_copy_local = mask_dtensor.detach().clone().to_local() + output_dtensor, output_mask_dtensor = distributed_pack_and_pad( + input_dtensor, mask_dtensor, W, axis, keep_input_padding=keep_input_padding + ) + + # assert no change to the fwd inputs + assert_tensors_identical( + input_dtensor_copy_local, input_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical( + mask_dtensor_copy_local, mask_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + + # 4. verify fwd output + output_gathered = output_dtensor.full_tensor() + mask_gathered = output_mask_dtensor.full_tensor() + + # verify the output.shape[axis] is padded to (W * world_size) + n_pad_target = W * world_size_shard_axis + if keep_input_padding: + n_expected = ((input_dtensor.shape[axis] + n_pad_target - 1) // n_pad_target) * n_pad_target + else: + n_valid_max = mask_global.sum(dim=axis).max().item() + n_expected = ((n_valid_max + n_pad_target - 1) // n_pad_target) * n_pad_target + + assert ( + output_dtensor.shape[axis] == n_expected + ), f"{label} output shape mismatch: {output_dtensor.shape[axis]} != {n_expected}" + + # The distributed output might be padded more than the reference output + # because distributed logic aligns to (W * world_size) + pad_len = output_gathered.shape[axis] - output_ref.shape[axis] + assert pad_len >= 0, "Distributed output should be at least as large as reference" + + if pad_len > 0: + # Pad output_ref along axis + pad_arg = [0] * (2 * output_ref.ndim) + pad_idx = (output_ref.ndim - 1 - axis) * 2 + 1 + pad_arg[pad_idx] = pad_len + output_ref_padded = torch.nn.functional.pad(output_ref, pad_arg) + mask_ref_padded = torch.nn.functional.pad(mask_ref, pad_arg) + grad_output_padded = torch.nn.functional.pad(grad_output, pad_arg) + else: + output_ref_padded = output_ref + mask_ref_padded = mask_ref + grad_output_padded = grad_output + + # target function involves no FLOPS so we can use strict equality + assert_tensors_identical( + output_gathered, + output_ref_padded, + check_stride=False, + check_grad=False, + check_grad_fn=False, + msg=lambda m: f"{label} output mismatch:\n {m}", + ) + + assert_tensors_identical( + mask_gathered, + mask_ref_padded, + check_stride=False, + check_grad=False, + check_grad_fn=False, + msg=lambda m: f"{label} mask mismatch:\n {m}", + ) + + # 5. Backward + + # Distribute grad + grad_dtensor = distribute_tensor(grad_output_padded.detach().clone(), device_mesh, placements) + # We need dummy grads for the masks + + # Backward with both + # We need to manually call backward because output_dtensor.backward() only accepts one gradient? + # No, torch.autograd.backward accepts tensors and grad_tensors + torch.autograd.backward([output_dtensor], [grad_dtensor]) + + grad_gathered = input_dtensor.grad.full_tensor() + + assert_tensors_identical( + grad_gathered, + grad_input_ref, + check_grad=False, + check_grad_fn=False, + msg=lambda m: f"{label} input gradient mismatch:\n {m}", + ) + + # 6. Test Inverse Forward + # We detach output_dtensor_qw to test Inverse backward in isolation + output_dtensor_detached = output_dtensor.detach().clone().requires_grad_(True) + input_recovered_dtensor = distributed_unpad_and_unpack( + output_dtensor_detached, + output_mask_dtensor, + mask_dtensor, + axis, + keep_input_padding=keep_input_padding, + ) + + # Verify Forward (Recovered Input should match original Input where mask is True) + # The inverse operation recovers the valid elements into their original positions. + # Invalid positions (where mask is False) are zeroed out by scatter. + input_recovered_global = input_recovered_dtensor.full_tensor() + + # Apply mask to original input for comparison + input_global_masked = input_global * mask_global.to(input_global.dtype) + # Inverse output also has zeros at invalid positions naturally + + assert_tensors_identical( + input_recovered_global, + input_global_masked, + check_stride=False, + check_grad=False, + check_grad_fn=False, + msg=lambda m: f"{label} inverse forward mismatch:\n {m}", + ) + + # 7. Test Inverse Backward + grad_input_recovered = torch.randn_like(input_global) + # Distribute the gradient + grad_input_recovered_dtensor = distribute_tensor( + grad_input_recovered.detach().clone(), device_mesh, placements + ) + + input_recovered_dtensor.backward(grad_input_recovered_dtensor) + + grad_output_result = output_dtensor_detached.grad.full_tensor() + + # Reference Backward: The backward of Inverse is Unmask (Gather) + # We use the pack and pad reference implementation on the gradient + grad_output_ref_qw, _, _ = pack_and_pad( + grad_input_recovered, mask_global, axis, W, keep_input_padding + ) + + # Pad reference to match distributed output shape + pad_len = grad_output_result.shape[axis] - grad_output_ref_qw.shape[axis] + if pad_len > 0: + pad_arg = [0] * (2 * grad_output_ref_qw.ndim) + pad_idx = (grad_output_ref_qw.ndim - 1 - axis) * 2 + 1 + pad_arg[pad_idx] = pad_len + grad_output_ref_padded = torch.nn.functional.pad(grad_output_ref_qw, pad_arg) + else: + grad_output_ref_padded = grad_output_ref_qw + + assert_tensors_identical( + grad_output_result, + grad_output_ref_padded, + check_stride=False, + check_grad=False, + check_grad_fn=False, + msg=lambda m: f"{label} inverse backward mismatch:\n {m}", + ) + + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, 4), True, "cuda", "ENV"), + ((1, 7), True, "cuda", "ENV"), + ((2, 4), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_distributed_pack_and_pad(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_pack_and_pad, world_size, grid_group_sizes, device_type, backend, env_per_rank + ) + + +def parallel_assert_pack_and_pad_and_gather_sliding_windows(rank, grid_group_sizes, device_type, backend, env_map): + """Test integration of distributed pack and pad -> gather sliding windows.""" + + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device = manager.device + device_mesh = manager.device_mesh + + # 1. Create global data + seed_by_rank(0, 42) + + for W, H, K_per_rank in [(8, 24, 3), (32, 128, 8), (32, 128, 1)]: + assert H % (W // 2) == 0, f"H must be divisible by W // 2 but got {H} % {W // 2} != 0" + h = H // (W // 2) + assert h % 2 == 0, f"h := H // (W // 2) must be divisible by 2 but got {h} % 2 != 0" + K = K_per_rank * grid_group_sizes["cp"] + h = H // (W // 2) + for n_total in [K * W, K * W + 10 * grid_group_sizes["cp"]]: + # always shard the (2 * K, ) axis by the "cp" submesh + for input_shape_extra in [(None,), (4, None), (None, 3), (2, None, 3)]: + # The "None" element indicates the axis to be used as "half-window" axis + # there can be only one "None" element in the input_shape + if sum(1 for x in input_shape_extra if x is None) != 1: + raise ValueError( + f"There can be one and only one 'None' element in the input_shape but got {input_shape_extra}" + ) + axis = input_shape_extra.index(None) + input_shape = input_shape_extra[:axis] + (n_total,) + input_shape_extra[axis + 1 :] + shape_mask = input_shape + ndim = len(input_shape) + # Shard the leading axis if they exist and evenly divisible by the mesh shape + # otherwise replicate + if ( + ndim > 1 + and axis != 0 + and device_mesh.ndim >= 2 + and input_shape[0] % device_mesh.size(0) == 0 + and shape_mask[0] % device_mesh.size(0) == 0 + ): + # don't shard the dim after "axis" so that these placements can be reused for + # sharding the upstream adjoint for testing the backward because the fwd output + # inserts new axes after the "axis" dimension + placements = (Shard(0),) + (Replicate(),) * (device_mesh.ndim - 2) + (Shard(axis),) + else: + placements = (Replicate(),) * (device_mesh.ndim - 1) + (Shard(axis),) + + label = f"W:{W}, H:{H}, K:{K}, input_shape:{input_shape}" + + # 1. Setup Global Data + input_global = torch.randn(input_shape, requires_grad=True, device=device) + mask_global = torch.randint(0, 2, shape_mask, dtype=torch.bool, device=device) + + # 2. Distribute + input_dtensor = distribute_tensor( + input_global.detach().clone().requires_grad_(True), device_mesh, placements + ) + mask_dtensor = distribute_tensor(mask_global.detach().clone(), device_mesh, placements) + + # 3. Distributed Chain + # Unmask + packed_dtensor, packed_mask = distributed_pack_and_pad(input_dtensor, mask_dtensor, W, axis) + + # Due to keep_input_padding=False, input_dtensor.shape[axis] is trimmed to the maximum number of valid elements + # which is then padded to the multiple of W * size_group_sharding_axis so we need to recompute + # K in order to reshape the tensors + K_global = packed_dtensor.shape[axis] // W + + packed_dtensor_hw = shardwise_unflatten_sharded(packed_dtensor, axis, (2 * K_global, W // 2)) + + # Gather + output_dtensor = distributed_gather_sliding_windows(packed_dtensor_hw, h, axis) + + # 4. Single Device Chain + packed_ref, _, mask_ref = pack_and_pad(input_global, mask_global, axis, W) + # Apply masks + packed_ref = packed_ref * mask_ref.to(packed_ref.dtype) + + # packed_ref_hw: (2K, W/2, F) + + # Need to match the global padding logic of distributed_pack_and_pad! + # Distributed pads total_valid to multiple of W*world_size. + # Single device pads to multiple of W. + # We need to pad packed_ref to match packed_dtensor shape for comparison. + + if packed_dtensor.shape[axis] > packed_ref.shape[axis]: + # Pad along axis + pad_arg = [0] * (2 * packed_ref.ndim) + idx = (packed_ref.ndim - 1 - axis) * 2 + 1 + pad_arg[idx] = packed_dtensor.shape[axis] - packed_ref.shape[axis] + packed_ref_padded = torch.nn.functional.pad(packed_ref, pad_arg) + + else: + packed_ref_padded = packed_ref + + packed_ref_hw_padded = packed_ref_padded.unflatten(axis, (2 * K_global, W // 2)) + + # Gather Ref + # We need offsets + offset_start = 1 - h // 2 + offsets_ref = torch.arange(offset_start, offset_start + 2 * K_global, 2, device=device) + output_ref = gather_sliding_windows(packed_ref_hw_padded, offsets_ref, h, axis) + + # 5. Verify Forward + output_gathered = output_dtensor.full_tensor() + torch.testing.assert_close( + output_gathered, output_ref, atol=0, rtol=0, msg=lambda m: f"{label} output mismatch:\n {m}" + ) + + # 6. Verify Backward + grad_out = torch.randn_like(output_gathered) + grad_dtensor = distribute_tensor(grad_out.detach().clone(), device_mesh, placements) + + # dummy gradients for masks (required by autograd since we return them) + # Note: we don't pass them to backward() because masks are non-differentiable + # and PyTorch will pass None to the backward method for them. + + torch.autograd.backward([output_dtensor], [grad_dtensor]) + + # Single backward + # We use output_ref and packed_ref_qw as roots + # We directly backprop with the padding, assuring that the extra padding in DTensor case + # behave similarly + torch.autograd.backward([output_ref], [grad_out]) + + # Verify input gradients + grad_input_gathered = input_dtensor.grad.full_tensor() + torch.testing.assert_close( + grad_input_gathered, + input_global.grad, + msg=lambda m: f"{label} input gradient mismatch:\n {m}", + ) + + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, 4), True, "cuda", "ENV"), + ((1, 7), True, "cuda", "ENV"), + ((2, 4), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_distributed_pack_and_pad_and_gather_sliding_windows(setup_env): + """ + Test integration of distributed pack and pad -> gather sliding windows. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_pack_and_pad_and_gather_sliding_windows, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +def parallel_assert_window_batch_attention(rank, grid_group_sizes, device_type, backend, env_map, dtype): + """Test integration of distributed pack and pad -> gather sliding windows -> attention.""" + + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device = manager.device + device_mesh = manager.device_mesh + + # 1. Create global data + seed_by_rank(0, 42) + + for W, H, K_per_rank in [(8, 16, 1), (32, 128, 8), (32, 128, 1)]: + assert H % (W // 2) == 0, f"H must be divisible by W // 2 but got {H} % {W // 2} != 0" + h = H // (W // 2) + assert h % 2 == 0, f"h := H // (W // 2) must be divisible by 2 but got {h} % 2 != 0" + K = K_per_rank * grid_group_sizes["cp"] + h = H // (W // 2) + for n_total in [K * W]: # Boltz window batching only supports K * W size input + for input_shape_extra in [ + (2, None, 3), + (16, None, 128), + ]: # single_to_keys requires (B, n, D) shape + # The "None" element indicates the axis to be used as "half-window" axis + # there can be only one "None" element in the input_shape + if sum(1 for x in input_shape_extra if x is None) != 1: + raise ValueError( + f"There can be one and only one 'None' element in the input_shape but got {input_shape_extra}" + ) + axis = input_shape_extra.index(None) + input_shape = input_shape_extra[:axis] + (n_total,) + input_shape_extra[axis + 1 :] + shape_mask = input_shape + ndim = len(input_shape) + # Shard the leading axis if they exist and evenly divisible by the mesh shape + # otherwise replicate + if ( + ndim > 1 + and axis != 0 + and device_mesh.ndim >= 2 + and input_shape[0] % device_mesh.size(0) == 0 + and shape_mask[0] % device_mesh.size(0) == 0 + ): + # don't shard the dim after "axis" so that these placements can be reused for + # sharding the upstream adjoint for testing the backward because the fwd output + # inserts new axes after the "axis" dimension + placements = (Shard(0),) + (Replicate(),) * (device_mesh.ndim - 2) + (Shard(axis),) + else: + placements = (Replicate(),) * (device_mesh.ndim - 1) + (Shard(axis),) + + label = f"W:{W}, H:{H}, K:{K}, input_shape:{input_shape}" + + # 1. Setup Global Data + input_global = torch.randn(input_shape, dtype=dtype, requires_grad=True, device=device) + mask_global = torch.randint(0, 2, shape_mask, dtype=torch.bool, device=device, requires_grad=False) + + # sort the input and mask so that valid elements are leading the axis + mask_global_sorted, argsort_mask_global = torch.sort( + mask_global, dim=axis, descending=True, stable=True + ) + input_global_sorted = torch.gather(input_global, axis, argsort_mask_global) + + indexing_matrix = get_indexing_matrix(K, W, H, device).to(dtype=input_global.dtype) + + # Get key and query from Boltz window batching (..., K, H, ...) + key_wb_expected = single_to_keys(input_global_sorted, indexing_matrix, W, H) + mask_key_wb_expected = single_to_keys( + mask_global_sorted.to(dtype=indexing_matrix.dtype), indexing_matrix, W, H + ).to(dtype=torch.bool) + + query_wb_expected = input_global_sorted.unflatten(axis, (K, W)) + mask_query_wb_expected = mask_global_sorted.unflatten(axis, (K, W)) + + # perform linear attention with single iteration (to avoid numerical complications from softmax etc but enough to + # prove the DTensor logics are correct) + a_wb_expected = torch.einsum( + "bkid,bkjd->bkij", + query_wb_expected * mask_query_wb_expected, + key_wb_expected * mask_key_wb_expected, + ) + o_wb_expected = torch.einsum("bkml,bkld->bkmd", a_wb_expected, key_wb_expected * mask_key_wb_expected) + + # 2. Distribute + input_dtensor = distribute_tensor( + input_global.detach().clone().requires_grad_(True), device_mesh, placements + ) + mask_dtensor = distribute_tensor(mask_global.detach().clone(), device_mesh, placements) + + # 3. Distributed Chain + # Unmask + packed_dtensor, packed_mask = distributed_pack_and_pad(input_dtensor, mask_dtensor, W, axis) + + # Due to keep_input_padding=False, input_dtensor.shape[axis] is trimmed to the maximum number of valid elements + # which is then padded to the multiple of W * size_group_sharding_axis so we need to recompute + # K in order to reshape the tensors + K_global = packed_dtensor.shape[axis] // W + + packed_dtensor_hw = shardwise_unflatten_sharded(packed_dtensor, axis, (2 * K_global, W // 2)) + packed_mask_hw = shardwise_unflatten_sharded(packed_mask, axis, (2 * K_global, W // 2)) + + # Gather + # (..., K, h, W//2) -> (..., K, H, ...) + key_wb_result_reshaped = distributed_gather_sliding_windows(packed_dtensor_hw, h, axis) + key_wb_result = key_wb_result_reshaped.flatten(axis + 1, axis + 2) + mask_key_wb_result_reshaped = distributed_gather_sliding_windows(packed_mask_hw, h, axis) + mask_key_wb_result = mask_key_wb_result_reshaped.flatten(axis + 1, axis + 2) + + query_wb_result = shardwise_unflatten_sharded(packed_dtensor, axis, (K_global, W)) + mask_query_wb_result = shardwise_unflatten_sharded(packed_mask, axis, (K_global, W)) + + # DTensor linear attention + # The DTensor native einsum doesn't support the involved implicit unflattening of two + # leading batch dimensions so we need to do via shard-wise local attention + a_wb_result_local = torch.einsum( + "bkid,bkjd->bkij", + query_wb_result.to_local() * mask_query_wb_result.to_local(), + key_wb_result.to_local() * mask_key_wb_result.to_local(), + ) + o_wb_result_local = torch.einsum( + "bkml,bkld->bkmd", + a_wb_result_local, + key_wb_result.to_local() * mask_key_wb_result.to_local(), + ) + a_wb_result = DTensor.from_local(a_wb_result_local, device_mesh, placements) + o_wb_result = DTensor.from_local(o_wb_result_local, device_mesh, placements) + + # verify key and query consistency + # this requires potentially padding the reference to the next multiple of W * world_size + key_wb_result_global = key_wb_result.full_tensor() + mask_key_wb_result_global = mask_key_wb_result.full_tensor() + query_wb_result_global = query_wb_result.full_tensor() + mask_query_wb_result_global = mask_query_wb_result.full_tensor() + a_wb_result_global = a_wb_result.full_tensor() + o_wb_result_global = o_wb_result.full_tensor() + pad_result = False + pad_expected = False + if key_wb_result_global.shape[axis] > key_wb_expected.shape[axis]: + pad_result = True + # pad expected + pad_arg = [0] * (2 * key_wb_expected.ndim) + idx = (key_wb_expected.ndim - 1 - axis) * 2 + 1 + pad_arg[idx] = key_wb_result_global.shape[axis] - key_wb_expected.shape[axis] + key_wb_expected_padded = torch.nn.functional.pad(key_wb_expected, pad_arg) + mask_key_wb_expected_padded = torch.nn.functional.pad(mask_key_wb_expected, pad_arg) + query_wb_expected_padded = torch.nn.functional.pad(query_wb_expected, pad_arg) + mask_query_wb_expected_padded = torch.nn.functional.pad(mask_query_wb_expected, pad_arg) + a_wb_expected_padded = torch.nn.functional.pad(a_wb_expected, pad_arg) + o_wb_expected_padded = torch.nn.functional.pad(o_wb_expected, pad_arg) + # use result + key_wb_result_global_padded = key_wb_result_global + mask_key_wb_result_global_padded = mask_key_wb_result_global + query_wb_result_global_padded = query_wb_result_global + mask_query_wb_result_global_padded = mask_query_wb_result_global + a_wb_result_global_padded = a_wb_result_global + o_wb_result_global_padded = o_wb_result_global + elif key_wb_result_global.shape[axis] < key_wb_expected.shape[axis]: + pad_result = True + # pad result + pad_arg = [0] * (2 * key_wb_result_global.ndim) + idx = (key_wb_result_global.ndim - 1 - axis) * 2 + 1 + pad_arg[idx] = key_wb_expected.shape[axis] - key_wb_result_global.shape[axis] + key_wb_result_global_padded = torch.nn.functional.pad(key_wb_result_global, pad_arg) + mask_key_wb_result_global_padded = torch.nn.functional.pad(mask_key_wb_result_global, pad_arg) + query_wb_result_global_padded = torch.nn.functional.pad(query_wb_result_global, pad_arg) + mask_query_wb_result_global_padded = torch.nn.functional.pad(mask_query_wb_result_global, pad_arg) + a_wb_result_global_padded = torch.nn.functional.pad(a_wb_result_global, pad_arg) + o_wb_result_global_padded = torch.nn.functional.pad(o_wb_result_global, pad_arg) + # use expected + key_wb_expected_padded = key_wb_expected + mask_key_wb_expected_padded = mask_key_wb_expected + query_wb_expected_padded = query_wb_expected + mask_query_wb_expected_padded = mask_query_wb_expected + a_wb_expected_padded = a_wb_expected + o_wb_expected_padded = o_wb_expected + else: + pad_arg = None # prevent accidental reuse from previous for loop iteration + key_wb_expected_padded = key_wb_expected + mask_key_wb_expected_padded = mask_key_wb_expected + query_wb_expected_padded = query_wb_expected + mask_query_wb_expected_padded = mask_query_wb_expected + a_wb_expected_padded = a_wb_expected + o_wb_expected_padded = o_wb_expected + key_wb_result_global_padded = key_wb_result_global + mask_key_wb_result_global_padded = mask_key_wb_result_global + query_wb_result_global_padded = query_wb_result_global + mask_query_wb_result_global_padded = mask_query_wb_result_global + a_wb_result_global_padded = a_wb_result_global + o_wb_result_global_padded = o_wb_result_global + + torch.testing.assert_close( + mask_key_wb_result_global_padded, + mask_key_wb_expected_padded, + msg=lambda m: f"{label} mask_key_wb mismatch:\n {m}", + ) + torch.testing.assert_close( + key_wb_result_global_padded * mask_key_wb_result_global_padded, + key_wb_expected_padded * mask_key_wb_expected_padded, + msg=lambda m: f"{label} key_wb mismatch:\n {m}", + ) + + torch.testing.assert_close( + mask_query_wb_result_global_padded, + mask_query_wb_expected_padded, + msg=lambda m: f"{label} mask_query_wb mismatch:\n {m}", + ) + torch.testing.assert_close( + query_wb_result_global_padded * mask_query_wb_result_global_padded, + query_wb_expected_padded * mask_query_wb_expected_padded, + msg=lambda m: f"{label} query_wb mismatch:\n {m}", + ) + torch.testing.assert_close( + a_wb_result_global_padded, + a_wb_expected_padded, + msg=lambda m: f"{label} a_wb mismatch:\n {m}", + ) + torch.testing.assert_close( + o_wb_result_global_padded, + o_wb_expected_padded, + msg=lambda m: f"{label} o_wb mismatch:\n {m}", + ) + + # check backward pass + # To make sharding the upstream adjoint easy, we always use the DTensor full_tensor version + # to generate the upstream adjoints and make necessary padding to the expected version + # We also zeros out the upstream adjoints to check if the invalid elements in input.grad are zeros + grad_o_wb_result_global = torch.randn_like(o_wb_result_global) * mask_query_wb_result_global + if pad_result: + # grad_o_wb_result_global is shorter along 'axis' than o_wb_expected_padded + grad_o_wb_expected_padded = torch.nn.functional.pad( + grad_o_wb_result_global.detach().clone(), pad_arg + ) + elif pad_expected: + # grad_o_wb_result_global is longer along 'axis' + grad_o_wb_expected_padded = ( + grad_o_wb_result_global.detach().clone().narrow(axis, 0, o_wb_expected_padded.shape[axis]) + ) + else: + grad_o_wb_expected_padded = grad_o_wb_result_global.detach().clone() + o_wb_expected_padded.backward(grad_o_wb_expected_padded) + + grad_o_wb_dtensor = distribute_tensor(grad_o_wb_result_global.detach().clone(), device_mesh, placements) + o_wb_result.backward(grad_o_wb_dtensor) + + # verify input gradients + grad_input_result_global = input_dtensor.grad.full_tensor() + mask_result_global = mask_dtensor.full_tensor() + torch.testing.assert_close( + grad_input_result_global * mask_result_global, + input_global.grad * mask_global, + msg=lambda m: f"{label} input gradient mismatch:\n {m}", + ) + + assert_tensors_identical( + grad_input_result_global * (~mask_result_global), + torch.zeros_like(grad_input_result_global), + check_grad=False, + check_grad_fn=False, + msg=lambda m: f"{label} input gradient mismatch for invalid elements:\n {m}", + ) + + DistributedManager.cleanup() + + +def parallel_assert_single_to_key(rank, grid_group_sizes, device_type, backend, env_map, dtype): + """Test convert_single_repr_to_window_batched_key against single_to_keys (fwd+bwd).""" + + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device = manager.device + device_mesh = manager.device_mesh + + seed_by_rank(0, 42) + + # Keep the same window-batching tuples used elsewhere: they satisfy Boltz constraints. + for W, H, K_per_rank in [(8, 16, 1), (32, 128, 8), (32, 128, 1)]: + assert H % (W // 2) == 0, f"H must be divisible by W // 2 but got {H} % {W // 2} != 0" + h = H // (W // 2) + assert h % 2 == 0, f"h := H // (W // 2) must be divisible by 2 but got {h} % 2 != 0" + + K = K_per_rank * grid_group_sizes["cp"] + N = K * W + + indexing_matrix = get_indexing_matrix(K, W, H, device).to(dtype=dtype) + + # convert_single_repr_to_window_batched_key assumes the sequence axis is dim=1. + for input_shape_extra in [ + (2, None, 3), + (16, None, 128), + ]: + axis = input_shape_extra.index(None) + assert axis == 1, "This test assumes the sequence axis is dim=1." + input_shape = input_shape_extra[:axis] + (N,) + input_shape_extra[axis + 1 :] + ndim = len(input_shape) + label = f"W:{W}, H:{H}, K:{K}, input_shape:{input_shape}" + + input_global = torch.randn(input_shape, dtype=dtype, requires_grad=True, device=device) + + # Reference: Boltz window-batched keys. + key_expected = single_to_keys(input_global, indexing_matrix, W, H) + + # Distribute input on device mesh; shard the sequence axis (dim=1) on the last mesh dim. + if ndim > 2 and device_mesh.ndim >= 2 and input_shape[0] % device_mesh.size(0) == 0: + placements = (Shard(0),) + (Replicate(),) * (device_mesh.ndim - 2) + (Shard(1),) + else: + placements = (Replicate(),) * (device_mesh.ndim - 1) + (Shard(1),) + + input_dtensor = distribute_tensor( + input_global.detach().clone().requires_grad_(True), + device_mesh, + placements, + ) + + # Distributed: window-batched keys. + key_dtensor = convert_single_repr_to_window_batched_key(input_dtensor, W, H) + key_gathered = key_dtensor.full_tensor() + + torch.testing.assert_close( + key_gathered, + key_expected, + msg=lambda m: f"{label} fwd key mismatch:\n {m}", + ) + + # Backward: compare input gradients. + grad_key = torch.randn_like(key_expected) + key_expected.backward(grad_key) + + grad_key_dtensor = distribute_tensor(grad_key.detach().clone(), device_mesh, key_dtensor.placements) + key_dtensor.backward(grad_key_dtensor) + + grad_input_gathered = input_dtensor.grad.full_tensor() + torch.testing.assert_close( + grad_input_gathered, + input_global.grad, + msg=lambda m: f"{label} input gradient mismatch:\n {m}", + ) + + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, 4), True, "cuda", "ENV"), + ((1, 7), True, "cuda", "ENV"), + ((2, 4), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_distributed_window_batch_attention(setup_env): + """ + Test integration of distributed pack and pad -> gather sliding windows -> attention. + This test uses Boltz get_indexing_matrix and single_to_keys() to create the window batched query and key + to perform the reference attention. It then uses the distributed pack and pad -> gather sliding windows + to perform the distributed attention. Both forward and backward passes are tested. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_window_batch_attention, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + torch.float64, # use FP64 to avoid dealing with numerical tolerance + ) + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, 4), True, "cuda", "ENV"), + ((1, 7), True, "cuda", "ENV"), + ((2, 4), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_distributed_single_to_key(setup_env): + """Test distributed convert_single_repr_to_window_batched_key vs single_to_keys.""" + + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + parallel_assert_single_to_key, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + torch.float64, # use FP64 to avoid dealing with numerical tolerance + ) diff --git a/tests/distributed/model/layers/test_redistribute_transpose.py b/tests/distributed/model/layers/test_redistribute_transpose.py new file mode 100755 index 000000000..4f7a5f56b --- /dev/null +++ b/tests/distributed/model/layers/test_redistribute_transpose.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import itertools +import random +from typing import Dict, Optional + +import pytest +import torch + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.redistribute_transpose_without_dtensor import redistribute_transpose +from boltz.testing.utils import assert_tensors_identical + + +def seed_by_rank(rank: int, seed: int = 42) -> None: + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def parallel_assert_redistribute_transpose( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + n_tokens_per_rank: int, + input_host: torch.Tensor, + output_expected_host: torch.Tensor, + d_output_expected_host: torch.Tensor, + d_input_host: torch.Tensor, + env_map: Optional[Dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + assert input_host.device.type == "cpu" + + layout_map = manager.layout_subgroups["cp"] + transpose_comm = TransposeComm(manager.group["cp"], layout_map) + + rank_coords = layout_map.unravel(manager.group_rank["cp"]) + i_chunk_begins = [rank_coords[i] * n_tokens_per_rank for i in range(len(rank_coords))] + i_chunk_ends = [(rank_coords[i] + 1) * n_tokens_per_rank for i in range(len(rank_coords))] + + # Extract the chunk of the input tensor for this rank + # The .contiguous() is necessary because the equivalent op is single-device scatter + # of the input, which requires contiguity + input_chunk = input_host[ + :, i_chunk_begins[0] : i_chunk_ends[0], i_chunk_begins[1] : i_chunk_ends[1], : + ].contiguous() + input_chunk = input_chunk.to(device=manager.device) + input_chunk.requires_grad = True + + d_output_chunk = d_output_expected_host[ + :, i_chunk_begins[0] : i_chunk_ends[0], i_chunk_begins[1] : i_chunk_ends[1], : + ].contiguous() + d_output_chunk = d_output_chunk.to(device=manager.device) + + input_chunk_clone = input_chunk.clone() + # Perform distributed transpose operation + result = redistribute_transpose(input_chunk, 1, 2, transpose_comm) + + # must not modify the input tensor + assert_tensors_identical(input_chunk_clone, input_chunk, check_grad_fn=False) + + # Perform backward pass + d_output_chunk_clone = d_output_chunk.clone() + torch.autograd.backward([result], [d_output_chunk]) + # no modification to any input + assert_tensors_identical(d_output_chunk, d_output_chunk_clone) + # no modification to any input, except that input_chunk now have gradients + assert_tensors_identical(input_chunk, input_chunk_clone, check_grad_fn=False, check_grad=False) + + # Check forward pass output + # Extract the expected chunk for this rank + # The .contiguous() is necessary because the equivalent op is single-device transpose + # then scatter, where the scatter op requires contiguity + output_expected_chunk = output_expected_host[ + :, i_chunk_begins[0] : i_chunk_ends[0], i_chunk_begins[1] : i_chunk_ends[1], : + ].contiguous() + + # Move result to host for comparison + result_host = result.detach().to(device=output_expected_chunk.device) + assert_tensors_identical(result_host, output_expected_chunk, check_stride=False) + + # check backward pass + # The .contiguous() is necessary because the equivalent op is single-device transpose + # backward then scatter, where the scatter op requires contiguity + d_input_expected_chunk = d_input_host[ + :, i_chunk_begins[0] : i_chunk_ends[0], i_chunk_begins[1] : i_chunk_ends[1], : + ].contiguous() + assert_tensors_identical(input_chunk.grad.detach().cpu(), d_input_expected_chunk) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.skip_jet_nightly +@pytest.mark.parametrize( + "setup_env", + itertools.product( + [(1, (1, 1)), (2, (1, 1)), (1, (2, 2)), (2, (2, 2)), (1, (3, 3))], [True], ["cpu", "cuda"], ["ENV"] + ), + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_redistribute_transpose_without_dtensor(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + + n_tokens_global = size_ring * size_ring + embed_dim = 8 + batch_size = 4 + dtype = torch.float32 + device_host = torch.device("cpu") + + assert ( + n_tokens_global % size_ring == 0 + ), f"n_tokens_global {n_tokens_global} is not a multiple of size_ring {size_ring}" + + seed_by_rank(0) # same seed for all ranks + + # Create input tensor of shape [B, N, M, D] + input = torch.randn( + (batch_size, n_tokens_global, n_tokens_global, embed_dim), + dtype=dtype, + device=device_type, + ) + input.requires_grad = True + + # For comparison, compute the single-device transpose of dimensions 1 and 2 + output_expected = redistribute_transpose(input, 1, 2) + + d_output_expected = torch.rand_like(output_expected) + torch.autograd.backward([output_expected], [d_output_expected]) + + input_host = input.detach().clone().to(device=device_host) + output_expected_host = output_expected.detach().clone().to(device=device_host) + + d_output_expected_host = d_output_expected.detach().clone().to(device=device_host) + d_input_host = input.grad.detach().clone().to(device=device_host) + + n_tokens_per_rank = n_tokens_global // size_ring + + torch.multiprocessing.set_start_method("spawn", force=True) + torch.multiprocessing.spawn( + fn=parallel_assert_redistribute_transpose, + args=( + grid_group_sizes, + device_type, + backend, + n_tokens_per_rank, + input_host, + output_expected_host, + d_output_expected_host, + d_input_host, + env_per_rank, + ), + nprocs=world_size, + join=True, + ) diff --git a/tests/distributed/model/layers/test_window_batch_utils.py b/tests/distributed/model/layers/test_window_batch_utils.py new file mode 100644 index 000000000..977f2965d --- /dev/null +++ b/tests/distributed/model/layers/test_window_batch_utils.py @@ -0,0 +1,551 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Unit tests for window batching utility functions.""" + +import math + +import pytest +import torch + +from boltz.distributed.model.layers.utils import ( + gather_sliding_windows, + gather_sliding_windows_backward, + get_query_window_key_range, + pack_and_pad, + pack_and_pad_backward, +) +from boltz.model.modules.encoders import get_indexing_matrix +from boltz.testing.utils import assert_tensors_identical + + +def set_batch_diagonal(batch_matrix, k_values, fill_value): + """ + Sets the k-th diagonal for each matrix in a batch. + Space complexity: O(B * min(H,W)) + """ + B, H, W = batch_matrix.shape + device = batch_matrix.device + + # 1. The maximum possible length of any diagonal is min(H, W) + max_diag_len = min(H, W) + + # 2. Create a base sequence [0, 1, 2, ..., max_len-1] + # Shape: (1, max_diag_len) + seq = torch.arange(max_diag_len, device=device).unsqueeze(0) + + # 3. Calculate starting coordinates (r, c) for each k + # If k > 0: start at (0, k) + # If k < 0: start at (|k|, 0) + # Shape: (B, 1) + start_row = (-k_values).clamp(min=0).unsqueeze(1) + start_col = k_values.clamp(min=0).unsqueeze(1) + + # 4. Generate the full coordinate grids + # We broaden the starting points by adding the sequence + # Shape: (B, max_diag_len) + rows = start_row + seq + cols = start_col + seq + + # 5. Create a mask for valid coordinates + # Because diagonals shift, they might hit the boundary before max_diag_len + valid_mask = (rows < H) & (cols < W) + + # 6. Create Batch indices to match + # Shape: (B, max_diag_len) + batch_idx = torch.arange(B, device=device).unsqueeze(1).expand(-1, max_diag_len) + + # 7. Apply Advanced Indexing + # We select only the valid coordinates using the boolean mask. + # PyTorch handles the memory layout mapping here internally. + batch_matrix[batch_idx[valid_mask], rows[valid_mask], cols[valid_mask]] = fill_value + + return batch_matrix + + +@pytest.fixture(params=["cpu", "cuda"]) +def device(request): + if request.param == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + return torch.device(request.param) + + +@pytest.fixture( + params=[ + (8, 16, 1), + (8, 16, 2), + (8, 16, 5), + (8, 16, 10), + (8, 32, 2), + (8, 32, 5), + (8, 32, 20), + (16, 64, 3), + (16, 64, 10), + (16, 64, 50), + (32, 64, 5), + (32, 64, 20), + (32, 128, 1), + (32, 128, 2), + (32, 128, 3), + (32, 128, 10), + (32, 128, 50), + (32, 128, 100), + (32, 256, 5), + (32, 256, 20), + (64, 256, 10), + (64, 256, 50), + (128, 512, 20), + (128, 512, 100), + ], + ids=lambda x: f"W:{x[0]}, H:{x[1]}, K:{x[2]}", +) +def get_toeplitz(request, device): + """ + Fixture that creates onehot tensor and verifies Toeplitz property. + + Returns (W, H, K, h, batched_toeplitz) where batched_toeplitz has shape (K, h, 2*K). + + Verifies one-hot property: each (query_window, slot) has at most one non-zero. + """ + W, H, K = request.param + h = H // (W // 2) + + idx_mat = get_indexing_matrix(K, W, H, device) + batched_toeplitz = idx_mat.unflatten(-1, (K, h)).transpose(0, 1).transpose(-2, -1).to(dtype=torch.int32) + # Shape: (K, h, 2*K) + + assert ( + batched_toeplitz.shape == (K, h, 2 * K) + ), f"get_indexing_matrix({K}, {W}, {H}) produces incorrect shape (post transformation): {batched_toeplitz.shape} != (K, h, 2 * K)" + + # Verify the Toeplitz property + # All Toeplitz inside batched batched_toeplitz should have one non-zero diagonal of ones, + # with the first Toeplitz matrix has diagonal at offset[0] = 1 - h // 2 and subsequent + # Toeplitz matrices have diagonal at offset[i] = offset[i-1] + 2. + batched_toeplitz_expected = set_batch_diagonal( + torch.zeros((K, h, 2 * K), device=batched_toeplitz.device, dtype=batched_toeplitz.dtype), + torch.arange(1 - h // 2, 1 - h // 2 + (K - 1) * 2 + 1, 2, device=batched_toeplitz.device), + 1, + ) + + assert torch.all( + batched_toeplitz == batched_toeplitz_expected + ), f"get_indexing_matrix({K}, {W}, {H}) does not produce the expected Toeplitz matrix" + + return W, H, K, h, batched_toeplitz + + +def test_range_formula_all_windows(get_toeplitz): + """Test that range formula is correct for ALL query windows using batched call.""" + W, H, K, h, batched_toeplitz = get_toeplitz + device = batched_toeplitz.device + + # Get ranges for ALL query windows in one batched call + all_ids = torch.arange(K, device=device) + ranges = get_query_window_key_range(W, H, K, all_ids) + + # Verify shape + assert ranges.shape == (2, K) + + # Extract ground truth from onehot using batched sparse COO (no loops!) + onehot_sparse = batched_toeplitz.to_sparse_coo() + indices = onehot_sparse.indices() # Shape: (3, num_nonzeros) + # indices[0] = query window index (i) + # indices[1] = slot index + # indices[2] = half-window index (j) + + qw_idx = indices[0] + j_idx = indices[2] + + # Use scatter_reduce to compute min/max per query window (PyTorch 1.12+) + expected_j_min = torch.full((K,), 2 * K, dtype=torch.long, device=device) + expected_j_max = torch.full((K,), -1, dtype=torch.long, device=device) + + expected_j_min.scatter_reduce_(0, qw_idx, j_idx, reduce="amin", include_self=False) + expected_j_max.scatter_reduce_(0, qw_idx, j_idx, reduce="amax", include_self=False) + + # Filter to valid windows (those with at least one non-zero) + valid_mask = expected_j_max >= 0 + + assert torch.all(ranges[0, valid_mask] == expected_j_min[valid_mask]), f"W={W},H={H},K={K}: j_min mismatch" + assert torch.all(ranges[1, valid_mask] == expected_j_max[valid_mask]), f"W={W},H={H},K={K}: j_max mismatch" + + +@pytest.mark.parametrize("ndim,axis", [(2, 0), (3, 1), (3, -2), (4, 2), (5, 3), (5, -2)]) +def test_efficient_unfold_equivalence(get_toeplitz, ndim, axis): + """ + Test gather_sliding_windows matches einsum with Toeplitz matrix. + + Tests with inputs of varying dimensions and axis positions. + """ + W, H, K, h, batched_toeplitz = get_toeplitz + device = batched_toeplitz.device + + # Build shape with 2*K at the specified axis + shape_list = [2, 3, 4, 5, 6][:ndim] + + # Normalize axis + norm_axis = axis if axis >= 0 else ndim + axis + shape_list[norm_axis] = 2 * K + + input_shape = tuple(shape_list) + dense_input = torch.arange(math.prod(input_shape), device=device, dtype=torch.float32).reshape(input_shape) + dense_input.requires_grad_(True) + + # Method 1: Generic einsum by moving axis to front + # Move axis dimension to position 0 + input_permuted = dense_input.moveaxis(norm_axis, 0) # (2*K, ...) + + # Generic einsum: (K, h, 2*K) × (2*K, ...) → (K, h, ...) + result_einsum = torch.einsum("kij,j...->ki...", batched_toeplitz.float(), input_permuted) + + # Move K and h dimensions to where axis was + # result_einsum is (K, h, ...) - move to (...[:axis], K, h, ...[axis:]) + result_einsum = result_einsum.moveaxis([0, 1], [norm_axis, norm_axis + 1]) + + # Method 2: efficient_toeplitz_matmul_unfold + offset_start = 1 - h // 2 + offsets = torch.arange(offset_start, offset_start + 2 * (K - 1) + 1, 2, device=device) + + dense_input_clone = dense_input.detach().clone().requires_grad_(True) + result_unfold = gather_sliding_windows(dense_input_clone, offsets, h, axis) + + # Verify equivalence + # Forward is exact copy vs einsum (multiply by 1.0) + # Should be close, maybe not bitwise identical on GPU + torch.testing.assert_close(result_einsum, result_unfold) + + # Verify backward pass + grad_output = torch.arange(result_einsum.numel(), device=device, dtype=result_einsum.dtype).reshape_as( + result_einsum + ) + + result_einsum.backward(grad_output.detach().clone()) + result_unfold.backward(grad_output.detach().clone()) + + torch.testing.assert_close(dense_input.grad, dense_input_clone.grad) + + +@pytest.mark.parametrize( + "W,H,K,qw_start,qw_end", + [ + # Test first windows (includes QW0 with negative offset) + (32, 128, 10, 0, 3), + (32, 128, 20, 0, 5), + # Test last windows (includes final QW with boundary) + (32, 128, 10, 7, 10), + (32, 128, 20, 15, 20), + # Test middle windows (interior, no boundaries) + (32, 128, 20, 5, 15), + (32, 128, 50, 20, 30), + # Test single window + (32, 128, 10, 3, 4), + # Different h values + (32, 64, 10, 2, 6), # h=4 + (32, 256, 10, 2, 6), # h=16 + # Larger K + (32, 128, 100, 40, 60), + ], +) +def test_translational_symmetry(W, H, K, qw_start, qw_end, device): + """ + Test Theorem 6: Translational symmetry of Toeplitz multiplication. + + Verifies: T(x[δ:δ+n], offsets - δ) == T(x, offsets)[slice] + + For a subset of query windows computed on a translated input slice, + the result equals slicing the full computation. + """ + h = H // (W // 2) + + # Full computation + torch.manual_seed(42) + input_full = torch.randn(2 * K, 16, device=device, requires_grad=True) + + offset_start = 1 - h // 2 + offsets_full = torch.arange(offset_start, offset_start + 2 * K, 2, device=device) + result_full = gather_sliding_windows(input_full, offsets_full, h, axis=0) + + # Determine input span needed for subset + subset_qw_ids = torch.arange(qw_start, qw_end, device=device) + ranges = get_query_window_key_range(W, H, K, subset_qw_ids) + hw_need_start = ranges[0].min().item() + hw_need_end = ranges[1].max().item() + 1 + + # Extract input slice and translate offsets + input_slice = input_full[hw_need_start:hw_need_end].detach().clone().requires_grad_(True) + offsets_subset = offsets_full[qw_start:qw_end] - hw_need_start + + # Compute on translated slice + result_subset = gather_sliding_windows(input_slice, offsets_subset, h, axis=0) + + # Verify: T(x[δ:], offsets-δ) == T(x, offsets)[slice] + expected_subset = result_full[qw_start:qw_end] + assert torch.all(result_subset == expected_subset), ( + f"W={W},H={H},K={K},QW[{qw_start},{qw_end}): Forward mismatch: \n" + f" {result_subset} \n vs. \n" + f" {expected_subset}" + ) + + # Verify backward pass + grad_output = torch.randn_like(result_subset) + + grad_full = torch.zeros_like(result_full) + grad_full[qw_start:qw_end] = grad_output + result_full.backward(grad_full) + + result_subset.backward(grad_output.clone()) + + expected_grad_slice = input_full.grad[hw_need_start:hw_need_end] + # Backward involves gradient accumulation which can have small numerical differences on GPU + torch.testing.assert_close(input_slice.grad, expected_grad_slice) + + +def test_backward_validation_errors(device): + """Test that gather_sliding_windows_backward raises appropriate errors for invalid inputs.""" + + # Valid baseline + grad_output = torch.randn(5, 8, 16, device=device) # (n_windows=5, window_size=8, features=16) + window_start_offsets = torch.tensor([-3, -1, 1, 3, 5], device=device) + window_size = 8 + axis = 0 + input_shape = (12, 16) # (2*K=12, features=16) + + # Test 1: grad_output not a tensor + with pytest.raises(TypeError, match="grad_output must be a torch.Tensor"): + gather_sliding_windows_backward([1, 2, 3], window_start_offsets, window_size, axis, input_shape) + + # Test 2: window_start_offsets not a tensor + with pytest.raises(TypeError, match="window_start_offsets must be a torch.Tensor"): + gather_sliding_windows_backward(grad_output, [1, 2, 3], window_size, axis, input_shape) + + # Test 3: window_start_offsets not 1D + with pytest.raises(ValueError, match="window_start_offsets must be 1D"): + bad_offsets = torch.tensor([[1, 2], [3, 4]], device=device) + gather_sliding_windows_backward(grad_output, bad_offsets, window_size, axis, input_shape) + + # Test 4: axis out of range + with pytest.raises(ValueError, match="axis .* out of range"): + gather_sliding_windows_backward(grad_output, window_start_offsets, window_size, axis=5, input_shape=input_shape) + + # Test 5: grad_output shape mismatch (wrong n_windows dimension) + with pytest.raises(ValueError, match="grad_output shape mismatch"): + bad_grad = torch.randn(7, 8, 16, device=device) # Wrong n_windows (7 instead of 5) + gather_sliding_windows_backward(bad_grad, window_start_offsets, window_size, axis, input_shape) + + # Test 6: grad_output shape mismatch (wrong window_size dimension) + with pytest.raises(ValueError, match="grad_output shape mismatch"): + bad_grad = torch.randn(5, 10, 16, device=device) # Wrong window_size (10 instead of 8) + gather_sliding_windows_backward(bad_grad, window_start_offsets, window_size, axis, input_shape) + + # Test 7: grad_output shape mismatch (wrong feature dimension) + with pytest.raises(ValueError, match="grad_output shape mismatch"): + bad_grad = torch.randn(5, 8, 32, device=device) # Wrong features (32 instead of 16) + gather_sliding_windows_backward(bad_grad, window_start_offsets, window_size, axis, input_shape) + + # Test 8: grad_output wrong ndim + with pytest.raises(ValueError, match="grad_output shape mismatch"): + bad_grad = torch.randn(5, 8, device=device) # Missing feature dimension + gather_sliding_windows_backward(bad_grad, window_start_offsets, window_size, axis, input_shape) + + +@pytest.mark.parametrize( + "input_shape_extra", [(None,), (4, None), (None, 3), (2, None, 3)], ids=lambda x: f"input_shape_extra:{x}" +) +@pytest.mark.parametrize("keep_input_padding", [False, True], ids=lambda x: f"keep_input_padding:{x}") +def test_pack_and_pad_equivalence(get_toeplitz, input_shape_extra, keep_input_padding): + """ + Test pack_and_pad utility. + + `None` in fixture `input_shape_extra` is eventually replaced with K * W. + For example, `input_shape_extra (2, None, 3)` will be replaced with `(2, K * W, 3)` in the test. + """ + W, H, K, h, batched_toeplitz = get_toeplitz + device = batched_toeplitz.device + + # Setup parameters + n_axes_none = sum(1 for x in input_shape_extra if x is None) + if n_axes_none != 1: + raise ValueError(f"There can be one and only one 'None' element in the input_shape but got {input_shape_extra}") + axis = input_shape_extra.index(None) + + input_shape = input_shape_extra[:axis] + (K * W,) + input_shape_extra[axis + 1 :] + + # 1. Generate clean input of shape (K*W, features) + # This represents the "perfectly padded" sequence + torch.manual_seed(42) + input = torch.randn(input_shape, device=device, requires_grad=True) + input_copy = input.detach().clone().requires_grad_(True) + + mask = torch.randint(0, 2, input_shape, dtype=torch.bool, device=device, requires_grad=False) + mask_copy = mask.detach().clone().requires_grad_(False) + + mask_sorted, argsort_mask = torch.sort(mask, dim=axis, descending=True, stable=True) + input_sorted = torch.gather(input, axis, argsort_mask) + + # --- Reference Computation (on clean input) --- + # Reshape to (2*K, W//2, features) for einsum + # Note: K*W elements -> reshaped to (2*K, W/2) + # 2*K * W/2 = K*W. Correct. + + # input_clean: (K*W, F) -> (2*K, W//2, F) + # We need to ensure the layout matches what the reshape produces. + # The utility reshapes (K*W) -> (2*K, W//2) along axis. + # This splits the sequence into chunks of size W/2. + input_sorted_reshaped = input_sorted.moveaxis(axis, -1).unflatten(-1, (2 * K, W // 2)) + + # Einsum: (K, h, 2*K) x (..., 2*K, W//2) -> (..., K, h, W//2) + # Sum over 2*K dimension + # batched_toeplitz: (K, h, 2*K) + ref_output_reshaped = torch.einsum("khj,...jb->...khb", batched_toeplitz.float(), input_sorted_reshaped) + # (..., K, h, W//2) -> (..., K, h, W//2, ...) + ref_output = ref_output_reshaped.moveaxis([-3, -2, -1], [axis, axis + 1, axis + 2]) + with torch.no_grad(): + mask_prepared_ref = mask_sorted.unflatten(axis, (2 * K, W // 2)) + mask_output_ref = ( + torch.einsum( + "khj,...jb->...khb", + batched_toeplitz.to(dtype=torch.float32), + # (..., 2 * K, W // 2, ...) -> (..., 2*K, W//2) + mask_prepared_ref.to(dtype=torch.float32).moveaxis([axis, axis + 1], [-2, -1]), + ).moveaxis([-3, -2, -1], [axis, axis + 1, axis + 2]) + ).to(dtype=mask_prepared_ref.dtype) + + # Reshape reference result to match gather_sliding_windows output structure + # Reference einsum output: (..., K, h, W//2, ...) + # gather_sliding_windows output: (..., n_windows, window_size, ...) + # For axis=0 input (2*K, W//2, F), it returns (K, h, W//2, F) + # Structure matches exactly. + + # --- Target Computation --- + # 1. pack the valid elements and pad to the next multiple of W + input_prepared, _, mask_prepared = pack_and_pad( + input_copy, mask_copy, axis, W, keep_input_padding=keep_input_padding + ) + + # check the mask + assert mask_prepared.shape == input_prepared.shape + n_valid = mask_copy.expand_as(input_copy).sum(dim=axis) + # there must be leading n_valid elements and trailing zeros along axis + # NOTE: cumprod zeros out non-leading True elements + assert_tensors_identical(mask_prepared.cumprod(dim=axis).sum(dim=axis), n_valid) + + if keep_input_padding: + mask_prepared_padded = mask_prepared + input_prepared_padded = input_prepared + else: + # when keep_input_padding is False, the mask_prepared will be shorter than the reference + # along 'axis'. We pad them before along 'axis' with zeros towards the reference length + pad_len = mask_prepared_ref.flatten(axis, axis + 1).shape[axis] - mask_prepared.shape[axis] + assert pad_len >= 0, "Padding length should be non-negative for the result when keep_input_padding is False" + pad_arg = [0] * (2 * mask_prepared.ndim) + pad_idx = (mask_prepared.ndim - 1 - axis) * 2 + 1 + pad_arg[pad_idx] = pad_len + input_prepared_padded = torch.nn.functional.pad(input_prepared, pad_arg) + mask_prepared_padded = torch.nn.functional.pad(mask_prepared, pad_arg) + + # reshape to (..., 2*K, W//2, ...) + input_prepared_padded = input_prepared_padded.unflatten(axis, (2 * K, W // 2)) + mask_prepared_padded = mask_prepared_padded.unflatten(axis, (2 * K, W // 2)) + + assert_tensors_identical(mask_prepared_padded, mask_prepared_ref) + + # 2. Gather Sliding Windows + # We need offsets for K windows. + # offset_start = 1 - h // 2 + offset_start = 1 - h // 2 + window_start_offsets = torch.arange(offset_start, offset_start + 2 * K, 2, device=device) + + target_output = gather_sliding_windows(input_prepared_padded, window_start_offsets, h, axis) + + # --- Verification --- + + # 1. Verify forward output + # target_output: (..., K, h, W//2, ...) + # Should be binary identical (no fp math involved, just gather/move) + torch.testing.assert_close(target_output, ref_output * mask_output_ref, atol=0, rtol=0) + + # 2. Verify backward (gradients) + # We'll compute gradients w.r.t. input_clean vs input_masked + + grad_out = torch.randn_like(ref_output) + + # NOTE: the reference computation doesn't use mask at all so we need to mask out the + # invalid upstream adjoints + with torch.no_grad(): + grad_out = grad_out * mask_output_ref + + # Reference backward + # 1. Backprop through Toeplitz path + torch.autograd.backward([ref_output], [grad_out]) + + # Target backward + torch.autograd.backward([target_output], [grad_out]) + + # Backward involves gradient accumulation (summation) which is not bitwise identical + # between einsum (reference) and index_add_ (target) due to floating point associativity. + torch.testing.assert_close(input_copy.grad, input.grad) + + # Verify padded/invalid elements have zero gradient + assert_tensors_identical(torch.zeros_like(input_copy.grad), input_copy.grad * ~mask_copy) + + +def test_pack_and_pad_backward_manual(get_toeplitz): + """Test manual backward pass for pack_and_pad.""" + W, H, K, h, batched_toeplitz = get_toeplitz + device = batched_toeplitz.device + axis = 0 + features = 16 + + # Setup inputs + torch.manual_seed(42) + input_clean = torch.randn(K * W, features, device=device, requires_grad=True) + + num_invalid = 5 + total_len = K * W + num_invalid + mask_indices = torch.randperm(total_len, device=device) + valid_indices = mask_indices[: K * W] + valid_indices_sorted, _ = torch.sort(valid_indices) + + input_masked = torch.zeros(total_len, features, device=device) + input_masked[valid_indices_sorted] = input_clean.detach() + input_masked.requires_grad_(True) + + mask = torch.zeros(total_len, dtype=torch.bool, device=device) + mask[valid_indices_sorted] = True + mask_reshaped = mask.reshape(total_len, 1) + + # Forward + output, indices, mask_output = pack_and_pad(input_masked, mask_reshaped, axis, W) + + # Compute grad using autograd + grad_out = torch.randn_like(output) + + torch.autograd.backward([output], [grad_out]) + + grad_autograd = input_masked.grad.clone() + input_masked.grad.zero_() + + # Compute grad using manual backward + grad_manual = pack_and_pad_backward(grad_out, mask_output, indices, input_masked.shape, axis) + + # Compare + torch.testing.assert_close(grad_manual, grad_autograd) diff --git a/tests/distributed/model/layers/test_window_ownership.py b/tests/distributed/model/layers/test_window_ownership.py new file mode 100644 index 000000000..2a1858cb3 --- /dev/null +++ b/tests/distributed/model/layers/test_window_ownership.py @@ -0,0 +1,361 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for query window ownership computation in distributed windowed attention.""" + +import pytest + +from boltz.distributed.model.layers.utils import ( + compute_query_window_ownership, + get_halo_from_neighbors, +) + + +def compute_rank_ranges(K: int, n_ranks: int, rank: int) -> tuple[tuple[int, int], tuple[int, int]]: + """ + Test helper to compute rank ranges for simulating DTensor sharding. + + IMPORTANT: Half-window ownership is derived from query window ownership, + not computed independently. This ensures hw_start = 2*qw_start always holds. + """ + # Query windows (contiguous assignment) + qw_per_rank = K // n_ranks + qw_remainder = K % n_ranks + + if rank < qw_remainder: + qw_start = rank * (qw_per_rank + 1) + qw_end = qw_start + qw_per_rank + 1 + else: + qw_start = rank * qw_per_rank + qw_remainder + qw_end = qw_start + qw_per_rank + + # Half-windows: DERIVED from query windows (QW i owns HW [2i, 2i+1]) + hw_start = 2 * qw_start + hw_end = 2 * qw_end + + return (qw_start, qw_end), (hw_start, hw_end) + + +@pytest.mark.parametrize("W,H", params := [(32, 128), (16, 64), (32, 256)], ids=[f"W:{x[0]}-H:{x[1]}" for x in params]) +@pytest.mark.parametrize("K_per_rank", [1, 2, 10, 12, 32, 100], ids=lambda x: f"K_per_rank:{x}") +@pytest.mark.parametrize("n_ranks", [2, 3, 4, 8], ids=lambda x: f"n_ranks:{x}") +def test_ownership_coverage_and_validity(W, H, K_per_rank, n_ranks): + """ + Test that ownership mapping covers all query windows exactly once and is valid. + + NOTE: this test doesn't cover the case K % n_ranks != 0, i.e., + the number of query windows is not a multiple of the number of ranks, which is + not supported by the DistributedGatherSlidingWindows currently. + + Verifies: + 1. All query windows [0, K) are owned by exactly one rank + 2. No gaps or overlaps in ownership + 3. All ranges are valid (start <= end) + 4. Halos are non-negative + """ + K = K_per_rank * n_ranks + all_owned_qws = [] + + # Assuming uniform distribution as per DistributedGatherSlidingWindows requirement K % n_ranks == 0 + # The compute_rank_ranges handles non-uniform but we restrict here to uniform to match utils.py logic + # Actually utils.py requires K % n_ranks == 0. + local_hw_len = (2 * K) // n_ranks + + for r in range(n_ranks): + (qw_start, qw_end), (hw_start, hw_end) = compute_rank_ranges(K, n_ranks, r) + ownership = compute_query_window_ownership(W, H, K, qw_start, qw_end) + + hw_need_start, hw_need_end = ownership["hw_needed"] + left_halo = ownership["left_halo_size"] + right_halo = ownership["right_halo_size"] + + # Valid ranges + assert qw_start <= qw_end, f"Rank {r}: invalid qw_range" + assert hw_start <= hw_end, f"Rank {r}: invalid hw_owned" + assert hw_need_start <= hw_need_end, f"Rank {r}: invalid hw_needed" + + # Non-negative halos + assert left_halo >= 0, f"Rank {r}: negative left_halo" + assert right_halo >= 0, f"Rank {r}: negative right_halo" + + # Verify sufficiency of neighbors using get_halo_from_neighbors + recv_meta, _ = get_halo_from_neighbors(r, n_ranks, local_hw_len, W, H, K) + + received_left = sum(length for _, htype, _, length in recv_meta if htype == "left") + received_right = sum(length for _, htype, _, length in recv_meta if htype == "right") + + assert ( + received_left == left_halo + ), f"Rank {r}: insufficient left halo from neighbors. Need {left_halo}, received {received_left}" + assert ( + received_right == right_halo + ), f"Rank {r}: insufficient right halo from neighbors. Need {right_halo}, received {received_right}" + + # Collect owned windows + if qw_start < qw_end: + all_owned_qws.extend(range(qw_start, qw_end)) + + # Verify complete coverage + assert sorted(all_owned_qws) == list(range(K)), f"W={W},H={H},K={K},n_ranks={n_ranks}: Coverage mismatch" + + +@pytest.mark.parametrize("W,H", params := [(32, 128), (16, 64), (32, 256)], ids=[f"W:{x[0]}-H:{x[1]}" for x in params]) +@pytest.mark.parametrize("K_per_rank", [1, 2, 10, 12, 32, 100], ids=lambda x: f"K_per_rank:{x}") +@pytest.mark.parametrize("n_ranks", [2, 3, 4, 8], ids=lambda x: f"n_ranks:{x}") +def test_get_halo_from_neighbors_symmetry(W, H, K_per_rank, n_ranks): + """ + Test that get_halo_from_neighbors produces symmetric recv/send patterns. + If Rank A expects to receive from Rank B, Rank B must expect to send to Rank A. + """ + K = K_per_rank * n_ranks + local_hw_len = (2 * K) // n_ranks + + for r in range(n_ranks): + recv_meta, send_meta = get_halo_from_neighbors(r, n_ranks, local_hw_len, W, H, K) + + # Check each receive item + for peer, htype, offset_in_halo, length in recv_meta: + # Check peer's send list + peer_recv, peer_send = get_halo_from_neighbors(peer, n_ranks, local_hw_len, W, H, K) + + # Look for a matching send: to me (r), same length + # Note: offset logic is trickier to verify cross-rank without full reconstruction, + # but length and existence are critical. + matches = [s for s in peer_send if s[0] == r and s[2] == length] + assert matches, ( + f"Rank {r} expects to recv {length} items (type {htype}) from {peer}, " + f"but {peer} has no matching send to {r}. Peer sends: {peer_send}" + ) + + # Check each send item + for peer, offset_in_local, length in send_meta: + # Check peer's recv list + peer_recv, peer_send = get_halo_from_neighbors(peer, n_ranks, local_hw_len, W, H, K) + + matches = [r_item for r_item in peer_recv if r_item[0] == r and r_item[3] == length] + assert matches, ( + f"Rank {r} expects to send {length} items to {peer}, " + f"but {peer} has no matching recv from {r}. Peer recvs: {peer_recv}" + ) + + +@pytest.mark.parametrize("W,H", params := [(32, 128), (16, 64), (32, 256)], ids=[f"W:{x[0]}-H:{x[1]}" for x in params]) +@pytest.mark.parametrize("K_per_rank", [1, 2, 10, 12, 32, 100], ids=lambda x: f"K_per_rank:{x}") +@pytest.mark.parametrize("n_ranks", [2, 3, 4, 8], ids=lambda x: f"n_ranks:{x}") +def test_halo_sufficiency(W, H, K_per_rank, n_ranks): + """Test that halos are sufficient to cover all needed half-windows.""" + K = K_per_rank * n_ranks + for r in range(n_ranks): + (qw_start, qw_end), (hw_start, hw_end) = compute_rank_ranges(K, n_ranks, r) + ownership = compute_query_window_ownership(W, H, K, qw_start, qw_end) + + hw_need_start, hw_need_end = ownership["hw_needed"] + left_halo = ownership["left_halo_size"] + right_halo = ownership["right_halo_size"] + + # With halos, we should cover all needed half-windows + available_start = hw_start - left_halo + available_end = hw_end + right_halo + + assert available_start <= hw_need_start, f"Rank {r}: left halo insufficient" + assert available_end >= hw_need_end, f"Rank {r}: right halo insufficient" + + +@pytest.mark.parametrize( + "W,H,K,n_ranks", + [ + (32, 128, 10, 2), + (32, 128, 12, 3), + (32, 128, 32, 4), + ], +) +def test_boundary_ranks_one_sided_halos(W, H, K, n_ranks): + """Test that first and last ranks have one-sided halos.""" + # First rank + (qw_start, qw_end), (hw_start, hw_end) = compute_rank_ranges(K, n_ranks, 0) + ownership_0 = compute_query_window_ownership(W, H, K, qw_start, qw_end) + assert ownership_0["left_halo_size"] == 0, "First rank should have no left halo" + + # Last rank + (qw_start, qw_end), (hw_start, hw_end) = compute_rank_ranges(K, n_ranks, n_ranks - 1) + ownership_last = compute_query_window_ownership(W, H, K, qw_start, qw_end) + assert ownership_last["right_halo_size"] == 0, "Last rank should have no right halo" + + +def test_ownership_example_k12_n3(): + """Concrete example test for K=12, n_ranks=3.""" + W, H, K, n_ranks = 32, 128, 12, 3 + + # Rank 0 + (qw_start, qw_end), (hw_start, hw_end) = compute_rank_ranges(K, n_ranks, 0) + assert (qw_start, qw_end) == (0, 4) + assert (hw_start, hw_end) == (0, 8) + own0 = compute_query_window_ownership(W, H, K, qw_start, qw_end) + assert own0["left_halo_size"] == 0 # First rank + + # Rank 1 + (qw_start, qw_end), (hw_start, hw_end) = compute_rank_ranges(K, n_ranks, 1) + assert (qw_start, qw_end) == (4, 8) + assert (hw_start, hw_end) == (8, 16) + + # Rank 2 + (qw_start, qw_end), (hw_start, hw_end) = compute_rank_ranges(K, n_ranks, 2) + assert (qw_start, qw_end) == (8, 12) + assert (hw_start, hw_end) == (16, 24) + own2 = compute_query_window_ownership(W, H, K, qw_start, qw_end) + assert own2["right_halo_size"] == 0 # Last rank + + +def test_n_ranks_greater_than_K(): + """Test edge case where we have more ranks than query windows.""" + W, H, K, n_ranks = 32, 128, 5, 10 + + ranks_with_windows = 0 + ranks_without_windows = 0 + + for r in range(n_ranks): + (qw_start, qw_end), (hw_start, hw_end) = compute_rank_ranges(K, n_ranks, r) + ownership = compute_query_window_ownership(W, H, K, qw_start, qw_end) + + if qw_start < qw_end: + ranks_with_windows += 1 + else: + ranks_without_windows += 1 + # Ranks without windows should have zero halos + assert ownership["left_halo_size"] == 0 + assert ownership["right_halo_size"] == 0 + + assert ranks_with_windows == K, "Should have K ranks with windows" + assert ranks_without_windows == n_ranks - K + + +def get_halo_from_neighbors_iterative( + rank: int, + size_group: int, + n_half_windows_local: int, + W: int, + H: int, + K: int, +) -> tuple[list, list]: + """ + Reference iterative implementation of get_halo_from_neighbors for testing purposes. + Copied from previous implementation. + """ + # 2. Compute ownership for ALL ranks + all_ownerships = [] + for r in range(size_group): + r_hw_start = r * n_half_windows_local + r_hw_end = (r + 1) * n_half_windows_local + r_qw_start = r_hw_start // 2 + r_qw_end = r_hw_end // 2 + if r_qw_start < r_qw_end: + all_ownerships.append(compute_query_window_ownership(W, H, K, r_qw_start, r_qw_end)) + else: + all_ownerships.append(None) # Handle empty ranks if K < size_group + + my_own = all_ownerships[rank] + # Handle case where rank has no windows + if my_own is None: + hw_start, hw_end = rank * n_half_windows_local, (rank + 1) * n_half_windows_local + hw_need_start, hw_need_end = hw_start, hw_start + else: + hw_start, hw_end = my_own["hw_owned"] + hw_need_start, hw_need_end = my_own["hw_needed"] + + recv_meta = [] # (peer, halo_type, offset_in_halo, length) + send_meta = [] # (peer, offset_in_local, length) + + for peer in range(size_group): + if peer == rank: + continue + + peer_own = all_ownerships[peer] + peer_hw_start = peer * n_half_windows_local + peer_hw_end = (peer + 1) * n_half_windows_local + + # --- RECV Logic (What I need from peer) --- + # overlap between two intervals: + # I need: [hw_need_start, hw_start) and + # Peer owns: [peer_hw_start, peer_hw_end) + l_start = max(hw_need_start, peer_hw_start) + l_end = min(hw_start, peer_hw_end) + if l_start < l_end: + recv_meta.append((peer, "left", l_start - hw_need_start, l_end - l_start)) + + # overlap between two intervals: + # I need: [hw_end, hw_need_end) and + # Peer owns: [peer_hw_start, peer_hw_end) + r_start = max(hw_end, peer_hw_start) + r_end = min(hw_need_end, peer_hw_end) + if r_start < r_end: + recv_meta.append((peer, "right", r_start - hw_end, r_end - r_start)) + + # --- SEND Logic (What peer needs from me) --- + if peer_own is None: + continue + p_need_start, p_need_end = peer_own["hw_needed"] + p_hw_start, p_hw_end = peer_own["hw_owned"] + + # overlap between two intervals: + # Peer needs: [p_need_start, p_hw_start) and + # I own: [hw_start, hw_end) + l_start = max(p_need_start, hw_start) + l_end = min(p_hw_start, hw_end) + if l_start < l_end: + send_meta.append((peer, l_start - hw_start, l_end - l_start)) + + # overlap between two intervals: + # Peer needs: [p_hw_end, p_need_end) and + # I own: [hw_start, hw_end) + r_start = max(p_hw_end, hw_start) + r_end = min(p_need_end, hw_end) + if r_start < r_end: + send_meta.append((peer, r_start - hw_start, r_end - r_start)) + + return recv_meta, send_meta + + +@pytest.mark.parametrize("W,H", params := [(32, 128), (16, 64), (32, 256)], ids=[f"W:{x[0]}-H:{x[1]}" for x in params]) +@pytest.mark.parametrize("K_per_rank", [1, 2, 10, 12, 32, 100], ids=lambda x: f"K_per_rank:{x}") +@pytest.mark.parametrize("n_ranks", [2, 3, 4, 8], ids=lambda x: f"n_ranks:{x}") +def test_get_halo_vectorized_vs_iterative(W, H, K_per_rank, n_ranks): + """ + Compare new vectorized get_halo_from_neighbors against the old iterative implementation. + """ + K = K_per_rank * n_ranks + + local_hw_len = (2 * K) // n_ranks + + for r in range(n_ranks): + recv_vec, send_vec = get_halo_from_neighbors(r, n_ranks, local_hw_len, W, H, K) + recv_iter, send_iter = get_halo_from_neighbors_iterative(r, n_ranks, local_hw_len, W, H, K) + + # Sort to ensure order consistency for comparison + # recv tuple: (peer, type, offset, length) + recv_vec.sort() + recv_iter.sort() + + # send tuple: (peer, offset, length) + send_vec.sort() + send_iter.sort() + + assert recv_vec == recv_iter, f"Rank {r}: Recv mismatch.\nVectorized: {recv_vec}\nIterative: {recv_iter}" + assert send_vec == send_iter, f"Rank {r}: Send mismatch.\nVectorized: {send_vec}\nIterative: {send_iter}" diff --git a/tests/distributed/model/loss/__init__.py b/tests/distributed/model/loss/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/tests/distributed/model/loss/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/tests/distributed/model/loss/benchmark_smooth_lddt_loss_triton.py b/tests/distributed/model/loss/benchmark_smooth_lddt_loss_triton.py new file mode 100644 index 000000000..18d509b57 --- /dev/null +++ b/tests/distributed/model/loss/benchmark_smooth_lddt_loss_triton.py @@ -0,0 +1,308 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import argparse +import time +from contextlib import contextmanager + +import torch + +from boltz.distributed.model.loss.diffusion import ( + _smooth_lddt_loss_backward_local, + _smooth_lddt_loss_forward_local, + _smooth_lddt_loss_local_triton_backward, + _smooth_lddt_loss_local_triton_forward, +) +from boltz.distributed.model.modules.utils import Precision, setup_tf32_env + +try: + import triton # noqa: F401 + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +@contextmanager +def benchmark_peak_memory_and_runtime(num_warmup=5, num_iter=10): + """Benchmark with proper warmup and averaging. + + Args: + num_warmup: Number of warmup iterations (for kernel compilation/autotuning) + num_iter: Number of timed iterations for averaging + """ + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + start_mem = torch.cuda.memory_allocated() + + stats = {"warmup_iters": num_warmup, "timed_iters": num_iter} + yield stats + + torch.cuda.synchronize() + peak_mem = torch.cuda.max_memory_allocated() + + # We want the peak memory *induced* by the function, relative to start. + peak_usage_mb = (peak_mem - start_mem) / 1024 / 1024 + stats["peak_mem"] = peak_usage_mb + + +def run_benchmark(B_local, N_atoms, multiplicity=1, dtype_str="fp32"): + match dtype_str: + case "fp32": + dtype = torch.float32 + precision = Precision.FP32 + case "bf16": + dtype = torch.bfloat16 + precision = Precision.BF16 + case "tf32": + dtype = torch.float32 + precision = Precision.TF32 + case _: + raise ValueError(f"Unsupported dtype: {dtype_str}") + + print(f"\nBenchmark Config: B_local={B_local}, N_atoms={N_atoms}, Multiplicity={multiplicity}, Dtype={dtype_str}") + device = torch.device("cuda") + + # Calculate grid sizes for diagnostic + effective_batch = B_local * multiplicity + fwd_block = 64 # Default from kernel + bwd_block = 32 # Default from kernel + fwd_grid_size = (effective_batch, (N_atoms + fwd_block - 1) // fwd_block, (N_atoms + fwd_block - 1) // fwd_block) + bwd_grid_size = (effective_batch, (N_atoms + bwd_block - 1) // bwd_block, (N_atoms + bwd_block - 1) // bwd_block) + fwd_total_programs = fwd_grid_size[0] * fwd_grid_size[1] * fwd_grid_size[2] + bwd_total_programs = bwd_grid_size[0] * bwd_grid_size[1] * bwd_grid_size[2] + + print(f"Effective batch size: {effective_batch}") + print(f"Forward grid (BLOCK={fwd_block}): {fwd_grid_size} = {fwd_total_programs:,} programs") + print(f"Backward grid (BLOCK={bwd_block}): {bwd_grid_size} = {bwd_total_programs:,} programs") + + if fwd_total_programs > 100000: + print("WARNING: Very large grid size may cause atomic contention!") + + print("NOTE: First Triton run includes kernel compilation + autotuning overhead") + + # Setup inputs + pred_coords_local = torch.randn(B_local * multiplicity, N_atoms, 3, device=device, dtype=dtype) + true_coords_local = torch.randn(B_local * multiplicity, N_atoms, 3, device=device, dtype=dtype) + pred_coords_t_local = torch.randn(B_local * multiplicity, N_atoms, 3, device=device, dtype=dtype) + true_coords_t_local = torch.randn(B_local * multiplicity, N_atoms, 3, device=device, dtype=dtype) + + # Expand inside function, but inputs here match function signature expected input + # Function expects (B, M) for these, and expands them inside. + is_nucleotide_local = torch.randint(0, 2, (B_local, N_atoms), device=device, dtype=torch.bool) + coords_mask_local = torch.randint(0, 2, (B_local, N_atoms), device=device, dtype=dtype) + coords_mask_t_local = torch.randint(0, 2, (B_local, N_atoms), device=device, dtype=dtype) + + nucleic_acid_cutoff = 5.0 + other_cutoff = 3.0 + is_self_comm = False # Simplifies mask logic + + print("-" * 60) + print("FORWARD PASS") + + with setup_tf32_env(precision): + # PyTorch Forward + def run_pytorch_fwd(): + return _smooth_lddt_loss_forward_local( + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + is_self_comm, + nucleic_acid_cutoff, + other_cutoff, + multiplicity, + ) + + with benchmark_peak_memory_and_runtime(num_warmup=3, num_iter=10) as stats: + # Warmup + for _ in range(stats["warmup_iters"]): + run_pytorch_fwd() + torch.cuda.synchronize() + + # Timed runs + start_time = time.time() + for _ in range(stats["timed_iters"]): + run_pytorch_fwd() + torch.cuda.synchronize() + end_time = time.time() + + stats["time"] = ((end_time - start_time) / stats["timed_iters"]) * 1000 + + print( + f"PyTorch Forward: Peak Memory = {stats['peak_mem']:.2f} MB, Time = {stats['time']:.2f} ms " + f"(avg of {stats['timed_iters']} runs)" + ) + + if HAS_TRITON: + # Triton Forward + def run_triton_fwd(): + return _smooth_lddt_loss_local_triton_forward( + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + is_self_comm, + nucleic_acid_cutoff, + other_cutoff, + multiplicity, + ) + + print(" (First run may be slow due to kernel compilation and autotuning...)") + with benchmark_peak_memory_and_runtime(num_warmup=10, num_iter=10) as stats: + # Warmup (includes autotuning on first run) + for i in range(stats["warmup_iters"]): + if i == 0: + print(" Warming up (compiling/autotuning)...", end="", flush=True) + run_triton_fwd() + if i == 0: + print(" done") + torch.cuda.synchronize() + + # Timed runs + start_time = time.time() + for _ in range(stats["timed_iters"]): + run_triton_fwd() + torch.cuda.synchronize() + end_time = time.time() + + stats["time"] = ((end_time - start_time) / stats["timed_iters"]) * 1000 + + print( + f"Triton Forward : Peak Memory = {stats['peak_mem']:.2f} MB, Time = {stats['time']:.2f} ms " + f"(avg of {stats['timed_iters']} runs)" + ) + else: + print("Triton not available, skipping Triton Forward") + + print("-" * 60) + print("BACKWARD PASS") + + # Inputs for backward + grad_num_reduced = torch.randn(B_local * multiplicity, device=device, dtype=dtype) + grad_den_reduced = torch.randn(B_local * multiplicity, device=device, dtype=dtype) + + # PyTorch Backward + def run_pytorch_bwd(): + return _smooth_lddt_loss_backward_local( + grad_num_reduced, + grad_den_reduced, + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + is_self_comm, + nucleic_acid_cutoff, + other_cutoff, + multiplicity, + ) + + with benchmark_peak_memory_and_runtime(num_warmup=3, num_iter=10) as stats: + # Warmup + for _ in range(stats["warmup_iters"]): + run_pytorch_bwd() + torch.cuda.synchronize() + + # Timed runs + start_time = time.time() + for _ in range(stats["timed_iters"]): + run_pytorch_bwd() + torch.cuda.synchronize() + end_time = time.time() + + stats["time"] = ((end_time - start_time) / stats["timed_iters"]) * 1000 + + print( + f"PyTorch Backward: Peak Memory = {stats['peak_mem']:.2f} MB, Time = {stats['time']:.2f} ms " + f"(avg of {stats['timed_iters']} runs)" + ) + + if HAS_TRITON: + # Triton Backward + def run_triton_bwd(): + return _smooth_lddt_loss_local_triton_backward( + grad_num_reduced, + grad_den_reduced, + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + is_self_comm, + nucleic_acid_cutoff, + other_cutoff, + multiplicity, + ) + + print(" (First run may be slow due to kernel compilation and autotuning...)") + with benchmark_peak_memory_and_runtime(num_warmup=10, num_iter=10) as stats: + # Warmup (includes autotuning on first run) + for i in range(stats["warmup_iters"]): + if i == 0: + print(" Warming up (compiling/autotuning)...", end="", flush=True) + run_triton_bwd() + if i == 0: + print(" done") + torch.cuda.synchronize() + + # Timed runs + start_time = time.time() + for _ in range(stats["timed_iters"]): + run_triton_bwd() + torch.cuda.synchronize() + end_time = time.time() + + stats["time"] = ((end_time - start_time) / stats["timed_iters"]) * 1000 + + print( + f"Triton Backward : Peak Memory = {stats['peak_mem']:.2f} MB, Time = {stats['time']:.2f} ms " + f"(avg of {stats['timed_iters']} runs)" + ) + else: + print("Triton not available, skipping Triton Backward") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--B_local", type=int, default=1) + parser.add_argument("--N_atoms", type=int, default=9 * 512) + parser.add_argument("--multiplicity", type=int, default=16) + parser.add_argument("--dtype", type=str, default="fp32", choices=["fp32", "bf16", "tf32"]) + + args = parser.parse_args() + + if not torch.cuda.is_available(): + print("CUDA not available, exiting") + exit(0) + + run_benchmark(args.B_local, args.N_atoms, args.multiplicity, args.dtype) diff --git a/tests/distributed/model/loss/test_cdist_lddt_triton.py b/tests/distributed/model/loss/test_cdist_lddt_triton.py new file mode 100644 index 000000000..f0e5fd367 --- /dev/null +++ b/tests/distributed/model/loss/test_cdist_lddt_triton.py @@ -0,0 +1,430 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import os +import re +import subprocess +from pathlib import Path + +import pytest +import torch + +from boltz.distributed.model.loss.triton.cdist_lddt import cdist_lddt +from boltz.model.loss.confidencev2 import lddt_dist +from boltz.testing.utils import init_tensors_uniform + + +def lddt_dist_reference( + pred_coords_row, + pred_coords_col, + true_coords_row, + true_coords_col, + resolved_mask_row, + resolved_mask_col, + atom_indices_row=None, # [B, N_row] + atom_indices_col=None, # [B, N_col] + cutoff=15.0, + cutoff_col=None, # [B, N_col] (optional, per-column per-batch cutoff values) + do_mask_diagonal=True, + per_atom=False, + return_denom=False, +): + """Reference implementation using torch.cdist and lddt_dist""" + B_mul, N_row, _ = pred_coords_row.shape + _, N_col, _ = pred_coords_col.shape + B, _ = resolved_mask_row.shape + multiplicity = B_mul // B + device = pred_coords_row.device + + # Broadcast masks + mask_row = resolved_mask_row.repeat_interleave(multiplicity, dim=0) + mask_col = resolved_mask_col.repeat_interleave(multiplicity, dim=0) + + # Pair mask: (B_mul, N_row, N_col) + pair_mask = mask_row.unsqueeze(-1) * mask_col.unsqueeze(-2) + + # Diagonal mask (conditional on do_mask_diagonal) + if do_mask_diagonal: + # Default to arange if indices not provided + # atom_indices are [B, N_row] and [B, N_col], broadcast to B_mul + if atom_indices_row is not None: + idx_row = atom_indices_row.repeat_interleave(multiplicity, dim=0) # [B_mul, N_row] + else: + idx_row = torch.arange(N_row, device=device).unsqueeze(0).expand(B_mul, -1) # [B_mul, N_row] + if atom_indices_col is not None: + idx_col = atom_indices_col.repeat_interleave(multiplicity, dim=0) # [B_mul, N_col] + else: + idx_col = torch.arange(N_col, device=device).unsqueeze(0).expand(B_mul, -1) # [B_mul, N_col] + # is_diagonal: [B_mul, N_row, N_col] + is_diagonal = idx_row.unsqueeze(-1) == idx_col.unsqueeze(-2) + pair_mask = pair_mask * (~is_diagonal) + + dmat_pred = torch.cdist(pred_coords_row, pred_coords_col) + dmat_true = torch.cdist(true_coords_row, true_coords_col) + + # Compute cutoff tensor: if cutoff_col is provided, broadcast to [B_mul, N_row, N_col] + if cutoff_col is not None: + # cutoff_col is [B, N_col], broadcast to [B_mul, 1, N_col] then to [B_mul, N_row, N_col] + cutoff_expanded = cutoff_col.repeat_interleave(multiplicity, dim=0).unsqueeze(1) # [B_mul, 1, N_col] + cutoff_tensor = cutoff_expanded.expand(-1, N_row, -1) # [B_mul, N_row, N_col] + else: + cutoff_tensor = cutoff + + dists_to_score = (dmat_true < cutoff_tensor).float() * pair_mask + + # Use existing reference implementation + # lddt_dist expects [B, N, N] inputs + result = lddt_dist(dmat_pred, dmat_true, pair_mask, cutoff=cutoff_tensor, per_atom=per_atom) + + if per_atom: + score, mask_no_match = result + if return_denom: + denom = torch.sum(dists_to_score, dim=-1) + return score, mask_no_match, denom + return score, mask_no_match + + score, _total = result + if return_denom: + denom = torch.sum(dists_to_score, dim=(-2, -1)) + return score, denom + return score + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +# Order: last decorator = outermost loop (slowest varying) +# Put constexpr params (per_atom, do_mask_diagonal, return_unnormalized_score) last to minimize Triton recompilations +@pytest.mark.parametrize("B, multiplicity", [(2, 4)], ids=["B:2, M:4"]) +@pytest.mark.parametrize("N_row", [32, 100], ids=lambda x: f"Nr:{x}") +@pytest.mark.parametrize("N_col", [32, 100], ids=lambda x: f"Nc:{x}") +@pytest.mark.parametrize("use_indices_row", [True, False], ids=lambda x: f"IdxR:{x}") +@pytest.mark.parametrize("use_indices_col", [True, False], ids=lambda x: f"IdxC:{x}") +@pytest.mark.parametrize("use_cutoff_col", [True, False], ids=lambda x: f"CutCol:{x}") +@pytest.mark.parametrize("do_mask_diagonal", [True, False], ids=lambda x: f"DiagMask:{x}") +@pytest.mark.parametrize("per_atom", [False, True], ids=lambda x: f"PerAtom:{x}") +@pytest.mark.parametrize("return_unnormalized_score", [False, True], ids=lambda x: f"Unnorm:{x}") +@pytest.mark.parametrize("return_denom", [True, False], ids=lambda x: f"Denom:{x}") +@pytest.mark.parametrize( + "dtype", + [torch.bfloat16, torch.float32, torch.float64], + ids=lambda x: f"Dtype:{x}", +) +def test_cdist_lddt_correctness( + B, + multiplicity, + N_row, + N_col, + use_indices_row, + use_indices_col, + use_cutoff_col, + do_mask_diagonal, + per_atom, + return_unnormalized_score, + return_denom, + dtype, +): + if torch.promote_types(dtype, torch.float32) != dtype: + pytest.xfail(f"cdist_lddt requires at least float32 precision but got {dtype}") + + # return_unnormalized_score and per_atom are orthogonal options: + # - per_atom controls output shape: [B_mul, N_row] vs [B_mul] + # - return_unnormalized_score controls whether to return raw (out_num, out_denom) vs normalized score + + device = torch.device("cuda") + B_mul = B * multiplicity + + # Value range for coordinate initialization (controls numerical stability) + min_val_init = -0.5 + max_val_init = 0.5 + + # Generate random data with controlled value range + pred_coords_row = torch.empty(B_mul, N_row, 3, device=device, dtype=dtype) + pred_coords_col = torch.empty(B_mul, N_col, 3, device=device, dtype=dtype) + true_coords_row = torch.empty(B_mul, N_row, 3, device=device, dtype=dtype) + true_coords_col = torch.empty(B_mul, N_col, 3, device=device, dtype=dtype) + init_tensors_uniform( + [pred_coords_row, pred_coords_col, true_coords_row, true_coords_col], + low=min_val_init, + high=max_val_init, + ) + + resolved_mask_row = torch.randint(0, 2, (B, N_row), device=device, dtype=dtype) + resolved_mask_col = torch.randint(0, 2, (B, N_col), device=device, dtype=dtype) + + # Each index can independently be explicit or implicit (arange) + # atom_indices have shape [B, N_row] and [B, N_col] (batch dimension B, not B_mul) + # When the other dimension is larger, use a random non-duplicated subset from that + # dimension to cover the rectangular subset case from minimum_lddt_symmetry_coords + atom_indices_row = None + atom_indices_col = None + + if use_indices_row: + if N_col > N_row: + # Random non-duplicated subset of arange(N_col) with size N_row, per batch + # Each batch sample gets a different random permutation + atom_indices_row = torch.stack( + [torch.randperm(N_col, device=device)[:N_row].sort().values for _ in range(B)] + ) # [B, N_row] + else: + # N_row >= N_col: use simple arange(N_row) for each batch + atom_indices_row = torch.arange(N_row, device=device).unsqueeze(0).expand(B, -1).contiguous() # [B, N_row] + + if use_indices_col: + if N_row > N_col: + # Random non-duplicated subset of arange(N_row) with size N_col + # Each batch sample gets a different random permutation + atom_indices_col = torch.stack( + [torch.randperm(N_row, device=device)[:N_col].sort().values for _ in range(B)] + ) # [B, N_col] + else: + # N_col >= N_row: use simple arange(N_col) for each batch + atom_indices_col = torch.arange(N_col, device=device).unsqueeze(0).expand(B, -1).contiguous() # [B, N_col] + + # Generate cutoff_col with shape [B, N_col] if use_cutoff_col is True + # Coordinates are in [-0.5, 0.5], so max distance is ~sqrt(3) ≈ 1.73 + # Use cutoff values in [0.3, 1.2] range to meaningfully filter distances + cutoff_col = None + if use_cutoff_col: + # Generate random cutoff values per batch, per column + cutoff_col = torch.empty(B, N_col, device=device).uniform_(0.3, 1.2) + + if return_unnormalized_score and return_denom: + pytest.skip("return_denom is invalid when return_unnormalized_score=True") + + # Reference + ref_result = lddt_dist_reference( + pred_coords_row, + pred_coords_col, + true_coords_row, + true_coords_col, + resolved_mask_row, + resolved_mask_col, + atom_indices_row, + atom_indices_col, + cutoff_col=cutoff_col, + do_mask_diagonal=do_mask_diagonal, + per_atom=per_atom, + return_denom=return_denom, + ) + + # Triton + triton_result = cdist_lddt( + pred_coords_row, + pred_coords_col, + true_coords_row, + true_coords_col, + resolved_mask_row, + resolved_mask_col, + multiplicity, + atom_indices_row=atom_indices_row, + atom_indices_col=atom_indices_col, + cutoff_col=cutoff_col, + do_mask_diagonal=do_mask_diagonal, + per_atom=per_atom, + return_unnormalized_score=return_unnormalized_score, + return_denom=return_denom, + ) + + eps = 1e-10 + if return_unnormalized_score: + # return_unnormalized_score returns unnormalized scores before normalization + # We verify by manually computing the normalized result and comparing to reference + if per_atom: + out_num, out_denom, mask_no_match_triton = triton_result + norm = 1.0 / (eps + out_denom) + computed_score = norm * (eps + out_num) + score_ref, mask_no_match_ref = ref_result + torch.testing.assert_close(computed_score, score_ref) + torch.testing.assert_close(mask_no_match_triton.to(mask_no_match_ref.dtype), mask_no_match_ref) + else: + out_num, out_denom = triton_result + score_ref = ref_result + computed_score = out_num / (out_denom + eps) + computed_score = torch.where(out_denom > 0, computed_score, torch.zeros_like(computed_score)) + torch.testing.assert_close(computed_score, score_ref) + else: + if per_atom and return_denom: + score_ref, mask_no_match_ref, denom_ref = ref_result + score_triton, mask_no_match_triton, denom_triton = triton_result + torch.testing.assert_close(score_triton.to(score_ref.dtype), score_ref) + torch.testing.assert_close(mask_no_match_triton.to(mask_no_match_ref.dtype), mask_no_match_ref) + torch.testing.assert_close(denom_triton.to(denom_ref.dtype), denom_ref) + elif per_atom and not return_denom: + score_ref, mask_no_match_ref = ref_result + score_triton, mask_no_match_triton = triton_result + torch.testing.assert_close(score_triton.to(score_ref.dtype), score_ref) + torch.testing.assert_close(mask_no_match_triton.to(mask_no_match_ref.dtype), mask_no_match_ref) + elif not per_atom and return_denom: + score_ref, denom_ref = ref_result + score_triton, denom_triton = triton_result + torch.testing.assert_close(score_triton.to(score_ref.dtype), score_ref) + torch.testing.assert_close(denom_triton.to(denom_ref.dtype), denom_ref) + else: # not per_atom and not return_denom + score_ref = ref_result + score_triton = triton_result + torch.testing.assert_close(score_triton.to(score_ref.dtype), score_ref) + + +@pytest.fixture( + params=[ + # (modifications_dict, expected_error_pattern) + ({}, None), # valid inputs + ({"mask_row": (3, 32), "mask_col": (3, 48)}, "mask_row batch dimension"), # Neither B nor B_mul + ({"coord_dim": 4}, "Coordinate dimension must be 3"), + ({"pred_coords_col": (4, 48, 3)}, "pred_coords_col shape"), + ({"true_coords_row": (8, 64, 3)}, "true_coords_row shape"), + ({"true_coords_col": (8, 64, 3)}, "true_coords_col shape"), + ({"mask_row": (2, 64)}, "mask_row N dimension"), + ({"mask_col": (2, 64)}, "mask_col N dimension"), + ( + {"mask_row": (2, 32), "mask_col": (4, 48)}, + "mask_col batch dimension", + ), # mask_col batch is 4, neither B=2 nor B_mul=8 + ({"atom_indices_row": (2, 64)}, "atom_indices_row shape"), + ({"atom_indices_col": (2, 64)}, "atom_indices_col shape"), + ({"atom_indices_row": None, "atom_indices_col": None}, None), # None indices valid + ( + {"return_unnormalized_score": True, "return_denom": True}, + "return_denom is not valid when return_unnormalized_score=True", + ), + ], + ids=[ + "valid_inputs", + "mask_batch_neither_B_nor_B_mul", + "coord_dim_not_3", + "pred_coords_col_wrong_batch", + "true_coords_row_wrong_n_row", + "true_coords_col_wrong_n_col", + "mask_row_wrong_n_row", + "mask_col_wrong_n_col", + "mask_col_batch_neither_B_nor_B_mul", + "atom_indices_row_wrong_shape", + "atom_indices_col_wrong_shape", + "none_indices_valid", + "return_unnormalized_score_and_denom_invalid", + ], +) +def validation_case(request): + """Fixture providing (modifications, expected_error) for validation tests""" + return request.param + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cdist_lddt_validation(validation_case): + """Test input validation for cdist_lddt()""" + modifications, expected_error = validation_case + device = torch.device("cuda") + + # Create valid inputs inline + B, multiplicity, N_row, N_col = 2, 4, 32, 48 + B_mul = B * multiplicity + coord_dim = modifications.get("coord_dim", 3) + + rng = torch.Generator(device=device) + rng.manual_seed(0) + inputs = { + "pred_coords_row": torch.randn(B_mul, N_row, coord_dim, device=device, generator=rng), + "pred_coords_col": torch.randn(B_mul, N_col, coord_dim, device=device, generator=rng), + "true_coords_row": torch.randn(B_mul, N_row, coord_dim, device=device, generator=rng), + "true_coords_col": torch.randn(B_mul, N_col, coord_dim, device=device, generator=rng), + "mask_row": torch.ones(B, N_row, device=device), + "mask_col": torch.ones(B, N_col, device=device), + "atom_indices_row": torch.arange(N_row, device=device).unsqueeze(0).expand(B, -1).contiguous(), + "atom_indices_col": torch.arange(N_col, device=device).unsqueeze(0).expand(B, -1).contiguous(), + "return_unnormalized_score": False, + "return_denom": False, + } + + # Apply modifications + for key, val in modifications.items(): + if key == "coord_dim": + continue # Already handled above + elif key in ("mask_row", "mask_col"): + inputs[key] = torch.ones(*val, device=device) + elif key in ("pred_coords_row", "pred_coords_col", "true_coords_row", "true_coords_col"): + inputs[key] = torch.randn(*val, device=device) + elif key in ("atom_indices_row", "atom_indices_col"): + if val is None: + inputs[key] = None + elif isinstance(val, tuple): + # val is (B, N) shape + inputs[key] = torch.arange(val[1], device=device).unsqueeze(0).expand(val[0], -1).contiguous() + else: + # legacy: val is just N + inputs[key] = torch.arange(val, device=device).unsqueeze(0).expand(B, -1).contiguous() + elif key in ("return_unnormalized_score", "return_denom"): + inputs[key] = val + + if expected_error is None: + result = cdist_lddt(**inputs, multiplicity=multiplicity) + assert result.shape == (inputs["pred_coords_row"].shape[0],) + else: + with pytest.raises(ValueError, match=expected_error): + cdist_lddt(**inputs, multiplicity=multiplicity) + + +def assert_no_register_spilling(path_to_ptx_file: Path): + ptx_code = path_to_ptx_file.read_text() + + # get the ".target sm_{arch}a" directive from the ptx code + sm_arch_match = re.search(r"\.target (sm_\w+)", ptx_code) + if not sm_arch_match: + raise RuntimeError(f"No .target directive found in {path_to_ptx_file}") + sm_arch = sm_arch_match.group(1) + + # Run ptxas + ptxas_path = os.environ.get("TRITON_PTXAS_PATH", "ptxas") + + cmd = [ptxas_path, "-v", f"--gpu-name={sm_arch}", str(path_to_ptx_file)] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + output = result.stderr + if "0 bytes spill stores, 0 bytes spill loads" not in output: + raise RuntimeError(f"Register spilling detected in {path_to_ptx_file}:\n{output}") + except subprocess.CalledProcessError as e: + raise RuntimeError(f"ptxas failed with error:\n{e.stderr}") from e + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_no_register_spilling(tmp_path, monkeypatch): + # Setup env for dumping PTX + monkeypatch.setenv("TRITON_KERNEL_DUMP", "1") + monkeypatch.setenv("TRITON_DUMP_DIR", str(tmp_path)) + monkeypatch.setenv("TRITON_ALWAYS_COMPILE", "1") + monkeypatch.setenv("TRITON_CACHE_DIR", str(tmp_path / "cache")) + monkeypatch.setenv("TRITON_PTXAS_PATH", os.environ.get("TRITON_PTXAS_PATH", "ptxas")) + + device = torch.device("cuda") + B_mul, B, N = 16, 1, 100 + + pred_coords = torch.randn(B_mul, N, 3, device=device) + true_coords = torch.randn(B_mul, N, 3, device=device) + mask = torch.ones(B, N, device=device) + + # Run kernel (multiplicity = B_mul // B = 16 // 1 = 16) + cdist_lddt(pred_coords, pred_coords, true_coords, true_coords, mask, mask, multiplicity=B_mul // B) + + # Check PTX + ptx_files = list(tmp_path.glob("**/_cdist_lddt_kernel.ptx")) + if not ptx_files: + raise RuntimeError(f"No PTX file found in {tmp_path}") + + assert_no_register_spilling(ptx_files[0]) diff --git a/tests/distributed/model/loss/test_cdist_pde_triton.py b/tests/distributed/model/loss/test_cdist_pde_triton.py new file mode 100644 index 000000000..9d2aae1c1 --- /dev/null +++ b/tests/distributed/model/loss/test_cdist_pde_triton.py @@ -0,0 +1,378 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import os +import re +import subprocess +from pathlib import Path + +import pytest +import torch +import torch.nn.functional as F + +from boltz.distributed.model.loss.triton.cdist_pde import cdist_pde +from boltz.testing.utils import init_tensors_uniform + + +def cdist_pde_reference( + pred_pde, + true_coords_row, + true_coords_col, + pred_coords_row, + pred_coords_col, + mask_row, + mask_col, + multiplicity, + num_bins=64, + max_dist=32.0, +): + """ + Reference implementation for PDE cross-entropy loss. + + This implements the equivalent computation as the Triton kernel but using + standard PyTorch operations, materializing the full distance matrices. + + Returns fully summed outputs [B_mul] (sum over both row and column dimensions). + """ + B_mul, N_row, N_col, _ = pred_pde.shape + B = B_mul // multiplicity + + # Broadcast masks to B_mul if needed + if mask_row.shape[0] == B: + mask_row_expanded = mask_row.repeat_interleave(multiplicity, dim=0) + else: + mask_row_expanded = mask_row + + if mask_col.shape[0] == B: + mask_col_expanded = mask_col.repeat_interleave(multiplicity, dim=0) + else: + mask_col_expanded = mask_col + + # Compute pair mask [B_mul, N_row, N_col] + mask = mask_row_expanded.unsqueeze(-1) * mask_col_expanded.unsqueeze(-2) + + # Compute distances + true_d = torch.cdist(true_coords_row, true_coords_col) + pred_d = torch.cdist(pred_coords_row, pred_coords_col) + target_pde = torch.abs(true_d - pred_d) + + # Compute bin indices + bin_index = torch.floor(target_pde * num_bins / max_dist).long() + bin_index = torch.clamp(bin_index, max=(num_bins - 1)) + + # Compute cross-entropy + pde_one_hot = F.one_hot(bin_index, num_classes=num_bins).float() + log_probs = F.log_softmax(pred_pde, dim=-1) + errors = -1 * torch.sum(pde_one_hot * log_probs, dim=-1) + + # Full sum over both row and column dimensions + out_loss_num = torch.sum(errors * mask, dim=(-2, -1)) # [B_mul] + out_mask_denom = torch.sum(mask, dim=(-2, -1)) # [B_mul] + + return out_loss_num, out_mask_denom + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("B", [1, 2], ids=lambda x: f"B:{x}") +@pytest.mark.parametrize("multiplicity", [1, 4], ids=lambda x: f"M:{x}") +@pytest.mark.parametrize("N_row", [32, 64], ids=lambda x: f"Nr:{x}") +@pytest.mark.parametrize("N_col", [32, 64], ids=lambda x: f"Nc:{x}") +@pytest.mark.parametrize("mask_row_has_mul", [False, True], ids=lambda x: f"MaskRowMul:{x}") +@pytest.mark.parametrize("mask_col_has_mul", [False, True], ids=lambda x: f"MaskColMul:{x}") +@pytest.mark.parametrize("all_zero_mask", [False, True], ids=lambda x: f"AllZeroMask:{x}") +def test_cdist_pde_correctness(B, multiplicity, N_row, N_col, mask_row_has_mul, mask_col_has_mul, all_zero_mask): + """Test forward and backward pass correctness against reference implementation.""" + device = torch.device("cuda") + B_mul = B * multiplicity + num_bins = 64 + max_dist = 32.0 + min_val, max_val = -1.0, 1.0 + + # Generate random coordinates using init_tensors_uniform + true_coords_row = torch.empty(B_mul, N_row, 3, device=device) + true_coords_col = torch.empty(B_mul, N_col, 3, device=device) + pred_coords_row = torch.empty(B_mul, N_row, 3, device=device) + pred_coords_col = torch.empty(B_mul, N_col, 3, device=device) + init_tensors_uniform( + [true_coords_row, true_coords_col, pred_coords_row, pred_coords_col], + low=min_val, + high=max_val, + ) + + # pred_pde needs gradient for backward test + pred_pde_ref = torch.empty(B_mul, N_row, N_col, num_bins, device=device) + init_tensors_uniform([pred_pde_ref], low=min_val, high=max_val) + pred_pde_ref.requires_grad_(True) + pred_pde_triton = pred_pde_ref.detach().clone().requires_grad_(True) + + # Create masks with specified shapes + if all_zero_mask: + # All-zero masks for edge case testing + mask_row_shape = (B_mul, N_row) if mask_row_has_mul else (B, N_row) + mask_col_shape = (B_mul, N_col) if mask_col_has_mul else (B, N_col) + mask_row = torch.zeros(mask_row_shape, device=device) + mask_col = torch.zeros(mask_col_shape, device=device) + else: + # Random binary masks + if mask_row_has_mul: + mask_row = torch.randint(0, 2, (B_mul, N_row), device=device).float() + else: + mask_row = torch.randint(0, 2, (B, N_row), device=device).float() + + if mask_col_has_mul: + mask_col = torch.randint(0, 2, (B_mul, N_col), device=device).float() + else: + mask_col = torch.randint(0, 2, (B, N_col), device=device).float() + + # Ensure at least some masks are 1 for gradient flow + mask_row.view(-1)[0] = 1.0 + mask_col.view(-1)[0] = 1.0 + + # ===== Forward pass test ===== + # Reference computation + ref_loss_num, ref_mask_denom = cdist_pde_reference( + pred_pde_ref, + true_coords_row, + true_coords_col, + pred_coords_row, + pred_coords_col, + mask_row, + mask_col, + multiplicity, + num_bins, + max_dist, + ) + + # Triton computation + triton_loss_num, triton_mask_denom = cdist_pde( + pred_pde_triton, + true_coords_row, + true_coords_col, + pred_coords_row, + pred_coords_col, + mask_row, + mask_col, + multiplicity, + num_bins, + max_dist, + ) + + # Compare forward outputs + torch.testing.assert_close(triton_loss_num, ref_loss_num) + torch.testing.assert_close(triton_mask_denom, ref_mask_denom) + + # ===== Backward pass test ===== + if all_zero_mask: + # With all-zero mask, outputs should be zero, skip backward test + assert torch.all(triton_loss_num == 0) + assert torch.all(triton_mask_denom == 0) + return + + # Create upstream gradient (mock adjoint) using init_tensors_uniform + grad_out = torch.empty_like(ref_loss_num) + init_tensors_uniform([grad_out], low=min_val, high=max_val) + + # Backward pass with upstream gradient + ref_loss_num.backward(grad_out) + triton_loss_num.backward(grad_out) + + # Compare gradients + assert pred_pde_ref.grad is not None + assert pred_pde_triton.grad is not None + torch.testing.assert_close(pred_pde_triton.grad, pred_pde_ref.grad) + + +@pytest.fixture( + params=[ + # (modifications_dict, expected_error_pattern) + ({}, None), # valid inputs + ({"pred_pde": (4, 32, 32, 32)}, "pred_pde num_bins mismatch"), # wrong num_bins + ({"coord_dim": 4}, "Coordinate dimension must be 3"), + ({"true_coords_row": (4, 64, 3)}, "true_coords_row shape mismatch"), + ({"true_coords_col": (4, 64, 3)}, "true_coords_col shape mismatch"), + ({"pred_coords_row": (4, 64, 3)}, "pred_coords_row shape mismatch"), + ({"pred_coords_col": (4, 64, 3)}, "pred_coords_col shape mismatch"), + ({"mask_row": (3, 32)}, "mask_row batch dimension"), # Neither B nor B_mul + ({"mask_col": (3, 32)}, "mask_col batch dimension"), # Neither B nor B_mul + ({"mask_row": (2, 64)}, "mask_row N dimension"), # Wrong N_row + ({"mask_col": (2, 64)}, "mask_col N dimension"), # Wrong N_col + # requires_grad validation (gradient flow is broken by bin_index computation) + ({"true_coords_row_requires_grad": True}, "true_coords_row should not require gradients"), + ({"true_coords_col_requires_grad": True}, "true_coords_col should not require gradients"), + ({"pred_coords_row_requires_grad": True}, "pred_coords_row should not require gradients"), + ({"pred_coords_col_requires_grad": True}, "pred_coords_col should not require gradients"), + ({"mask_row_requires_grad": True}, "mask_row should not require gradients"), + ({"mask_col_requires_grad": True}, "mask_col should not require gradients"), + ], + ids=[ + "valid_inputs", + "pred_pde_wrong_num_bins", + "coord_dim_not_3", + "true_coords_row_wrong_shape", + "true_coords_col_wrong_shape", + "pred_coords_row_wrong_shape", + "pred_coords_col_wrong_shape", + "mask_row_batch_neither_B_nor_B_mul", + "mask_col_batch_neither_B_nor_B_mul", + "mask_row_wrong_n_row", + "mask_col_wrong_n_col", + "true_coords_row_requires_grad", + "true_coords_col_requires_grad", + "pred_coords_row_requires_grad", + "pred_coords_col_requires_grad", + "mask_row_requires_grad", + "mask_col_requires_grad", + ], +) +def pde_validation_case(request): + """Fixture providing (modifications, expected_error) for validation tests""" + return request.param + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cdist_pde_input_validation(pde_validation_case): + """Test input validation for cdist_pde()""" + modifications, expected_error = pde_validation_case + device = torch.device("cuda") + + # Create valid inputs inline + B, multiplicity, N_row, N_col = 2, 2, 32, 32 + B_mul = B * multiplicity + num_bins = 64 + max_dist = 32.0 + coord_dim = modifications.get("coord_dim", 3) + + inputs = { + "pred_pde": torch.randn(B_mul, N_row, N_col, num_bins, device=device), + "true_coords_row": torch.randn(B_mul, N_row, coord_dim, device=device), + "true_coords_col": torch.randn(B_mul, N_col, coord_dim, device=device), + "pred_coords_row": torch.randn(B_mul, N_row, coord_dim, device=device), + "pred_coords_col": torch.randn(B_mul, N_col, coord_dim, device=device), + "mask_row": torch.ones(B, N_row, device=device), + "mask_col": torch.ones(B, N_col, device=device), + } + + # Apply modifications + for key, val in modifications.items(): + if key == "coord_dim": + continue # Already handled above + elif key in ("mask_row", "mask_col"): + inputs[key] = torch.ones(*val, device=device) + elif key == "pred_pde": + inputs[key] = torch.randn(*val, device=device) + elif key in ("true_coords_row", "true_coords_col", "pred_coords_row", "pred_coords_col"): + inputs[key] = torch.randn(*val, device=device) + elif key.endswith("_requires_grad"): + # Handle requires_grad modifications + tensor_key = key.replace("_requires_grad", "") + inputs[tensor_key] = inputs[tensor_key].requires_grad_(val) + + if expected_error is None: + loss_num, mask_denom = cdist_pde( + **inputs, + multiplicity=multiplicity, + num_bins=num_bins, + max_dist=max_dist, + ) + # Kernel now outputs fully summed [B_mul] instead of [B_mul, N_row] + assert loss_num.shape == (B_mul,) + assert mask_denom.shape == (B_mul,) + else: + with pytest.raises(ValueError, match=expected_error): + cdist_pde( + **inputs, + multiplicity=multiplicity, + num_bins=num_bins, + max_dist=max_dist, + ) + + +def assert_no_register_spilling(path_to_ptx_file: Path): + """Check that a PTX file shows no register spilling.""" + ptx_code = path_to_ptx_file.read_text() + + # Get the target architecture + sm_arch_match = re.search(r"\.target (sm_\w+)", ptx_code) + if not sm_arch_match: + raise RuntimeError(f"No .target directive found in {path_to_ptx_file}") + sm_arch = sm_arch_match.group(1) + + # Run ptxas + ptxas_path = os.environ.get("TRITON_PTXAS_PATH", "ptxas") + cmd = [ptxas_path, "-v", f"--gpu-name={sm_arch}", str(path_to_ptx_file)] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + output = result.stderr + if "0 bytes spill stores, 0 bytes spill loads" not in output: + raise RuntimeError(f"Register spilling detected in {path_to_ptx_file}:\n{output}") + except subprocess.CalledProcessError as e: + raise RuntimeError(f"ptxas failed with error:\n{e.stderr}") from e + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_no_register_spilling(tmp_path, monkeypatch): + """Test that kernels don't spill registers.""" + # Setup env for dumping PTX + monkeypatch.setenv("TRITON_KERNEL_DUMP", "1") + monkeypatch.setenv("TRITON_DUMP_DIR", str(tmp_path)) + monkeypatch.setenv("TRITON_ALWAYS_COMPILE", "1") + monkeypatch.setenv("TRITON_CACHE_DIR", str(tmp_path / "cache")) + monkeypatch.setenv("TRITON_PTXAS_PATH", os.environ.get("TRITON_PTXAS_PATH", "ptxas")) + + device = torch.device("cuda") + # Use small problem size (like cdist_lddt test) to ensure quick compilation + B_mul, B = 16, 1 + N = 100 + num_bins = 64 + + pred_pde = torch.randn(B_mul, N, N, num_bins, device=device, requires_grad=True) + true_coords = torch.randn(B_mul, N, 3, device=device) + pred_coords = torch.randn(B_mul, N, 3, device=device) + mask = torch.ones(B, N, device=device) + + # Run forward kernel + loss_num, mask_denom = cdist_pde( + pred_pde, + true_coords, + true_coords, + pred_coords, + pred_coords, + mask, + mask, + multiplicity=B_mul // B, + num_bins=num_bins, + ) + + # Run backward kernel + loss_num.sum().backward() + + # Check PTX files for forward kernel + fwd_ptx_files = list(tmp_path.glob("**/_cdist_pde_fwd_kernel.ptx")) + if not fwd_ptx_files: + raise RuntimeError(f"No forward kernel PTX file found in {tmp_path}") + assert_no_register_spilling(fwd_ptx_files[0]) + + # Check PTX files for backward kernel + bwd_ptx_files = list(tmp_path.glob("**/_cdist_pde_bwd_kernel.ptx")) + if not bwd_ptx_files: + raise RuntimeError(f"No backward kernel PTX file found in {tmp_path}") + assert_no_register_spilling(bwd_ptx_files[0]) diff --git a/tests/distributed/model/loss/test_compute_plddt_mae_triton.py b/tests/distributed/model/loss/test_compute_plddt_mae_triton.py new file mode 100644 index 000000000..35bd4d279 --- /dev/null +++ b/tests/distributed/model/loss/test_compute_plddt_mae_triton.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for compute_plddt_mae_triton. + +Verifies that compute_plddt_mae_triton produces scalar results matching the +serial compute_plddt_mae from boltz.model.loss.validation on a single GPU. +""" + +from __future__ import annotations + +import pytest +import torch + +from boltz.distributed.model.loss.validation import compute_plddt_mae_triton +from boltz.model.loss.validation import compute_plddt_mae +from boltz.testing.utils import random_features + +EXPECTED_PLDDT_KEYS = {"protein", "ligand", "dna", "rna"} + + +def _generate_plddt_test_data(multiplicity, resolved_mask_mode, rng): + """Generate input tensors for plddt MAE tests with batch_size=1.""" + n_tokens = 20 + n_atoms = n_tokens * 20 + feats_host = random_features( + size_batch=1, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=1, + atom_counts_per_token_range=(1, 20), + device=torch.device("cpu"), + float_value_range=(-5.0, 5.0), + selected_keys=[ + "token_to_rep_atom", + "r_set_to_rep_atom", + "atom_to_token", + "mol_type", + "atom_counts_per_token", + ], + rng=rng, + ) + N_atom_actual = feats_host["token_to_rep_atom"].shape[2] + B_mul = multiplicity + pred_coords = torch.randn(B_mul, N_atom_actual, 3, dtype=torch.float32) + true_coords = torch.randn(B_mul, N_atom_actual, 3, dtype=torch.float32) + pred_lddt = torch.rand(B_mul, n_tokens, dtype=torch.float32) + if resolved_mask_mode == "ones": + resolved_mask = torch.ones(B_mul, N_atom_actual, dtype=torch.float32) + elif resolved_mask_mode == "zeros": + resolved_mask = torch.zeros(B_mul, N_atom_actual, dtype=torch.float32) + else: + resolved_mask = torch.randint(0, 2, (B_mul, N_atom_actual)).float() + return feats_host, pred_coords, true_coords, pred_lddt, resolved_mask + + +def _run_comparison(multiplicity, resolved_mask_mode, seed): + """Run triton vs serial comparison and return both results.""" + torch.manual_seed(seed) + rng = torch.Generator(device="cpu") + rng.manual_seed(seed) + + feats_host, pred_coords, true_coords, pred_lddt, resolved_mask = _generate_plddt_test_data( + multiplicity, resolved_mask_mode, rng + ) + + feats = { + "r_set_to_rep_atom": feats_host["r_set_to_rep_atom"], + "mol_type": feats_host["mol_type"], + "atom_to_token": feats_host["atom_to_token"], + "token_to_rep_atom": feats_host["token_to_rep_atom"], + } + + ref_mae, ref_total = compute_plddt_mae( + pred_atom_coords=pred_coords, + feats=feats, + true_atom_coords=true_coords, + pred_lddt=pred_lddt, + true_coords_resolved_mask=resolved_mask, + multiplicity=multiplicity, + ) + + device = torch.device("cuda") + feats_cuda = {k: v.to(device) for k, v in feats.items()} + + triton_mae, triton_total = compute_plddt_mae_triton( + pred_atom_coords=pred_coords.to(device), + feats=feats_cuda, + true_atom_coords=true_coords.to(device), + pred_lddt=pred_lddt.to(device), + true_coords_resolved_mask=resolved_mask.to(device), + multiplicity=multiplicity, + ) + + return ref_mae, ref_total, triton_mae, triton_total + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("multiplicity", [1, 2]) +@pytest.mark.parametrize("resolved_mask_mode", ["ones", "zeros", "partial"]) +def test_compute_plddt_mae_triton(multiplicity, resolved_mask_mode): + """compute_plddt_mae_triton must match serial compute_plddt_mae.""" + ref_mae, ref_total, triton_mae, triton_total = _run_comparison( + multiplicity=multiplicity, + resolved_mask_mode=resolved_mask_mode, + seed=42, + ) + + assert set(triton_mae.keys()) == EXPECTED_PLDDT_KEYS + assert set(triton_total.keys()) == EXPECTED_PLDDT_KEYS + + for key in EXPECTED_PLDDT_KEYS: + if resolved_mask_mode == "zeros": + assert ref_total[key] == 0.0, f"Expected zero total for '{key}'" + torch.testing.assert_close(triton_mae[key].cpu(), ref_mae[key]) + torch.testing.assert_close(triton_total[key].cpu(), ref_total[key]) diff --git a/tests/distributed/model/loss/test_dtensor_bfactor.py b/tests/distributed/model/loss/test_dtensor_bfactor.py new file mode 100644 index 000000000..81ffe9df8 --- /dev/null +++ b/tests/distributed/model/loss/test_dtensor_bfactor.py @@ -0,0 +1,281 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""DTensor-based tests for Boltz-2 B-factor loss. + +Tests the DTensor CP implementation against the serial bfactor_loss_fn. +Maps to: src/boltz/distributed/model/loss/bfactor.py + +Verification checks: + V1: single-proc serial immutability (FW/BW inputs unchanged) + V4: multi-proc FW input tensor values unchanged by FW and BW + V8: multi-proc FW loss value close-to single-proc + V9: multi-proc BW pred gradient values close-to single-proc + +bf16 coverage (CUDA-only): + The bfactor loss uses promote_types (compute in fp32) inside + autocast(enabled=False). The bf16 test verifies that bf16 pred logits + are correctly promoted and the resulting loss/gradient are fp32. +""" + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor +from torch.testing import assert_close + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.loss.bfactor import bfactor_loss as bfactor_loss_dtensor +from boltz.model.loss.bfactor import bfactor_loss_fn as bfactor_loss_serial +from boltz.testing.utils import ( + assert_tensors_identical, + skip_if_cuda_not_avail_or_device_count_less_than_word_size, + spawn_multiprocessing, +) + +SEED = 42 + + +def _assert_unchanged(actual, expected, *, serial=False): + """Shorthand for assert_tensors_identical with standard immutability kwargs.""" + assert_tensors_identical( + actual, + expected, + check_stride=True, + check_grad=False, + check_grad_fn=False, + check_storage_pointer=False, + check_storage_offset=serial, + ) + + +def _worker_bfactor_loss_parity( + rank: int, + pred_on_host: torch.Tensor, + t2ra_on_host: torch.Tensor, + bf_on_host: torch.Tensor, + loss_ref: float, + pred_grad_ref_on_host: torch.Tensor, + grid_group_sizes: dict, + device_type: str, + backend: str, + env_map: dict[str, str] | None = None, +): + """Worker: compare distributed bfactor loss against serial reference. + + Performs V4 (input immutability), V8 (loss parity), V9 (grad parity). + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + dm = DistributedManager() + + single_placements = (Shard(0), Shard(1), Replicate()) + # bfactor is (B, A) — replicate A on CP dims so bmm dimensions match + # (t2ra has A in its last dim which is Replicate on cp1) + atom_placements = (Shard(0), Replicate(), Replicate()) + + pred_dt = distribute_tensor(pred_on_host.to(dm.device), dm.device_mesh_subgroups, single_placements).requires_grad_( + True + ) + t2ra_dt = distribute_tensor(t2ra_on_host.to(dm.device), dm.device_mesh_subgroups, single_placements) + bf_dt = distribute_tensor(bf_on_host.to(dm.device), dm.device_mesh_subgroups, atom_placements) + + # V4 setup: clone inputs for immutability check + pred_dt_clone = pred_dt.detach().clone().requires_grad_(pred_dt.requires_grad) + t2ra_dt_clone = t2ra_dt.detach().clone().requires_grad_(t2ra_dt.requires_grad) + bf_dt_clone = bf_dt.detach().clone().requires_grad_(bf_dt.requires_grad) + + output = {"pbfactor": pred_dt} + feats = {"token_to_rep_atom": t2ra_dt, "bfactor": bf_dt} + + dp_group = dm.device_mesh_subgroups.get_group(0) + cp0_group = dm.device_mesh_subgroups.get_group(1) + cp1_group = dm.device_mesh_subgroups.get_group(2) + + # Forward + loss_dt = bfactor_loss_dtensor( + output, + feats, + device_mesh=dm.device_mesh_subgroups, + dp_group=dp_group, + cp0_group=cp0_group, + cp1_group=cp1_group, + ) + + # V4a: FW inputs unchanged + assert_tensors_identical( + pred_dt.to_local(), + pred_dt_clone.to_local(), + check_grad=False, + check_grad_fn=False, + ) + assert_tensors_identical( + t2ra_dt.to_local(), + t2ra_dt_clone.to_local(), + check_grad=False, + check_grad_fn=False, + ) + assert_tensors_identical( + bf_dt.to_local(), + bf_dt_clone.to_local(), + check_grad=False, + check_grad_fn=False, + ) + + # V8: forward loss parity + loss_val = loss_dt.full_tensor().item() + assert_close( + torch.tensor(loss_val), + torch.tensor(loss_ref), + atol=1e-5, + rtol=1e-5, + msg=lambda m: f"Rank {rank} loss mismatch\n{m}", + ) + + # Backward + loss_dt.backward() + + # V4b: FW inputs unchanged after backward + assert_tensors_identical( + pred_dt.to_local(), + pred_dt_clone.to_local(), + check_grad=False, + check_grad_fn=False, + ) + + # V9: pred gradient parity + assert pred_dt.grad is not None, "pred gradient is None" + pred_grad_full = pred_dt.grad.full_tensor() + assert_close( + pred_grad_full, + pred_grad_ref_on_host.to(dm.device), + atol=5e-5, + rtol=5e-5, + msg=lambda m: f"Rank {rank} pred grad mismatch\n{m}", + ) + + # Loss and gradient should be fp32 regardless of input dtype (promote_types) + assert loss_dt.dtype == torch.float32, f"Loss dtype should be fp32, got {loss_dt.dtype}" + assert pred_dt.grad.dtype == torch.float32, f"Grad dtype should be fp32, got {pred_dt.grad.dtype}" + + # Non-vacuous: gradient must be non-zero (would be zero only if loss is + # independent of pred, which would indicate a broken implementation). + assert ( + pred_grad_full.abs().sum() > 0 + ), "Distributed pred gradient is all-zero — loss is not differentiable w.r.t. pred" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + # CPU dp=2 cp=(3,3): DP + non-power-of-two CP for CPU-only CI + ((2, (3, 3)), True, "cpu", "ENV"), + # CUDA dp=1 cp=(1,1): serial-equivalent sanity check (1 GPU) + ((1, (1, 1)), True, "cuda", "ENV"), + # CUDA dp=2 cp=(1,1): bf16-compatible path (2 GPUs) + ((2, (1, 1)), True, "cuda", "ENV"), + # CUDA dp=1 cp=(2,2): actual CP under CUDA (4 GPUs) + ((1, (2, 2)), True, "cuda", "ENV"), + # CUDA dp=2 cp=(2,2): DP + CP under CUDA (8 GPUs) + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cpu-dp2-cp3x3", "cuda-dp1-cp1x1", "cuda-dp2-cp1x1", "cuda-dp1-cp2x2", "cuda-dp2-cp2x2"], +) +def test_dtensor_bfactor_loss_forward_backward(setup_env): + """BFactor loss: distributed loss and pred gradient match serial reference.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + skip_if_cuda_not_avail_or_device_count_less_than_word_size(device_type=device_type, world_size=world_size) + + B = 2 + bins = 8 + N = 8 * grid_group_sizes["cp"][0] + A = N # one atom per token for simplicity + + with torch.random.fork_rng(devices=[], enabled=True): + torch.manual_seed(SEED) + + # Create pred logits + pred = torch.randn(B, N, bins, requires_grad=True) + pred_copy = pred.detach().clone().requires_grad_(True) + + # Create token_to_rep_atom: identity-like (each token maps to one atom) + t2ra = torch.zeros(B, N, A, dtype=torch.float32) + for b in range(B): + for i in range(N): + t2ra[b, i, i] = 1.0 + + # Create bfactor: realistic values (0-100 range), some zeros + bf = torch.rand(B, A) * 80.0 + bf[:, : A // 4] = 0.0 # first quarter is zero (no bfactor) + + # Serial loss + output_serial = {"pbfactor": pred} + feats_serial = {"token_to_rep_atom": t2ra, "bfactor": bf} + loss_ref = bfactor_loss_serial(output_serial, feats_serial) + + # V1a: serial FW input unchanged + _assert_unchanged(pred, pred_copy, serial=True) + + loss_ref_val = loss_ref.item() + + # Non-vacuous: serial loss must be positive. A zero loss would make + # the forward parity check (V8) and gradient parity check (V9) trivial. + assert loss_ref_val > 0, ( + f"Serial bfactor loss is {loss_ref_val} — test data produced zero " + f"loss, making parity checks vacuous. Verify bf/t2ra test setup." + ) + + # Serial backward + loss_ref.backward() + + # V1b: serial FW input unchanged after backward + _assert_unchanged(pred, pred_copy, serial=True) + + pred_grad_ref = pred.grad.detach().cpu().clone() + + spawn_multiprocessing( + _worker_bfactor_loss_parity, + world_size, + pred.detach().cpu(), + t2ra.detach().cpu(), + bf.detach().cpu(), + loss_ref_val, + pred_grad_ref, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/distributed/model/loss/test_dtensor_confidence_loss.py b/tests/distributed/model/loss/test_dtensor_confidence_loss.py new file mode 100644 index 000000000..3cdfe7574 --- /dev/null +++ b/tests/distributed/model/loss/test_dtensor_confidence_loss.py @@ -0,0 +1,1039 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +"""Tests for DTensor confidence_loss wrapper implementation. + +This module tests the distributed implementation of confidence_loss which aggregates +plddt_loss, pde_loss, and resolved_loss. The tests verify numerical correctness +against the serial implementation and proper gradient computation. + +The confidence_loss wrapper coordinates DTensor placements across sub-loss functions +and returns aggregated scalar losses. +""" + +from math import gcd + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.data.feature.featurizer import BoltzFeaturizer +from boltz.data.module.inference import load_input +from boltz.data.tokenize.boltz import BoltzTokenizer +from boltz.distributed.comm import TransposeComm +from boltz.distributed.data.feature.featurizer_utils import get_num_atoms_tokens +from boltz.distributed.data.utils import distribute_features +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.loss.confidencev2 import ( + compute_frame_pred as dtensor_compute_frame_pred, +) +from boltz.distributed.model.loss.confidencev2 import confidence_loss +from boltz.model.layers.confidence_utils import compute_frame_pred as serial_compute_frame_pred +from boltz.model.loss.confidencev2 import confidence_loss as serial_confidence_loss +from boltz.testing.utils import ( + distribute_atom_features, + random_features, + spawn_multiprocessing, +) + + +def _assert_nontrivial_expected_mask_collinear( + expected_mask_collinear_host: torch.Tensor, + token_pad_mask_host: torch.Tensor, + test_name: str, +) -> None: + """Ensure expected mask_collinear has valid support and non-trivial positives.""" + token_pad_mask_bool = token_pad_mask_host.bool() + if expected_mask_collinear_host.numel() == 0: + raise AssertionError(f"{test_name}: expected_mask_collinear_host is empty") + if not token_pad_mask_bool.any(): + raise AssertionError(f"{test_name}: token_pad_mask has no valid tokens") + + for batch_idx in range(expected_mask_collinear_host.shape[0]): + valid = expected_mask_collinear_host[batch_idx, :, token_pad_mask_bool[batch_idx]] + if valid.numel() == 0: + raise AssertionError(f"{test_name}: batch {batch_idx} has no valid tokens for expected_mask_collinear") + if not valid.any(): + raise AssertionError( + f"{test_name}: batch {batch_idx} expected_mask_collinear on valid tokens is always False" + ) + + +def parallel_assert_compute_frame_pred(rank: int, payload: tuple) -> None: + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + multiplicity, + pred_atom_coords_host, + feats_host, + atom_counts_per_token_host, + expected_frames_idx_host, + expected_mask_collinear_host, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device_mesh = manager.device_mesh_subgroups + + single_repr_placements = (Shard(0), Shard(1), Replicate()) + replicate_placements = (Shard(0), Replicate(), Replicate()) + + size_batch = feats_host["atom_pad_mask"].shape[0] + pred_coords_unflat = pred_atom_coords_host.unflatten(0, (size_batch, multiplicity)) + inputs_atom = { + "atom_counts_per_token": atom_counts_per_token_host.to(dtype=torch.int64), + "pred_atom_coords_0": pred_coords_unflat[:, 0].to(dtype=pred_atom_coords_host.dtype), + "atom_to_token": feats_host["atom_to_token"].to(dtype=pred_atom_coords_host.dtype), + "atom_pad_mask": feats_host["atom_pad_mask"].to(dtype=pred_atom_coords_host.dtype), + "atom_resolved_mask": feats_host["atom_resolved_mask"].to(dtype=pred_atom_coords_host.dtype), + "frames_idx": feats_host["frames_idx"].to(dtype=torch.int64), + } + for i_mul in range(1, multiplicity): + inputs_atom[f"pred_atom_coords_{i_mul}"] = pred_coords_unflat[:, i_mul].to(dtype=pred_atom_coords_host.dtype) + + placements_cp = { + "atom_counts_per_token": (Shard(0), Replicate()), + "atom_to_token": (Shard(0), Replicate()), + "atom_pad_mask": (Shard(0), Replicate()), + "atom_resolved_mask": (Shard(0), Replicate()), + "frames_idx": (Shard(1), Replicate()), + "pred_atom_coords_0": (Shard(0), Replicate()), + } + placements_dp_cp = { + "atom_to_token": (Shard(0), Shard(1), Replicate()), + "atom_pad_mask": (Shard(0), Shard(1), Replicate()), + "atom_resolved_mask": (Shard(0), Shard(1), Replicate()), + "frames_idx": (Shard(0), Shard(1), Replicate()), + "pred_atom_coords_0": (Shard(0), Shard(1), Replicate()), + } + for i_mul in range(1, multiplicity): + placements_cp[f"pred_atom_coords_{i_mul}"] = (Shard(0), Replicate()) + placements_dp_cp[f"pred_atom_coords_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + device_mesh, + manager.group["cp"], + multiplicities={"pred_atom_coords": multiplicity}, + ) + pred_atom_coords = feats_atom["pred_atom_coords"] + frames_idx_true = feats_atom["frames_idx"] + + feats = { + "asym_id": distribute_tensor( + feats_host["asym_id"].to(manager.device), device_mesh=device_mesh, placements=single_repr_placements + ), + "atom_to_token": feats_atom["atom_to_token"], + "atom_pad_mask": feats_atom["atom_pad_mask"], + "atom_resolved_mask": feats_atom["atom_resolved_mask"], + "mol_type": distribute_tensor( + feats_host["mol_type"].to(manager.device), device_mesh=device_mesh, placements=single_repr_placements + ), + "token_pad_mask": distribute_tensor( + feats_host["token_pad_mask"].to(manager.device), device_mesh=device_mesh, placements=single_repr_placements + ), + } + + frames_idx_pred, mask_collinear_pred = dtensor_compute_frame_pred( + pred_atom_coords, + frames_idx_true, + feats, + multiplicity=multiplicity, + ) + + note = "" + try: + frames_idx_pred_local = frames_idx_pred.redistribute(device_mesh, placements=replicate_placements).to_local() + dp_rank = manager.group_rank["dp"] + local_batch_size = size_batch // manager.group["dp"].size() + dp_idx_str = dp_rank * local_batch_size + dp_idx_end = dp_idx_str + local_batch_size + + # Proxy comparison for frame indices: compare geometry implied by indices. + # This avoids coupling test correctness to a specific index-space convention. + pred_atom_coords_local = ( + pred_atom_coords.redistribute(device_mesh, placements=replicate_placements) + .to_local() + .unflatten(0, (local_batch_size, multiplicity)) + ) + expected_atom_coords_local = pred_atom_coords_host.to(manager.device).unflatten(0, (size_batch, multiplicity))[ + dp_idx_str:dp_idx_end + ] + expected_frames_idx_local = expected_frames_idx_host.to(manager.device)[dp_idx_str:dp_idx_end] + + batch_idx = torch.arange(local_batch_size, device=manager.device)[:, None, None, None] + mult_idx = torch.arange(multiplicity, device=manager.device)[None, :, None, None] + + dt_frames = pred_atom_coords_local[batch_idx, mult_idx, frames_idx_pred_local] + host_frames = expected_atom_coords_local[batch_idx, mult_idx, expected_frames_idx_local] + # Use mean frame center as an index-invariant proxy. + dt_frame_centers = dt_frames.mean(dim=-2) + host_frame_centers = host_frames.mean(dim=-2) + + token_pad_mask_non_dtensor = feats_host["token_pad_mask"][dp_idx_str:dp_idx_end].bool().to(manager.device) + token_pad_mask_dtensor = ( + feats["token_pad_mask"].redistribute(device_mesh, placements=replicate_placements).to_local().bool() + ) + for batch_i in range(local_batch_size): + non_mask = token_pad_mask_non_dtensor[batch_i] + dt_mask = token_pad_mask_dtensor[batch_i] + + # Ensure test isn't trivially passing + if non_mask.sum().item() != dt_mask.sum().item(): + raise AssertionError( + "frames_idx proxy token count mismatch: " + f"non-dtensor={non_mask.sum().item()}, dtensor={dt_mask.sum().item()}" + ) + if not non_mask.any(): + raise AssertionError(f"batch {batch_i} has no valid tokens") + + # Compare collected frame center coordinates + torch.testing.assert_close( + dt_frame_centers[batch_i, :, dt_mask].cpu(), + host_frame_centers[batch_i, :, non_mask].cpu(), + ) + + except AssertionError as e: + note += "Test failed when comparing frames_idx_pred: " + str(e) + "\n" + + try: + mask_collinear_full = mask_collinear_pred.full_tensor().cpu() + token_pad_mask_non_dtensor = feats_host["token_pad_mask"].bool().cpu() + token_pad_mask_dtensor = feats["token_pad_mask"].full_tensor().bool().cpu() + expected_mask_collinear_host_cpu = expected_mask_collinear_host.cpu() + + for batch_idx in range(mask_collinear_full.shape[0]): + non_mask = token_pad_mask_non_dtensor[batch_idx] + dt_mask = token_pad_mask_dtensor[batch_idx] + mask_collinear_valid = mask_collinear_full[batch_idx, :, dt_mask] + expected_mask_collinear_valid = expected_mask_collinear_host_cpu[batch_idx, :, non_mask] + + # Ensure test isn't trivially passing + if non_mask.sum().item() != dt_mask.sum().item(): + raise AssertionError( + "mask_collinear token count mismatch: " + f"non-dtensor={non_mask.sum().item()}, dtensor={dt_mask.sum().item()}" + ) + if not mask_collinear_valid.any() or not expected_mask_collinear_valid.any(): + raise AssertionError( + "test can trivially pass on mask_collinear_pred since mask_collinear_pred on valid tokens is always False" + ) + + # Compare mask_collinear + torch.testing.assert_close( + mask_collinear_valid, + expected_mask_collinear_valid, + ) + except AssertionError as e: + note += "Test failed when comparing mask_collinear_pred: " + str(e) + "\n" + + if note: + raise AssertionError(note) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize("multiplicity", [1, 2], ids=["multiplicity=1", "multiplicity=2"]) +@pytest.mark.parametrize("seed", [0, 42], ids=["seed=0", "seed=42"]) +def test_dtensor_compute_frame_pred(setup_env: tuple, multiplicity: int, seed: int) -> None: + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + size_ring = grid_group_sizes["cp"][0] + batch_size = grid_group_sizes["dp"] + n_tokens_per_rank = 20 + n_tokens = size_ring * n_tokens_per_rank + max_atoms_per_token = 18 + n_atoms_per_rank = n_tokens_per_rank * max_atoms_per_token + n_atoms = size_ring * n_atoms_per_rank + + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + + pred_atom_coords = torch.randn( + (batch_size * multiplicity, n_atoms, 3), device=device_type, generator=rng, requires_grad=True + ) + # enforce collinearity by setting some atoms to zero + collinear_mask = torch.randint( + 0, 2, (batch_size * multiplicity, n_atoms), device=device_type, generator=rng, dtype=torch.bool + ) + pred_atom_coords = torch.where(collinear_mask.unsqueeze(-1), pred_atom_coords, torch.zeros_like(pred_atom_coords)) + + rng_features = torch.Generator(device=pred_atom_coords.device) + rng_features.manual_seed(seed) + feats = random_features( + size_batch=batch_size, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=1, + atom_counts_per_token_range=(1, max_atoms_per_token), + device=pred_atom_coords.device, + float_value_range=(-1.0, 1.0), + selected_keys=[ + "asym_id", + "atom_to_token", + "atom_pad_mask", + "atom_resolved_mask", + "frames_idx", + "atom_counts_per_token", + "mol_type", + "token_pad_mask", + ], + rng=rng_features, + ) + atom_pad_mask_bool = feats["atom_pad_mask"].bool() + frames_idx = feats["frames_idx"] + atom_pad_per_token = atom_pad_mask_bool.unsqueeze(1).expand(-1, n_tokens, -1) + frame_atom_valid = torch.gather(atom_pad_per_token, dim=2, index=frames_idx) + assert torch.all(frame_atom_valid), "random_features generated frames_idx pointing to masked atoms" + feats_serial = { + "asym_id": feats["asym_id"], + "atom_to_token": feats["atom_to_token"], + "atom_pad_mask": feats["atom_pad_mask"], + "atom_resolved_mask": feats["atom_resolved_mask"], + "mol_type": feats["mol_type"], + "token_pad_mask": feats["token_pad_mask"], + } + expected_frames_idx_host, expected_mask_collinear_host = serial_compute_frame_pred( + pred_atom_coords, + feats["frames_idx"], + feats_serial, + multiplicity, + ) + _assert_nontrivial_expected_mask_collinear( + expected_mask_collinear_host.detach().cpu(), + feats["token_pad_mask"].detach().cpu(), + "test_dtensor_compute_frame_pred", + ) + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + multiplicity, + pred_atom_coords.detach().clone().cpu(), + {k: v.detach().clone().cpu() for k, v in feats.items()}, + feats["atom_counts_per_token"].detach().clone().cpu(), + expected_frames_idx_host.detach().clone().cpu(), + expected_mask_collinear_host.detach().clone().cpu(), + ) + + spawn_multiprocessing( + parallel_assert_compute_frame_pred, + world_size, + payload, + ) + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((1, (2, 2)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize("multiplicity", [1, 2], ids=["multiplicity=1", "multiplicity=2"]) +@pytest.mark.parametrize("seed", [0, 42], ids=["seed=0", "seed=42"]) +def test_dtensor_compute_frame_pred_real_data_parallel( + setup_env: tuple, + multiplicity: int, + seed: int, + create_preprocessed_handle_boltz1_v1, +) -> None: + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + processed = create_preprocessed_handle_boltz1_v1 + record = processed.manifest.records[0] + input_data = load_input(record, processed.targets_dir, processed.msa_dir) + tokenized = BoltzTokenizer().tokenize(input_data) + n_atoms_raw, n_tokens_raw = get_num_atoms_tokens(tokenized) + + ring = grid_group_sizes["cp"][0] + atoms_per_window = 32 + atom_lcm = ring * atoms_per_window // gcd(ring, atoms_per_window) + max_atoms = ((n_atoms_raw + atom_lcm - 1) // atom_lcm) * atom_lcm + max_tokens = ((n_tokens_raw + ring - 1) // ring) * ring + max_seqs = ring + + feats_single = BoltzFeaturizer().process( + tokenized, + training=False, + max_atoms=max_atoms, + max_tokens=max_tokens, + max_seqs=max_seqs, + pad_to_max_seqs=True, + ) + if not isinstance(feats_single, dict): + raise TypeError("Expected non-sharded feature dict from BoltzFeaturizer.process") + + selected_keys = [ + "asym_id", + "atom_to_token", + "atom_pad_mask", + "atom_resolved_mask", + "frames_idx", + "mol_type", + "token_pad_mask", + ] + feats_single = {k: feats_single[k] for k in selected_keys} + # v1 featurizer doesn't emit atom_counts_per_token; derive from one-hot atom_to_token + feats_single["atom_counts_per_token"] = feats_single["atom_to_token"].sum(dim=0).to(torch.int64) + + batch_size = grid_group_sizes["dp"] + feats = {k: v.unsqueeze(0).repeat_interleave(batch_size, dim=0).to(device_type) for k, v in feats_single.items()} + + n_atoms = feats["atom_pad_mask"].shape[1] + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + pred_atom_coords = torch.randn((batch_size * multiplicity, n_atoms, 3), device=device_type, generator=rng) + collinear_mask = torch.randint( + 0, + 2, + (batch_size * multiplicity, n_atoms), + device=device_type, + generator=rng, + dtype=torch.bool, + ) + pred_atom_coords = torch.where(collinear_mask.unsqueeze(-1), pred_atom_coords, torch.zeros_like(pred_atom_coords)) + feats_serial = { + "asym_id": feats["asym_id"], + "atom_to_token": feats["atom_to_token"], + "atom_pad_mask": feats["atom_pad_mask"], + "atom_resolved_mask": feats["atom_resolved_mask"], + "mol_type": feats["mol_type"], + "token_pad_mask": feats["token_pad_mask"], + } + expected_frames_idx_host, expected_mask_collinear_host = serial_compute_frame_pred( + pred_atom_coords, + feats["frames_idx"], + feats_serial, + multiplicity, + ) + _assert_nontrivial_expected_mask_collinear( + expected_mask_collinear_host.detach().cpu(), + feats["token_pad_mask"].detach().cpu(), + "test_dtensor_compute_frame_pred_real_data_parallel", + ) + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + multiplicity, + pred_atom_coords.detach().clone().cpu(), + {k: v.detach().clone().cpu() for k, v in feats.items()}, + feats["atom_counts_per_token"].detach().clone().cpu(), + expected_frames_idx_host.detach().clone().cpu(), + expected_mask_collinear_host.detach().clone().cpu(), + ) + + spawn_multiprocessing( + parallel_assert_compute_frame_pred, + world_size, + payload, + ) + + +def parallel_assert_confidence_loss( + rank: int, + payload: tuple, +): + """Worker function that runs on each rank to test confidence_loss DTensor implementation. + + This function: + 1. Initializes the distributed environment + 2. Distributes atom and token features using appropriate utilities + 3. Calls the DTensor confidence_loss function + 4. Verifies forward pass matches serial reference + 5. Verifies gradients flow correctly through all logit tensors + """ + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + multiplicity, + alpha_pae, + pred_lddt_global_host, + pred_pde_global_host, + pred_pae_global_host, + pred_resolved_global_host, + pred_atom_coords_global_host, + true_atom_coords_global_host, + true_coords_resolved_mask_global_host, + token_to_rep_atom_global_host, + r_set_to_rep_atom_global_host, + atom_to_token_global_host, + mol_type_global_host, + token_pad_mask_global_host, + atom_counts_per_token_host, + pae_feats_host, + expected_loss_host, + expected_loss_breakdown_host, + expected_grad_pred_lddt_host, + expected_grad_pred_pde_host, + expected_grad_pred_pae_host, + expected_grad_pred_resolved_host, + ) = payload + + # Setup environment variables for this rank + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device_mesh = manager.device_mesh_subgroups + dtype = pred_atom_coords_global_host.dtype + + # Create TransposeComm for redistribute_transpose + comm = TransposeComm(manager.group["cp"], manager.layout_subgroups["cp"]) + + # --- Distribute atom features using distribute_atom_features utility --- + size_batch = token_to_rep_atom_global_host.shape[0] + inputs_atom = { + "atom_counts_per_token": atom_counts_per_token_host.to(dtype=torch.int64), + "token_to_rep_atom": token_to_rep_atom_global_host.to(dtype=dtype), + "r_set_to_rep_atom": r_set_to_rep_atom_global_host.to(dtype=dtype), + "atom_to_token": atom_to_token_global_host.to(dtype=dtype), + } + if pae_feats_host is not None: + inputs_atom["frames_idx"] = pae_feats_host["frames_idx"] + inputs_atom["atom_pad_mask"] = pae_feats_host["atom_pad_mask"].to(dtype=dtype) + inputs_atom["atom_resolved_mask"] = pae_feats_host["atom_resolved_mask"].to(dtype=dtype) + + # Add per-multiplicity coordinates and masks + pred_coords_unflat = pred_atom_coords_global_host.unflatten(0, (size_batch, multiplicity)) + true_coords_unflat = true_atom_coords_global_host.unflatten(0, (size_batch, multiplicity)) + resolved_mask_unflat = true_coords_resolved_mask_global_host.unflatten(0, (size_batch, multiplicity)) + + for i_mul in range(multiplicity): + inputs_atom[f"pred_atom_coords_{i_mul}"] = pred_coords_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"true_atom_coords_{i_mul}"] = true_coords_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"true_coords_resolved_mask_{i_mul}"] = resolved_mask_unflat[:, i_mul].to(dtype=dtype) + + # Define placements for CP submesh and full mesh + placements_cp = { + "atom_counts_per_token": (Shard(0), Replicate()), + "token_to_rep_atom": (Shard(0), Replicate()), + "r_set_to_rep_atom": (Shard(0), Replicate()), + "atom_to_token": (Shard(0), Replicate()), + } + placements_dp_cp = { + "token_to_rep_atom": (Shard(0), Shard(1), Replicate()), + "r_set_to_rep_atom": (Shard(0), Shard(1), Replicate()), + "atom_to_token": (Shard(0), Shard(1), Replicate()), + } + if pae_feats_host is not None: + placements_cp["frames_idx"] = (Shard(1), Replicate()) + placements_cp["atom_pad_mask"] = (Shard(0), Replicate()) + placements_cp["atom_resolved_mask"] = (Shard(0), Replicate()) + placements_dp_cp["frames_idx"] = (Shard(0), Shard(1), Replicate()) + placements_dp_cp["atom_pad_mask"] = (Shard(0), Shard(1), Replicate()) + placements_dp_cp["atom_resolved_mask"] = (Shard(0), Shard(1), Replicate()) + for i_mul in range(multiplicity): + placements_cp[f"pred_atom_coords_{i_mul}"] = (Shard(0), Replicate()) + placements_cp[f"true_atom_coords_{i_mul}"] = (Shard(0), Replicate()) + placements_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Replicate()) + placements_dp_cp[f"pred_atom_coords_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + placements_dp_cp[f"true_atom_coords_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + placements_dp_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + + # Distribute atom features with intersperse padding + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + device_mesh, + manager.group["cp"], + multiplicities={ + "pred_atom_coords": multiplicity, + "true_atom_coords": multiplicity, + "true_coords_resolved_mask": multiplicity, + }, + ) + + # --- Distribute token features using distribute_features --- + if manager.group_rank["world"] == 0: + token_features = { + "mol_type": mol_type_global_host.to(device=manager.device, dtype=torch.int64), + "token_pad_mask": token_pad_mask_global_host.to(device=manager.device, dtype=dtype), + "pred_lddt": pred_lddt_global_host.to(device=manager.device, dtype=torch.float32), + "pred_resolved": pred_resolved_global_host.to(device=manager.device, dtype=torch.float32), + } + if pae_feats_host is not None: + token_features["frame_resolved_mask"] = pae_feats_host["frame_resolved_mask"].to( + device=manager.device, dtype=dtype + ) + token_features["asym_id"] = pae_feats_host["asym_id"].to(device=manager.device) + token_features["is_nonpolymer_with_frame"] = pae_feats_host["is_nonpolymer_with_frame"].to( + device=manager.device + ) + else: + token_features = None + token_placements = { + "mol_type": (Shard(0), Shard(1), Replicate()), + "token_pad_mask": (Shard(0), Shard(1), Replicate()), + "pred_lddt": (Shard(0), Shard(1), Replicate()), + "pred_resolved": (Shard(0), Shard(1), Replicate()), + } + if pae_feats_host is not None: + token_placements["frame_resolved_mask"] = (Shard(0), Shard(1), Replicate()) + token_placements["asym_id"] = (Shard(0), Shard(1), Replicate()) + token_placements["is_nonpolymer_with_frame"] = (Shard(0), Shard(1), Replicate()) + token_feats_dtensor = distribute_features( + token_features, + token_placements, + manager.group["world"], + manager.group_ranks["world"][0], + device_mesh, + ) + + # --- Distribute pair representations (pred_pde, and pred_pae when alpha_pae > 0) --- + if manager.group_rank["world"] == 0: + pair_features = { + "pred_pde": pred_pde_global_host.to(device=manager.device, dtype=torch.float32), + } + if alpha_pae > 0.0: + pair_features["pred_pae"] = pred_pae_global_host.to(device=manager.device, dtype=torch.float32) + else: + pair_features = None + pair_placements = { + "pred_pde": (Shard(0), Shard(1), Shard(2)), + } + if alpha_pae > 0.0: + pair_placements["pred_pae"] = (Shard(0), Shard(1), Shard(2)) + pair_dtensor_dict = distribute_features( + pair_features, + pair_placements, + manager.group["world"], + manager.group_ranks["world"][0], + device_mesh, + ) + + # Extract distributed tensors + pred_atom_coords_dtensor = feats_atom["pred_atom_coords"] + true_atom_coords_dtensor = feats_atom["true_atom_coords"] + true_coords_resolved_mask_dtensor = feats_atom["true_coords_resolved_mask"] + + # Create feature dictionary for confidence_loss + feats_dtensor = { + "token_to_rep_atom": feats_atom["token_to_rep_atom"], + "r_set_to_rep_atom": feats_atom["r_set_to_rep_atom"], + "atom_to_token": feats_atom["atom_to_token"], + "mol_type": token_feats_dtensor["mol_type"], + "token_pad_mask": token_feats_dtensor["token_pad_mask"], + } + if pae_feats_host is not None: + feats_dtensor["frames_idx"] = feats_atom["frames_idx"] + feats_dtensor["atom_pad_mask"] = feats_atom["atom_pad_mask"] + feats_dtensor["atom_resolved_mask"] = feats_atom["atom_resolved_mask"] + feats_dtensor["frame_resolved_mask"] = token_feats_dtensor["frame_resolved_mask"] + feats_dtensor["asym_id"] = token_feats_dtensor["asym_id"] + feats_dtensor["is_nonpolymer_with_frame"] = token_feats_dtensor["is_nonpolymer_with_frame"] + + # Get model_out tensors with gradient tracking + pred_lddt_dtensor = token_feats_dtensor["pred_lddt"].detach().requires_grad_(True) + pred_pde_dtensor = pair_dtensor_dict["pred_pde"].detach().requires_grad_(True) + pred_resolved_dtensor = token_feats_dtensor["pred_resolved"].detach().requires_grad_(True) + + # Build model_out dictionary + model_out = { + "plddt_logits": pred_lddt_dtensor, + "pde_logits": pred_pde_dtensor, + "resolved_logits": pred_resolved_dtensor, + "sample_atom_coords": pred_atom_coords_dtensor, + } + + pred_pae_dtensor = None + if alpha_pae > 0.0: + pred_pae_dtensor = pair_dtensor_dict["pred_pae"].detach().requires_grad_(True) + model_out["pae_logits"] = pred_pae_dtensor + + # Compute confidence_loss + confidence_loss_kwargs = { + "model_out": model_out, + "feats": feats_dtensor, + "true_coords": true_atom_coords_dtensor, + "true_coords_resolved_mask": true_coords_resolved_mask_dtensor, + "comm": comm, + "multiplicity": multiplicity, + "alpha_pae": alpha_pae, + } + if alpha_pae > 0.0: + confidence_loss_kwargs["dist_manager"] = manager + confidence_loss_kwargs["group_layout"] = manager.layout_subgroups["cp"] + + result = confidence_loss(**confidence_loss_kwargs) + + # Verify output structure + assert "loss" in result, "Result must contain 'loss' key" + assert "loss_breakdown" in result, "Result must contain 'loss_breakdown' key" + assert "plddt_loss" in result["loss_breakdown"], "loss_breakdown must contain 'plddt_loss'" + assert "pde_loss" in result["loss_breakdown"], "loss_breakdown must contain 'pde_loss'" + assert "resolved_loss" in result["loss_breakdown"], "loss_breakdown must contain 'resolved_loss'" + assert "pae_loss" in result["loss_breakdown"], "loss_breakdown must contain 'pae_loss'" + + # Verify individual loss values and placements first for better diagnostics + expected_placements = (Replicate(), Replicate(), Replicate()) + for loss_name in ["plddt_loss", "pde_loss", "resolved_loss", "pae_loss"]: + subloss_dtensor = result["loss_breakdown"][loss_name] + assert ( + subloss_dtensor.placements == expected_placements + ), f"{loss_name} placements {subloss_dtensor.placements} != expected {expected_placements}" + loss_value = subloss_dtensor.to_local() + expected_value = expected_loss_breakdown_host[loss_name].to(device=loss_value.device, dtype=loss_value.dtype) + torch.testing.assert_close(loss_value, expected_value, msg=f"{loss_name} mismatch") + + # Verify total loss value + loss_local = result["loss"].to_local() + expected_loss = expected_loss_host.to(device=loss_local.device, dtype=loss_local.dtype) + torch.testing.assert_close(loss_local, expected_loss) + + # Verify total loss placements are fully replicated + loss_dtensor = result["loss"] + assert ( + loss_dtensor.placements == expected_placements + ), f"Loss placements {loss_dtensor.placements} != expected {expected_placements}" + + # Backward pass on DTensor directly + loss_dtensor.backward() + + # Verify gradients for pred_lddt + grad_pred_lddt = pred_lddt_dtensor.grad + assert grad_pred_lddt is not None, "Gradient not computed for pred_lddt" + grad_pred_lddt_full = grad_pred_lddt.full_tensor() + expected_grad_lddt = expected_grad_pred_lddt_host.to( + device=grad_pred_lddt_full.device, dtype=grad_pred_lddt_full.dtype + ) + torch.testing.assert_close(grad_pred_lddt_full, expected_grad_lddt) + + # Verify gradients for pred_pde + grad_pred_pde = pred_pde_dtensor.grad + assert grad_pred_pde is not None, "Gradient not computed for pred_pde" + grad_pred_pde_full = grad_pred_pde.full_tensor() + expected_grad_pde = expected_grad_pred_pde_host.to(device=grad_pred_pde_full.device, dtype=grad_pred_pde_full.dtype) + torch.testing.assert_close(grad_pred_pde_full, expected_grad_pde) + + # Verify gradients for pred_resolved + grad_pred_resolved = pred_resolved_dtensor.grad + assert grad_pred_resolved is not None, "Gradient not computed for pred_resolved" + grad_pred_resolved_full = grad_pred_resolved.full_tensor() + expected_grad_resolved = expected_grad_pred_resolved_host.to( + device=grad_pred_resolved_full.device, dtype=grad_pred_resolved_full.dtype + ) + torch.testing.assert_close(grad_pred_resolved_full, expected_grad_resolved) + + # Verify gradients for pred_pae (when alpha_pae > 0) + if alpha_pae > 0.0: + grad_pred_pae = pred_pae_dtensor.grad + assert grad_pred_pae is not None, "Gradient not computed for pred_pae" + grad_pred_pae_full = grad_pred_pae.full_tensor() + expected_grad_pae = expected_grad_pred_pae_host.to( + device=grad_pred_pae_full.device, dtype=grad_pred_pae_full.dtype + ) + torch.testing.assert_close(grad_pred_pae_full, expected_grad_pae) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +@pytest.mark.parametrize("alpha_pae", [0.0, 0.5], ids=lambda x: f"alpha_pae:{x}") +@pytest.mark.parametrize("multiplicity", [1, 2], ids=lambda x: f"multiplicity:{x}") +def test_confidence_loss(setup_env, alpha_pae: float, multiplicity: int): + """Test that DTensor confidence_loss matches serial reference. + + This test verifies: + 1. Forward pass: DTensor confidence_loss matches serial confidence_loss + 2. Backward pass: Gradients w.r.t. pred_lddt, pred_pde, pred_resolved (and pred_pae + when alpha_pae > 0) match serial gradients + 3. Output structure contains plddt_loss, pde_loss, resolved_loss, pae_loss + + The test uses realistic feature generation with proper block-diagonal structure. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + torch.manual_seed(42) + + dp_size = grid_group_sizes["dp"] + cp_size = grid_group_sizes["cp"][0] * grid_group_sizes["cp"][1] + + # Generate test data + B = dp_size # Batch size equals DP size + N_token = 32 + N_atom = 140 # Large enough to accommodate token atom counts + n_atoms_per_token_min = 1 + n_atoms_per_token_max = 4 + num_bins_lddt = 50 # Number of pLDDT bins + num_bins_pde = 64 # Number of PDE bins + dtype = torch.float32 # Use FP32 for testing + + # Make N_atom divisible by cp_size for even sharding + N_atom = ((N_atom + cp_size - 1) // cp_size) * cp_size + + # Use random_features to generate features with proper block-diagonal structure + selected_keys = [ + "token_to_rep_atom", + "r_set_to_rep_atom", + "atom_to_token", + "mol_type", + "token_pad_mask", + "atom_counts_per_token", + ] + if alpha_pae > 0.0: + selected_keys += [ + "frames_idx", + "frame_resolved_mask", + "asym_id", + "atom_pad_mask", + "atom_resolved_mask", + "is_nonpolymer_with_frame", + ] + + feats = random_features( + size_batch=B, + n_tokens=N_token, + n_atoms=N_atom, + n_msa=1, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=(-0.5, 0.5), + selected_keys=selected_keys, + ) + + if alpha_pae > 0.0: + feats["atom_resolved_mask"] = torch.ones_like(feats["atom_resolved_mask"]) + feats["frame_resolved_mask"] = torch.ones_like(feats["frame_resolved_mask"]) + + token_to_rep_atom_global = feats["token_to_rep_atom"].to(dtype=dtype) # [B, N_token, N_atom] + r_set_to_rep_atom_global = feats["r_set_to_rep_atom"].to(dtype=dtype) # [B, N_R, N_atom] + atom_to_token_global = feats["atom_to_token"].to(dtype=dtype) # [B, N_atom, N_token] + mol_type_global = feats["mol_type"] # [B, N_token] + token_pad_mask_global = feats["token_pad_mask"].to(dtype=dtype) # [B, N_token] + atom_counts_per_token = feats["atom_counts_per_token"] + + N_atom_actual = token_to_rep_atom_global.shape[2] + + # Generate coordinates using uniform distribution + # Use [-10, 10] range for meaningful pairwise distances + pred_atom_coords_global = torch.empty(B * multiplicity, N_atom_actual, 3, device=device_type, dtype=dtype).uniform_( + -10.0, 10.0 + ) + true_atom_coords_global = torch.empty(B * multiplicity, N_atom_actual, 3, device=device_type, dtype=dtype).uniform_( + -10.0, 10.0 + ) + + # Create true_coords_resolved_mask: [B*mult, N_atom] + true_coords_resolved_mask_global = torch.randint( + 0, 2, (B * multiplicity, N_atom_actual), device=device_type, dtype=dtype + ) + # Ensure at least half atoms are resolved for meaningful test + true_coords_resolved_mask_global[:, : N_atom_actual // 2] = 1.0 + + # Generate pred_lddt: (B*mult, N_token, num_bins_lddt) + pred_lddt_global = torch.randn( + B * multiplicity, N_token, num_bins_lddt, device=device_type, dtype=torch.float32 + ).requires_grad_(True) + + # Generate pred_pde: (B*mult, N_token, N_token, num_bins_pde) + pred_pde_global = torch.randn( + B * multiplicity, N_token, N_token, num_bins_pde, device=device_type, dtype=torch.float32 + ).requires_grad_(True) + + # Generate pred_resolved: (B*mult, N_token, 2) + pred_resolved_global = torch.randn( + B * multiplicity, N_token, 2, device=device_type, dtype=torch.float32 + ).requires_grad_(True) + + # Generate pred_pae when needed: (B*mult, N_token, N_token, num_bins_pde) + num_bins_pae = 64 + pred_pae_global = None + if alpha_pae > 0.0: + pred_pae_global = torch.randn( + B * multiplicity, N_token, N_token, num_bins_pae, device=device_type, dtype=torch.float32 + ).requires_grad_(True) + + # Build model_out for serial reference + model_out_serial = { + "plddt_logits": pred_lddt_global, + "pde_logits": pred_pde_global, + "resolved_logits": pred_resolved_global, + "sample_atom_coords": pred_atom_coords_global.float(), + } + if alpha_pae > 0.0: + model_out_serial["pae_logits"] = pred_pae_global + + # Build feats for serial reference + feats_serial = { + "token_to_rep_atom": token_to_rep_atom_global.clone().float(), + "r_set_to_rep_atom": r_set_to_rep_atom_global.clone().float(), + "atom_to_token": atom_to_token_global.clone().float(), + "mol_type": mol_type_global.clone(), + "token_pad_mask": token_pad_mask_global.clone().float(), + } + if alpha_pae > 0.0: + feats_serial["frames_idx"] = feats["frames_idx"].clone() + feats_serial["frame_resolved_mask"] = feats["frame_resolved_mask"].clone().float() + feats_serial["asym_id"] = feats["asym_id"].clone() + feats_serial["atom_pad_mask"] = feats["atom_pad_mask"].clone().float() + feats_serial["atom_resolved_mask"] = feats["atom_resolved_mask"].clone().float() + feats_serial["is_nonpolymer_with_frame"] = feats["is_nonpolymer_with_frame"].clone() + + # Compute serial reference + expected_result = serial_confidence_loss( + model_out=model_out_serial, + feats=feats_serial, + true_coords=true_atom_coords_global.float(), + true_coords_resolved_mask=true_coords_resolved_mask_global.float(), + token_level_confidence=True, + multiplicity=multiplicity, + alpha_pae=alpha_pae, + ) + + expected_loss = expected_result["loss"] + expected_loss_breakdown = { + "plddt_loss": expected_result["loss_breakdown"]["plddt_loss"], + "pde_loss": expected_result["loss_breakdown"]["pde_loss"], + "resolved_loss": expected_result["loss_breakdown"]["resolved_loss"], + "pae_loss": torch.tensor(expected_result["loss_breakdown"]["pae_loss"], dtype=dtype), + } + + # Compute gradients for serial reference + expected_loss.backward() + expected_grad_pred_lddt = pred_lddt_global.grad.clone() + expected_grad_pred_pde = pred_pde_global.grad.clone() + expected_grad_pred_resolved = pred_resolved_global.grad.clone() + expected_grad_pred_pae = None + if alpha_pae > 0.0: + expected_grad_pred_pae = pred_pae_global.grad.clone() + assert expected_grad_pred_pae is not None, "Serial confidence_loss should produce gradients for pred_pae" + + # Verify that serial reference produces gradients + assert expected_grad_pred_lddt is not None, "Serial confidence_loss should produce gradients for pred_lddt" + assert expected_grad_pred_pde is not None, "Serial confidence_loss should produce gradients for pred_pde" + assert expected_grad_pred_resolved is not None, "Serial confidence_loss should produce gradients for pred_resolved" + + # Collect PAE-specific features for the parallel worker + pae_feats = None + if alpha_pae > 0.0: + pae_feats = { + "frames_idx": feats["frames_idx"].clone().cpu(), + "frame_resolved_mask": feats["frame_resolved_mask"].clone().cpu(), + "asym_id": feats["asym_id"].clone().cpu(), + "atom_pad_mask": feats["atom_pad_mask"].clone().cpu(), + "atom_resolved_mask": feats["atom_resolved_mask"].clone().cpu(), + "is_nonpolymer_with_frame": feats["is_nonpolymer_with_frame"].clone().cpu(), + } + + # Prepare payload for parallel test + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + multiplicity, + alpha_pae, + pred_lddt_global.detach().clone().cpu(), + pred_pde_global.detach().clone().cpu(), + pred_pae_global.detach().clone().cpu() if pred_pae_global is not None else None, + pred_resolved_global.detach().clone().cpu(), + pred_atom_coords_global.clone().cpu(), + true_atom_coords_global.clone().cpu(), + true_coords_resolved_mask_global.clone().cpu(), + token_to_rep_atom_global.clone().cpu(), + r_set_to_rep_atom_global.clone().cpu(), + atom_to_token_global.clone().cpu(), + mol_type_global.clone().cpu(), + token_pad_mask_global.clone().cpu(), + atom_counts_per_token.clone().cpu(), + pae_feats, + expected_loss.detach().clone().cpu(), + {k: v.detach().clone().cpu() if isinstance(v, torch.Tensor) else v for k, v in expected_loss_breakdown.items()}, + expected_grad_pred_lddt.clone().cpu(), + expected_grad_pred_pde.clone().cpu(), + expected_grad_pred_pae.clone().cpu() if expected_grad_pred_pae is not None else None, + expected_grad_pred_resolved.clone().cpu(), + ) + + # Launch parallel test + spawn_multiprocessing(parallel_assert_confidence_loss, world_size, payload) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/distributed/model/loss/test_dtensor_confidence_pde_loss.py b/tests/distributed/model/loss/test_dtensor_confidence_pde_loss.py new file mode 100644 index 000000000..d2559d68c --- /dev/null +++ b/tests/distributed/model/loss/test_dtensor_confidence_pde_loss.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +"""Tests for DTensor pde_loss computation. + +This module tests: +1. Whether DTensor pde_loss matches serial pde_loss +2. Whether gradients flow correctly through pred_pde + +The pde_loss computation: + token_to_rep_atom = feats["token_to_rep_atom"] + token_mask = bmm(token_to_rep_atom, resolved_mask) + mask = token_mask[:,:,None] * token_mask[:,None,:] + + true_token_coords = bmm(token_to_rep_atom, true_atom_coords) + pred_token_coords = bmm(token_to_rep_atom, pred_atom_coords) + + true_d = cdist(true_token_coords, true_token_coords) + pred_d = cdist(pred_token_coords, pred_token_coords) + target_pde = abs(true_d - pred_d) + + bin_index = clamp(floor(target_pde * num_bins / max_dist), max=num_bins-1) + errors = cross_entropy(pred_pde, bin_index) + loss = sum(errors * mask) / sum(mask) +""" + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.data.utils import distribute_features +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.loss.confidencev2 import pde_loss +from boltz.model.loss.confidencev2 import pde_loss as serial_pde_loss +from boltz.testing.utils import ( + distribute_atom_features, + random_features, + spawn_multiprocessing, +) + + +def parallel_assert_pde_loss( + rank: int, + payload: tuple, +): + """Worker function that runs on each rank to test pde_loss DTensor implementation. + + Uses the same setup pattern as parallel_assert_plddt_loss. + """ + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + pred_pde_global_host, + pred_atom_coords_global_host, + true_atom_coords_global_host, + true_coords_resolved_mask_global_host, + token_to_rep_atom_global_host, + atom_counts_per_token_host, + expected_loss_host, + expected_grad_pred_pde_host, + multiplicity, + ) = payload + + # Setup environment variables for this rank + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device_mesh = manager.device_mesh_subgroups + dtype = pred_atom_coords_global_host.dtype + + # Create TransposeComm for redistribute_transpose + comm = TransposeComm(manager.group["cp"], manager.layout_subgroups["cp"]) + + # --- Distribute atom features using distribute_atom_features utility --- + size_batch = token_to_rep_atom_global_host.shape[0] + inputs_atom = { + "atom_counts_per_token": atom_counts_per_token_host.to(dtype=torch.int64), + "token_to_rep_atom": token_to_rep_atom_global_host.to(dtype=dtype), + } + + # Add per-multiplicity coordinates and masks + pred_coords_unflat = pred_atom_coords_global_host.unflatten(0, (size_batch, multiplicity)) + true_coords_unflat = true_atom_coords_global_host.unflatten(0, (size_batch, multiplicity)) + resolved_mask_unflat = true_coords_resolved_mask_global_host.unflatten(0, (size_batch, multiplicity)) + + for i_mul in range(multiplicity): + inputs_atom[f"pred_atom_coords_{i_mul}"] = pred_coords_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"true_atom_coords_{i_mul}"] = true_coords_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"true_coords_resolved_mask_{i_mul}"] = resolved_mask_unflat[:, i_mul].to(dtype=dtype) + + # Define placements for CP submesh and full mesh + placements_cp = { + "atom_counts_per_token": (Shard(0), Replicate()), + "token_to_rep_atom": (Shard(0), Replicate()), + } + placements_dp_cp = { + "token_to_rep_atom": (Shard(0), Shard(1), Replicate()), + } + for i_mul in range(multiplicity): + placements_cp[f"pred_atom_coords_{i_mul}"] = (Shard(0), Replicate()) + placements_cp[f"true_atom_coords_{i_mul}"] = (Shard(0), Replicate()) + placements_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Replicate()) + placements_dp_cp[f"pred_atom_coords_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + placements_dp_cp[f"true_atom_coords_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + placements_dp_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + + # Distribute atom features with intersperse padding + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + device_mesh, + manager.group["cp"], + multiplicities={ + "pred_atom_coords": multiplicity, + "true_atom_coords": multiplicity, + "true_coords_resolved_mask": multiplicity, + }, + ) + + # --- Distribute pred_pde (pair representation) --- + # pred_pde: [B*mult, N_token, N_token, num_bins] with placements (S(0), S(1), S(2)) + if manager.group_rank["world"] == 0: + pred_pde_features = { + "pred_pde": pred_pde_global_host.to(device=manager.device, dtype=torch.float32), + } + else: + pred_pde_features = None + pred_pde_placements = { + "pred_pde": (Shard(0), Shard(1), Shard(2)), + } + pred_pde_dtensor_dict = distribute_features( + pred_pde_features, + pred_pde_placements, + manager.group["world"], + manager.group_ranks["world"][0], + device_mesh, + ) + + # Extract distributed tensors + pred_atom_coords_dtensor = feats_atom["pred_atom_coords"] + true_atom_coords_dtensor = feats_atom["true_atom_coords"] + true_coords_resolved_mask_dtensor = feats_atom["true_coords_resolved_mask"] + + # Create feature dictionary + feats_dtensor = { + "token_to_rep_atom": feats_atom["token_to_rep_atom"], + } + + # Get pred_pde DTensor with gradient tracking + pred_pde_dtensor = pred_pde_dtensor_dict["pred_pde"] + pred_pde_dtensor_grad = pred_pde_dtensor.detach().requires_grad_(True) + + # Compute pde_loss + loss = pde_loss( + pred_pde=pred_pde_dtensor_grad, + pred_atom_coords=pred_atom_coords_dtensor, + true_atom_coords=true_atom_coords_dtensor, + true_coords_resolved_mask=true_coords_resolved_mask_dtensor, + feats=feats_dtensor, + comm=comm, + multiplicity=multiplicity, + ) + + # Verify loss value + loss_local = loss.to_local() + # Match dtype of DTensor output + expected_loss = expected_loss_host.to(device=loss_local.device, dtype=loss_local.dtype) + torch.testing.assert_close(loss_local, expected_loss) + + # Verify gradients + loss_local.backward() + grad_pred_pde = pred_pde_dtensor_grad.grad + + # Full gather the gradient to compare + grad_pred_pde_full = grad_pred_pde.full_tensor() + # Match dtype of DTensor gradient output + expected_grad = expected_grad_pred_pde_host.to(device=grad_pred_pde_full.device, dtype=grad_pred_pde_full.dtype) + torch.testing.assert_close(grad_pred_pde_full, expected_grad) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +@pytest.mark.parametrize("multiplicity", [1, 2], ids=lambda x: f"multiplicity:{x}") +def test_pde_loss(setup_env, multiplicity: int): + """Test that DTensor pde_loss matches serial reference. + + This test verifies: + 1. Forward pass: DTensor pde_loss matches serial pde_loss + 2. Backward pass: Gradients w.r.t. pred_pde match serial gradients + + The test uses realistic feature generation with proper block-diagonal structure. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + torch.manual_seed(42) + + dp_size = grid_group_sizes["dp"] + cp_size = grid_group_sizes["cp"][0] * grid_group_sizes["cp"][1] + + # Generate test data + B = dp_size # Batch size equals DP size + N_token = 32 + N_atom = 140 # Large enough to accommodate token atom counts + n_atoms_per_token_min = 1 + n_atoms_per_token_max = 4 + num_bins = 64 # Number of PDE bins + max_dist = 32.0 # Max distance for binning + dtype = torch.float32 # Use FP32 for testing + + # Make N_atom divisible by cp_size for even sharding + N_atom = ((N_atom + cp_size - 1) // cp_size) * cp_size + + # Use random_features to generate features with proper block-diagonal structure + feats = random_features( + size_batch=B, + n_tokens=N_token, + n_atoms=N_atom, + n_msa=1, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=(-0.1, 0.1), + selected_keys=[ + "token_to_rep_atom", + "atom_counts_per_token", + ], + ) + + token_to_rep_atom_global = feats["token_to_rep_atom"].to(dtype=dtype) # [B, N_token, N_atom] + atom_counts_per_token = feats["atom_counts_per_token"] + + N_atom_actual = token_to_rep_atom_global.shape[2] + + # Generate coordinates using uniform distribution + # Use [-10, 10] range so average pairwise distance is meaningful + pred_atom_coords_global = torch.empty(B * multiplicity, N_atom_actual, 3, device=device_type, dtype=dtype).uniform_( + -10.0, 10.0 + ) + true_atom_coords_global = torch.empty(B * multiplicity, N_atom_actual, 3, device=device_type, dtype=dtype).uniform_( + -10.0, 10.0 + ) + + # Create true_coords_resolved_mask: [B*mult, N_atom] + true_coords_resolved_mask_global = torch.randint( + 0, 2, (B * multiplicity, N_atom_actual), device=device_type, dtype=dtype + ) + # Ensure at least half atoms are resolved for meaningful test + true_coords_resolved_mask_global[:, : N_atom_actual // 2] = 1.0 + + # Generate pred_pde: (B*mult, N_token, N_token, num_bins) + pred_pde_global = torch.randn( + B * multiplicity, N_token, N_token, num_bins, device=device_type, dtype=torch.float32 + ).requires_grad_(True) + + # Compute serial reference (serial uses float32 internally) + feats_serial = { + "token_to_rep_atom": token_to_rep_atom_global.clone().float(), + } + + expected_loss, _rel_loss = serial_pde_loss( + pred_pde=pred_pde_global, + pred_atom_coords=pred_atom_coords_global.float(), + feats=feats_serial, + true_atom_coords=true_atom_coords_global.float(), + true_coords_resolved_mask=true_coords_resolved_mask_global.float(), + multiplicity=multiplicity, + max_dist=max_dist, + ) + + # Compute gradients for serial reference + expected_loss.backward() + expected_grad_pred_pde = pred_pde_global.grad.clone() + + # Verify that serial reference produces gradients (the gradient flows through log_softmax) + assert expected_grad_pred_pde is not None, "Serial pde_loss should produce gradients" + assert not torch.allclose( + expected_grad_pred_pde, torch.zeros_like(expected_grad_pred_pde) + ), "Gradients should be non-zero" + + # Prepare payload for parallel test + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + pred_pde_global.detach().clone().cpu(), + pred_atom_coords_global.clone().cpu(), + true_atom_coords_global.clone().cpu(), + true_coords_resolved_mask_global.clone().cpu(), + token_to_rep_atom_global.clone().cpu(), + atom_counts_per_token.clone().cpu(), + expected_loss.detach().clone().cpu(), + expected_grad_pred_pde.clone().cpu(), + multiplicity, + ) + + # Launch parallel test + spawn_multiprocessing(parallel_assert_pde_loss, world_size, payload) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/distributed/model/loss/test_dtensor_confidence_plddt_loss.py b/tests/distributed/model/loss/test_dtensor_confidence_plddt_loss.py new file mode 100644 index 000000000..0d3777e73 --- /dev/null +++ b/tests/distributed/model/loss/test_dtensor_confidence_plddt_loss.py @@ -0,0 +1,928 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +"""Tests for pair_mask factorization and DTensor lddt_resolved_token computation. + +This module tests: +1. Whether the pair_mask construction in serial plddt_loss can be factorized +2. Whether cdist_lddt with factorized masks matches torch.cdist + lddt_dist +3. Whether DTensor lddt_resolved_token() matches serial reference + +The pair_mask factorization: + pair_mask = atom_mask[:,:,None] & atom_mask[:,None,:] # outer product + pair_mask = pair_mask & ~eye # remove diagonal + pair_mask = einsum("bnm,bkm->bnk", pair_mask, r_set_to_rep_atom) + pair_mask = bmm(token_to_rep_atom, pair_mask) + +Can be factorized into: + mask_row = bmm(token_to_rep_atom, atom_mask) + mask_col = bmm(r_set_to_rep_atom, atom_mask) + factorized_mask = mask_row[:,:,None] & mask_col[:,None,:] & diagonal_mask + +Where diagonal_mask handles the atom-level self-pair exclusion. +""" + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard + +from boltz.data import const +from boltz.distributed.comm import TransposeComm +from boltz.distributed.data.utils import distribute_features +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.loss.confidencev2 import lddt_resolved_token, plddt_loss +from boltz.distributed.model.loss.triton.cdist_lddt import cdist_lddt +from boltz.model.loss.confidencev2 import lddt_dist +from boltz.model.loss.confidencev2 import plddt_loss as serial_plddt_loss +from boltz.testing.utils import ( + distribute_atom_features, + random_features, + spawn_multiprocessing, +) + + +def test_pair_mask_factorization(): + """Test pair_mask factorization with realistic features from random_features. + + This test checks if the factorized mask approach can reproduce + the serial pair_mask construction using proper block-diagonal + token_to_rep_atom and r_set_to_rep_atom matrices. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = torch.device("cuda") + torch.manual_seed(42) + dtype = torch.float32 # Use FP32 for testing + + # Test dimensions + B = 2 + N_token = 160 + n_atoms_per_token_min, n_atoms_per_token_max = 1, 3 + # N_atom should be >= N_token * max_atoms_per_token to fit all atoms + N_atom = N_token * (n_atoms_per_token_min + n_atoms_per_token_max) // 2 # 32 + + # Generate realistic features using random_features + feats = random_features( + size_batch=B, + n_tokens=N_token, + n_atoms=N_atom, + n_msa=1, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=device, + float_value_range=(-0.1, 0.1), + selected_keys=[ + "token_to_rep_atom", + "r_set_to_rep_atom", + ], + ) + + # Extract features and convert to boolean where appropriate + token_to_rep_atom = feats["token_to_rep_atom"].bool() # [B, N_token, N_atom] + r_set_to_rep_atom = feats["r_set_to_rep_atom"].bool() # [B, N_R, N_atom] + + N_atom_actual = token_to_rep_atom.shape[2] + + # Create random atom_mask [B, N_atom] - boolean + atom_mask = torch.randint(0, 2, (B, N_atom_actual), device=device, dtype=torch.bool) + + # ========== Serial pair_mask construction (using boolean ops) ========== + # Step 1: Outer product of atom_mask + pair_mask = atom_mask.unsqueeze(-1) & atom_mask.unsqueeze(-2) # [B, N_atom, N_atom] + + # Step 2: Remove diagonal (self-pairs in atom space) + diag_mask = ~torch.eye(N_atom_actual, device=device, dtype=torch.bool) + pair_mask = pair_mask & diag_mask[None, :, :] + + # Step 3: Project columns via r_set_to_rep_atom (need float for einsum) + pair_mask_float = pair_mask.to(dtype=dtype) + r_set_float = r_set_to_rep_atom.to(dtype=dtype) + pair_mask_float = torch.einsum("bnm,bkm->bnk", pair_mask_float, r_set_float) # [B, N_atom, N_R] + + # Step 4: Project rows via token_to_rep_atom + token_float = token_to_rep_atom.to(dtype=dtype) + pair_mask_serial = torch.bmm(token_float, pair_mask_float) # [B, N_token, N_R] + + # Convert back to boolean (values are 0 or 1 due to one-hot matrices) + pair_mask_serial = pair_mask_serial.bool() + + # ========== Factorized mask construction ========== + # mask_row[b, t] = True if rep_atom of token t is resolved + mask_row = torch.bmm(token_float, atom_mask.unsqueeze(-1).to(dtype=dtype)).squeeze(-1).bool() # [B, N_token] + + # mask_col[b, r] = True if rep_atom of R-set r is resolved + mask_col = torch.bmm(r_set_float, atom_mask.unsqueeze(-1).to(dtype=dtype)).squeeze(-1).bool() # [B, N_R] + + # Outer product of factorized masks + factorized_outer = mask_row.unsqueeze(-1) & mask_col.unsqueeze(-2) # [B, N_token, N_R] + + # Diagonal mask: exclude pairs where rep_atom_token == rep_atom_r_set + rep_atom_token = token_to_rep_atom.int().argmax(dim=-1) # [B, N_token] + rep_atom_r_set = r_set_to_rep_atom.int().argmax(dim=-1) # [B, N_R] + + # diagonal_mask[b, t, r] = True if rep_atom_token[b,t] != rep_atom_r_set[b,r] + diagonal_mask = rep_atom_token.unsqueeze(-1) != rep_atom_r_set.unsqueeze(-2) # [B, N_token, N_R] + + # Factorized mask with diagonal exclusion + pair_mask_factorized = factorized_outer & diagonal_mask # [B, N_token, N_R] + + # ========== Compare ========== + # The test: are they equal? + assert torch.equal(pair_mask_factorized, pair_mask_serial), "Factorized mask does NOT equal serial pair_mask" + + +@pytest.mark.parametrize("multiplicity", [1, 2], ids=lambda x: f"multiplicity:{x}") +def test_pair_mask_factorized_cdist_lddt(multiplicity): + """Test that cdist_lddt with factorized masks matches cdist + lddt_dist. + + This test verifies that: + 1. cdist_lddt forward output matches torch.cdist + lddt_dist + 2. Multiplicity is handled correctly for coordinates and masks + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = torch.device("cuda") + torch.manual_seed(42) + dtype = torch.float32 # Use FP32 for testing + + # Test dimensions + B = 2 + N_token = 32 + n_atoms_per_token_min, n_atoms_per_token_max = 1, 3 + N_atom = N_token * (n_atoms_per_token_min + n_atoms_per_token_max) // 2 + + # Generate realistic features using random_features + feats = random_features( + size_batch=B, + n_tokens=N_token, + n_atoms=N_atom, + n_msa=1, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=device, + float_value_range=(-0.1, 0.1), + selected_keys=[ + "token_to_rep_atom", + "r_set_to_rep_atom", + ], + ) + + token_to_rep_atom = feats["token_to_rep_atom"].to(dtype=dtype) # [B, N_token, N_atom] + r_set_to_rep_atom = feats["r_set_to_rep_atom"].to(dtype=dtype) # [B, N_R, N_atom] + + N_R = r_set_to_rep_atom.shape[1] + N_atom_actual = token_to_rep_atom.shape[2] + + # Create random atom_mask [B*mult, N_atom] - boolean (with multiplicity) + atom_mask = torch.randint(0, 2, (B * multiplicity, N_atom_actual), device=device, dtype=torch.bool) + # Ensure at least half atoms are resolved + atom_mask[:, : N_atom_actual // 2] = True + + # Create random coordinates for tokens (rows) and R-set (columns) with multiplicity + # Use uniform [-10, 10] to ensure meaningful distance distribution around cutoff + # Shape: [B*mult, N_token/N_R, 3] + pred_coords_row = torch.empty(B * multiplicity, N_token, 3, device=device, dtype=dtype).uniform_(-10.0, 10.0) + true_coords_row = torch.empty(B * multiplicity, N_token, 3, device=device, dtype=dtype).uniform_(-10.0, 10.0) + pred_coords_col = torch.empty(B * multiplicity, N_R, 3, device=device, dtype=dtype).uniform_(-10.0, 10.0) + true_coords_col = torch.empty(B * multiplicity, N_R, 3, device=device, dtype=dtype).uniform_(-10.0, 10.0) + + # Compute average distance to set cutoff at roughly the median distance + # This ensures meaningful coverage of pairs both within and outside cutoff + pred_d_test = torch.cdist(pred_coords_row, pred_coords_col) + true_d_test = torch.cdist(true_coords_row, true_coords_col) + avg_dist = (pred_d_test.mean().item() + true_d_test.mean().item()) / 2.0 + cutoff = avg_dist + + # ========== Serial approach: torch.cdist + lddt_dist with full pair_mask ========== + # Expand token_to_rep_atom and r_set_to_rep_atom for multiplicity + token_to_rep_atom_mult = token_to_rep_atom.repeat_interleave(multiplicity, dim=0) # [B*mult, N_token, N_atom] + r_set_to_rep_atom_mult = r_set_to_rep_atom.repeat_interleave(multiplicity, dim=0) # [B*mult, N_R, N_atom] + + # Construct pair_mask with multiplicity + atom_mask_float = atom_mask.to(dtype=dtype) + pair_mask = atom_mask_float.unsqueeze(-1) * atom_mask_float.unsqueeze(-2) # [B*mult, N_atom, N_atom] + diag_mask = 1.0 - torch.eye(N_atom_actual, device=device, dtype=dtype) + pair_mask = pair_mask * diag_mask[None, :, :] + pair_mask = torch.einsum("bnm,bkm->bnk", pair_mask, r_set_to_rep_atom_mult) # [B*mult, N_atom, N_R] + pair_mask_serial = torch.bmm(token_to_rep_atom_mult, pair_mask) # [B*mult, N_token, N_R] + + # Compute distances using torch.cdist + pred_d_serial = torch.cdist(pred_coords_row, pred_coords_col) # [B*mult, N_token, N_R] + true_d_serial = torch.cdist(true_coords_row, true_coords_col) # [B*mult, N_token, N_R] + + # Compute lddt using lddt_dist + lddt_serial, mask_no_match_serial = lddt_dist( + pred_d_serial, true_d_serial, pair_mask_serial, cutoff=cutoff, per_atom=True + ) + + # ========== Factorized approach: cdist_lddt ========== + # Factorized masks with multiplicity [B*mult, N_token] and [B*mult, N_R] + mask_row = torch.bmm(token_to_rep_atom_mult, atom_mask_float.unsqueeze(-1)).squeeze(-1) # [B*mult, N_token] + mask_col = torch.bmm(r_set_to_rep_atom_mult, atom_mask_float.unsqueeze(-1)).squeeze(-1) # [B*mult, N_R] + + # Representative atom indices [B, N_token] and [B, N_R] (no multiplicity) + rep_atom_token = token_to_rep_atom.argmax(dim=-1) # [B, N_token] + rep_atom_r_set = r_set_to_rep_atom.argmax(dim=-1) # [B, N_R] + + lddt_cdist, mask_no_match_cdist = cdist_lddt( + pred_coords_row=pred_coords_row, + pred_coords_col=pred_coords_col, + true_coords_row=true_coords_row, + true_coords_col=true_coords_col, + mask_row=mask_row, + mask_col=mask_col, + multiplicity=multiplicity, + atom_indices_row=rep_atom_token, + atom_indices_col=rep_atom_r_set, + cutoff=cutoff, + do_mask_diagonal=True, + per_atom=True, + ) + + # ========== Compare forward pass ========== + # cdist_lddt Triton kernel uses float32 internally for performance, + # so convert serial reference to match for comparison + torch.testing.assert_close(lddt_cdist, lddt_serial.to(lddt_cdist.dtype)) + torch.testing.assert_close(mask_no_match_cdist, mask_no_match_serial.to(mask_no_match_cdist.dtype)) + + +def parallel_assert_lddt_resolved_token( + rank: int, + payload: tuple, +): + """Parallel test function for DTensor lddt_resolved_token(). + + This function runs on each rank in the distributed setup and verifies that + the DTensor lddt_resolved_token implementation matches the serial reference. + """ + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + pred_atom_coords_global_host, + true_atom_coords_global_host, + true_coords_resolved_mask_global_host, + token_to_rep_atom_global_host, + r_set_to_rep_atom_global_host, + atom_to_token_global_host, + mol_type_global_host, + atom_counts_per_token_host, + expected_target_lddt_host, + expected_combined_mask_host, + multiplicity, + cutoff_value, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device_mesh = manager.device_mesh_subgroups + dtype = pred_atom_coords_global_host.dtype + + # Create TransposeComm for redistribute_transpose + comm = TransposeComm(manager.group["cp"], manager.layout_subgroups["cp"]) + + # --- Distribute atom features using distribute_atom_features utility --- + size_batch = token_to_rep_atom_global_host.shape[0] + inputs_atom = { + "atom_counts_per_token": atom_counts_per_token_host.to(dtype=torch.int64), + "token_to_rep_atom": token_to_rep_atom_global_host.to(dtype=dtype), + "r_set_to_rep_atom": r_set_to_rep_atom_global_host.to(dtype=dtype), + "atom_to_token": atom_to_token_global_host.to(dtype=dtype), + } + + # Add per-multiplicity coordinates and masks + # Unflatten [B*mult, N_atom, 3] -> [B, mult, N_atom, 3] + pred_coords_unflat = pred_atom_coords_global_host.unflatten(0, (size_batch, multiplicity)) + true_coords_unflat = true_atom_coords_global_host.unflatten(0, (size_batch, multiplicity)) + resolved_mask_unflat = true_coords_resolved_mask_global_host.unflatten(0, (size_batch, multiplicity)) + + for i_mul in range(multiplicity): + inputs_atom[f"pred_atom_coords_{i_mul}"] = pred_coords_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"true_atom_coords_{i_mul}"] = true_coords_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"true_coords_resolved_mask_{i_mul}"] = resolved_mask_unflat[:, i_mul].to(dtype=dtype) + + # Define placements for CP submesh and full mesh + # Note: distribute_atom_features expects Replicate() for second dim of atom features + placements_cp = { + "atom_counts_per_token": (Shard(0), Replicate()), + "token_to_rep_atom": (Shard(0), Replicate()), + "r_set_to_rep_atom": (Shard(0), Replicate()), + "atom_to_token": (Shard(0), Replicate()), + } + placements_dp_cp = { + "token_to_rep_atom": (Shard(0), Shard(1), Replicate()), + "r_set_to_rep_atom": (Shard(0), Shard(1), Replicate()), + "atom_to_token": (Shard(0), Shard(1), Replicate()), + } + for i_mul in range(multiplicity): + placements_cp[f"pred_atom_coords_{i_mul}"] = (Shard(0), Replicate()) + placements_cp[f"true_atom_coords_{i_mul}"] = (Shard(0), Replicate()) + placements_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Replicate()) + placements_dp_cp[f"pred_atom_coords_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + placements_dp_cp[f"true_atom_coords_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + placements_dp_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + + # Distribute atom features with intersperse padding + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + device_mesh, + manager.group["cp"], + multiplicities={ + "pred_atom_coords": multiplicity, + "true_atom_coords": multiplicity, + "true_coords_resolved_mask": multiplicity, + }, + ) + + # --- Distribute token features (mol_type) using distribute_features --- + # mol_type is a token feature [B, N_token], not an atom feature + # Only rank 0 in the world group provides the features, others pass None + if manager.group_rank["world"] == 0: + token_features = { + "mol_type": mol_type_global_host.to(device=manager.device, dtype=torch.int64), + } + else: + token_features = None + token_placements = { + "mol_type": (Shard(0), Shard(1), Replicate()), + } + token_feats_dtensor = distribute_features( + token_features, + token_placements, + manager.group["world"], + manager.group_ranks["world"][0], + device_mesh, + ) + + # Extract distributed tensors + pred_atom_coords_dtensor = feats_atom["pred_atom_coords"] + true_atom_coords_dtensor = feats_atom["true_atom_coords"] + true_coords_resolved_mask_dtensor = feats_atom["true_coords_resolved_mask"] + + # Create feature dictionary + feats_dtensor = { + "token_to_rep_atom": feats_atom["token_to_rep_atom"], + "r_set_to_rep_atom": feats_atom["r_set_to_rep_atom"], + "atom_to_token": feats_atom["atom_to_token"], + "mol_type": token_feats_dtensor["mol_type"], + } + + # Call distributed lddt_resolved_token() + # Returns (target_lddt, combined_mask) where combined_mask = token_resolved_mask * mask_no_match + target_lddt_dtensor, combined_mask_dtensor = lddt_resolved_token( + pred_atom_coords_dtensor, + true_atom_coords_dtensor, + true_coords_resolved_mask_dtensor, + feats_dtensor, + comm, + multiplicity=multiplicity, + cutoff=cutoff_value, + ) + + # Verify against serial reference + target_lddt_global = target_lddt_dtensor.full_tensor().cpu() + # Match dtype of DTensor output (preserves input coordinate dtype) + expected_target_lddt_global = expected_target_lddt_host.to(dtype=target_lddt_global.dtype) + + torch.testing.assert_close( + target_lddt_global, + expected_target_lddt_global, + ) + + combined_mask_global = combined_mask_dtensor.full_tensor().cpu() + # Match dtype of DTensor output (inherits from coordinate dtype) + expected_combined_mask_global = expected_combined_mask_host.to(dtype=combined_mask_global.dtype) + + torch.testing.assert_close( + combined_mask_global, + expected_combined_mask_global, + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +@pytest.mark.parametrize("multiplicity", [1, 2], ids=lambda x: f"multiplicity:{x}") +def test_lddt_resolved_token(setup_env, multiplicity): + """Test DTensor lddt_resolved_token() implementation against serial reference. + + This test verifies that the distributed lddt_resolved_token computation matches the + serial implementation by: + 1. Generating realistic features using random_features() + 2. Computing serial reference using exact code from plddt_loss + 3. Sharding inputs using distribute_atom_features() + 4. Calling distributed lddt_resolved_token() function + 5. Verifying output matches serial reference + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + torch.manual_seed(seed) + + # Test dimensions + dp_size = grid_group_sizes["dp"] + cp_size = grid_group_sizes["cp"][0] + + # Batch size must equal DP size for per-sample processing in pad_and_scatter + B = dp_size + n_tokens_per_shard = 16 + N_token = n_tokens_per_shard * cp_size + # Atoms: use variable atoms per token (1-3 atoms per token) + n_atoms_per_token_min, n_atoms_per_token_max = 1, 3 + # Estimate total atoms: avg 2 atoms/token * N_token, rounded to be divisible by cp_size + avg_atoms_per_token = (n_atoms_per_token_min + n_atoms_per_token_max) / 2 + N_atom = int(avg_atoms_per_token * N_token) + # Make N_atom divisible by cp_size for even sharding + N_atom = ((N_atom + cp_size - 1) // cp_size) * cp_size + dtype = torch.float32 # Use FP32 for testing + + # Use random_features to generate features with proper block-diagonal structure + feats = random_features( + size_batch=B, + n_tokens=N_token, + n_atoms=N_atom, + n_msa=1, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=(-0.1, 0.1), + selected_keys=[ + "token_to_rep_atom", + "r_set_to_rep_atom", + "atom_to_token", + "mol_type", + "atom_counts_per_token", + ], + ) + + token_to_rep_atom_global = feats["token_to_rep_atom"].to(dtype=dtype) # [B, N_token, N_atom] + r_set_to_rep_atom_global = feats["r_set_to_rep_atom"].to(dtype=dtype) # [B, N_R, N_atom] + atom_to_token_global = feats["atom_to_token"].to(dtype=dtype) # [B, N_atom, N_token] + mol_type_global = feats["mol_type"] # [B, N_token] + atom_counts_per_token = feats["atom_counts_per_token"] + + N_atom_actual = token_to_rep_atom_global.shape[2] + + # Generate coordinates using uniform distribution + # Use [-10, 10] range to ensure meaningful distance distribution around cutoff + # The cutoff will be computed as average distance to ensure coverage above and below + pred_atom_coords_global = torch.empty(B * multiplicity, N_atom_actual, 3, device=device_type, dtype=dtype).uniform_( + -10.0, 10.0 + ) + true_atom_coords_global = torch.empty(B * multiplicity, N_atom_actual, 3, device=device_type, dtype=dtype).uniform_( + -10.0, 10.0 + ) + + # Create true_coords_resolved_mask: [B*mult, N_atom] + true_coords_resolved_mask_global = torch.randint( + 0, 2, (B * multiplicity, N_atom_actual), device=device_type, dtype=dtype + ) + # Ensure at least half atoms are resolved for meaningful test + true_coords_resolved_mask_global[:, : N_atom_actual // 2] = 1.0 + + # === SERIAL REFERENCE (exact copy from plddt_loss lines 178-231) === + atom_mask = true_coords_resolved_mask_global + + R_set_to_rep_atom = r_set_to_rep_atom_global.repeat_interleave(multiplicity, 0).to(dtype=dtype) + + token_type = mol_type_global.repeat_interleave(multiplicity, 0) + is_nucleotide_token = (token_type == const.chain_type_ids["DNA"]).to(dtype=dtype) + ( + token_type == const.chain_type_ids["RNA"] + ).to(dtype=dtype) + + atom_to_token = atom_to_token_global.to(dtype=dtype).repeat_interleave(multiplicity, 0) + token_to_rep_atom = token_to_rep_atom_global.to(dtype=dtype).repeat_interleave(multiplicity, 0) + + true_token_coords = torch.bmm(token_to_rep_atom, true_atom_coords_global) + pred_token_coords = torch.bmm(token_to_rep_atom, pred_atom_coords_global) + + true_d = torch.cdist( + true_token_coords, + torch.bmm(R_set_to_rep_atom, true_atom_coords_global), + ) + pred_d = torch.cdist( + pred_token_coords, + torch.bmm(R_set_to_rep_atom, pred_atom_coords_global), + ) + + # pair_mask construction + pair_mask = atom_mask.unsqueeze(-1) * atom_mask.unsqueeze(-2) # [B, N_atom, N_atom] + pair_mask = pair_mask * (1 - torch.eye(pair_mask.shape[1], device=pair_mask.device))[None, :, :] + pair_mask = torch.einsum("bnm,bkm->bnk", pair_mask, R_set_to_rep_atom) # [B, N_atom, N_R] + pair_mask = torch.bmm(token_to_rep_atom, pair_mask) # [B, N_token, N_R] + + is_nucleotide_R_element = torch.bmm( + R_set_to_rep_atom, torch.bmm(atom_to_token, is_nucleotide_token.unsqueeze(-1)) + ).squeeze(-1) + + # Compute average inter-atom distance for cutoff to ensure coverage below and above cutoff + # This gives better numerical stability in the lDDT computation + avg_pred_dist = pred_d.mean().item() + avg_true_dist = true_d.mean().item() + cutoff_value = (avg_pred_dist + avg_true_dist) / 2.0 + + cutoff = cutoff_value + cutoff_value * is_nucleotide_R_element.reshape(B * multiplicity, 1, -1).repeat( + 1, true_d.shape[1], 1 + ) + + # lddt_dist (per_atom=True) + expected_target_lddt, expected_mask_no_match = lddt_dist(pred_d, true_d, pair_mask, cutoff, per_atom=True) + + # Compute token_resolved_mask (whether each token has a resolved representative atom) + # This matches the computation in lddt_resolved_token + token_resolved_mask = torch.bmm( + token_to_rep_atom, atom_mask.unsqueeze(-1).to(dtype=token_to_rep_atom.dtype) + ).squeeze(-1) + expected_combined_mask = token_resolved_mask * expected_mask_no_match + + # Verify that lddt_dist does not support gradient computation. + # The lDDT metric uses step functions (thresholding at 0.5, 1, 2, 4 Å) which break the + # autograd graph. Even with requires_grad=True on inputs, backward() raises RuntimeError. + pred_d_grad = pred_d.detach().clone().requires_grad_(True) + true_d_grad = true_d.detach().clone().requires_grad_(True) + lddt_out, _ = lddt_dist(pred_d_grad, true_d_grad, pair_mask, cutoff, per_atom=True) + with pytest.raises(RuntimeError, match="does not require grad and does not have a grad_fn"): + lddt_out.sum().backward() + + # Prepare payload for parallel test + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + pred_atom_coords_global.clone().cpu(), + true_atom_coords_global.clone().cpu(), + true_coords_resolved_mask_global.clone().cpu(), + token_to_rep_atom_global.clone().cpu(), + r_set_to_rep_atom_global.clone().cpu(), + atom_to_token_global.clone().cpu(), + mol_type_global.clone().cpu(), + atom_counts_per_token.clone().cpu(), + expected_target_lddt.detach().clone().cpu(), + expected_combined_mask.detach().clone().cpu(), + multiplicity, + cutoff_value, + ) + + # Launch parallel test + spawn_multiprocessing(parallel_assert_lddt_resolved_token, world_size, payload) + + +def parallel_assert_plddt_loss( + rank: int, + payload: tuple, +): + """Worker function that runs on each rank to test plddt_loss DTensor implementation. + + Uses the same setup pattern as parallel_assert_lddt_resolved_token. + """ + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + pred_lddt_global_host, + pred_atom_coords_global_host, + true_atom_coords_global_host, + true_coords_resolved_mask_global_host, + token_to_rep_atom_global_host, + r_set_to_rep_atom_global_host, + atom_to_token_global_host, + mol_type_global_host, + atom_counts_per_token_host, + expected_loss_host, + expected_grad_pred_lddt_host, + multiplicity, + ) = payload + + # Setup environment variables for this rank (same as test_lddt_resolved_token) + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device_mesh = manager.device_mesh_subgroups + dtype = pred_atom_coords_global_host.dtype + + # Create TransposeComm for redistribute_transpose + comm = TransposeComm(manager.group["cp"], manager.layout_subgroups["cp"]) + + # --- Distribute atom features using distribute_atom_features utility --- + # (same pattern as parallel_assert_lddt_resolved_token) + size_batch = token_to_rep_atom_global_host.shape[0] + inputs_atom = { + "atom_counts_per_token": atom_counts_per_token_host.to(dtype=torch.int64), + "token_to_rep_atom": token_to_rep_atom_global_host.to(dtype=dtype), + "r_set_to_rep_atom": r_set_to_rep_atom_global_host.to(dtype=dtype), + "atom_to_token": atom_to_token_global_host.to(dtype=dtype), + } + + # Add per-multiplicity coordinates and masks + pred_coords_unflat = pred_atom_coords_global_host.unflatten(0, (size_batch, multiplicity)) + true_coords_unflat = true_atom_coords_global_host.unflatten(0, (size_batch, multiplicity)) + resolved_mask_unflat = true_coords_resolved_mask_global_host.unflatten(0, (size_batch, multiplicity)) + + for i_mul in range(multiplicity): + inputs_atom[f"pred_atom_coords_{i_mul}"] = pred_coords_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"true_atom_coords_{i_mul}"] = true_coords_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"true_coords_resolved_mask_{i_mul}"] = resolved_mask_unflat[:, i_mul].to(dtype=dtype) + + # Define placements for CP submesh and full mesh + placements_cp = { + "atom_counts_per_token": (Shard(0), Replicate()), + "token_to_rep_atom": (Shard(0), Replicate()), + "r_set_to_rep_atom": (Shard(0), Replicate()), + "atom_to_token": (Shard(0), Replicate()), + } + placements_dp_cp = { + "token_to_rep_atom": (Shard(0), Shard(1), Replicate()), + "r_set_to_rep_atom": (Shard(0), Shard(1), Replicate()), + "atom_to_token": (Shard(0), Shard(1), Replicate()), + } + for i_mul in range(multiplicity): + placements_cp[f"pred_atom_coords_{i_mul}"] = (Shard(0), Replicate()) + placements_cp[f"true_atom_coords_{i_mul}"] = (Shard(0), Replicate()) + placements_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Replicate()) + placements_dp_cp[f"pred_atom_coords_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + placements_dp_cp[f"true_atom_coords_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + placements_dp_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + + # Distribute atom features with intersperse padding + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + device_mesh, + manager.group["cp"], + multiplicities={ + "pred_atom_coords": multiplicity, + "true_atom_coords": multiplicity, + "true_coords_resolved_mask": multiplicity, + }, + ) + + # --- Distribute token features (mol_type and pred_lddt) using distribute_features --- + if manager.group_rank["world"] == 0: + token_features = { + "mol_type": mol_type_global_host.to(device=manager.device, dtype=torch.int64), + "pred_lddt": pred_lddt_global_host.to(device=manager.device, dtype=torch.float32), + } + else: + token_features = None + token_placements = { + "mol_type": (Shard(0), Shard(1), Replicate()), + "pred_lddt": (Shard(0), Shard(1), Replicate()), + } + token_feats_dtensor = distribute_features( + token_features, + token_placements, + manager.group["world"], + manager.group_ranks["world"][0], + device_mesh, + ) + + # Extract distributed tensors + pred_atom_coords_dtensor = feats_atom["pred_atom_coords"] + true_atom_coords_dtensor = feats_atom["true_atom_coords"] + true_coords_resolved_mask_dtensor = feats_atom["true_coords_resolved_mask"] + + # Create feature dictionary + feats_dtensor = { + "token_to_rep_atom": feats_atom["token_to_rep_atom"], + "r_set_to_rep_atom": feats_atom["r_set_to_rep_atom"], + "atom_to_token": feats_atom["atom_to_token"], + "mol_type": token_feats_dtensor["mol_type"], + } + + # Get pred_lddt DTensor with gradient tracking + pred_lddt_dtensor = token_feats_dtensor["pred_lddt"] + pred_lddt_dtensor_grad = pred_lddt_dtensor.detach().requires_grad_(True) + + # Compute plddt_loss + loss = plddt_loss( + pred_lddt=pred_lddt_dtensor_grad, + pred_atom_coords=pred_atom_coords_dtensor, + true_atom_coords=true_atom_coords_dtensor, + true_coords_resolved_mask=true_coords_resolved_mask_dtensor, + feats=feats_dtensor, + comm=comm, + multiplicity=multiplicity, + ) + + # Verify loss value + loss_local = loss.to_local() + # Match dtype of DTensor output (may inherit from coordinate dtype) + expected_loss = expected_loss_host.to(device=loss_local.device, dtype=loss_local.dtype) + torch.testing.assert_close(loss_local, expected_loss) + + # Verify gradients + loss_local.backward() + grad_pred_lddt = pred_lddt_dtensor_grad.grad + + # Full gather the gradient to compare + grad_pred_lddt_full = grad_pred_lddt.full_tensor() + # Match dtype of DTensor gradient output + expected_grad = expected_grad_pred_lddt_host.to(device=grad_pred_lddt_full.device, dtype=grad_pred_lddt_full.dtype) + torch.testing.assert_close(grad_pred_lddt_full, expected_grad) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +@pytest.mark.parametrize("multiplicity", [1, 2], ids=lambda x: f"multiplicity:{x}") +def test_plddt_loss(setup_env, multiplicity: int): + """Test that DTensor plddt_loss matches serial reference. + + This test verifies: + 1. Forward pass: DTensor plddt_loss matches serial plddt_loss + 2. Backward pass: Gradients w.r.t. pred_lddt match serial gradients + + The test uses the same feature generation as test_lddt_resolved_token. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + torch.manual_seed(42) + + dp_size = grid_group_sizes["dp"] + cp_size = grid_group_sizes["cp"][0] * grid_group_sizes["cp"][1] + + # Generate test data + B = dp_size # Batch size equals DP size + N_token = 32 + N_atom = 140 # Large enough to accommodate token atom counts + n_atoms_per_token_min = 1 + n_atoms_per_token_max = 4 + num_bins = 50 # Number of pLDDT bins + dtype = torch.float32 # Use FP32 for testing + + # Make N_atom divisible by cp_size for even sharding + N_atom = ((N_atom + cp_size - 1) // cp_size) * cp_size + + # Use random_features to generate features with proper block-diagonal structure + feats = random_features( + size_batch=B, + n_tokens=N_token, + n_atoms=N_atom, + n_msa=1, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=(-0.1, 0.1), + selected_keys=[ + "token_to_rep_atom", + "r_set_to_rep_atom", + "atom_to_token", + "mol_type", + "atom_counts_per_token", + ], + ) + + token_to_rep_atom_global = feats["token_to_rep_atom"].to(dtype=dtype) # [B, N_token, N_atom] + r_set_to_rep_atom_global = feats["r_set_to_rep_atom"].to(dtype=dtype) # [B, N_R, N_atom] + atom_to_token_global = feats["atom_to_token"].to(dtype=dtype) # [B, N_atom, N_token] + mol_type_global = feats["mol_type"] # [B, N_token] + atom_counts_per_token = feats["atom_counts_per_token"] + + N_atom_actual = token_to_rep_atom_global.shape[2] + + # Generate coordinates using uniform distribution + # Use [-10, 10] range so average pairwise distance is ~15 (matching cutoff in serial plddt_loss) + # This ensures meaningful coverage of pairs both within and outside cutoff + pred_atom_coords_global = torch.empty(B * multiplicity, N_atom_actual, 3, device=device_type, dtype=dtype).uniform_( + -10.0, 10.0 + ) + true_atom_coords_global = torch.empty(B * multiplicity, N_atom_actual, 3, device=device_type, dtype=dtype).uniform_( + -10.0, 10.0 + ) + + # Create true_coords_resolved_mask: [B*mult, N_atom] + true_coords_resolved_mask_global = torch.randint( + 0, 2, (B * multiplicity, N_atom_actual), device=device_type, dtype=dtype + ) + # Ensure at least half atoms are resolved for meaningful test + true_coords_resolved_mask_global[:, : N_atom_actual // 2] = 1.0 + + # Generate pred_lddt: (B*mult, N_token, num_bins) + pred_lddt_global = torch.randn( + B * multiplicity, N_token, num_bins, device=device_type, dtype=torch.float32 + ).requires_grad_(True) + + # Compute serial reference (serial uses float32 internally) + feats_serial = { + "token_to_rep_atom": token_to_rep_atom_global.clone().float(), + "r_set_to_rep_atom": r_set_to_rep_atom_global.clone().float(), + "atom_to_token": atom_to_token_global.clone().float(), + "mol_type": mol_type_global.clone(), + } + + expected_loss, _rel_loss = serial_plddt_loss( + pred_lddt=pred_lddt_global, + pred_atom_coords=pred_atom_coords_global.float(), + feats=feats_serial, + true_atom_coords=true_atom_coords_global.float(), + true_coords_resolved_mask=true_coords_resolved_mask_global.float(), + token_level_confidence=True, + multiplicity=multiplicity, + ) + + # Compute gradients for serial reference + expected_loss.backward() + expected_grad_pred_lddt = pred_lddt_global.grad.clone() + + # Verify that serial reference produces gradients (the gradient flows through log_softmax) + assert expected_grad_pred_lddt is not None, "Serial plddt_loss should produce gradients" + assert not torch.allclose( + expected_grad_pred_lddt, torch.zeros_like(expected_grad_pred_lddt) + ), "Gradients should be non-zero" + + # Prepare payload for parallel test + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + pred_lddt_global.detach().clone().cpu(), + pred_atom_coords_global.clone().cpu(), + true_atom_coords_global.clone().cpu(), + true_coords_resolved_mask_global.clone().cpu(), + token_to_rep_atom_global.clone().cpu(), + r_set_to_rep_atom_global.clone().cpu(), + atom_to_token_global.clone().cpu(), + mol_type_global.clone().cpu(), + atom_counts_per_token.clone().cpu(), + expected_loss.detach().clone().cpu(), + expected_grad_pred_lddt.clone().cpu(), + multiplicity, + ) + + # Launch parallel test + spawn_multiprocessing(parallel_assert_plddt_loss, world_size, payload) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/distributed/model/loss/test_dtensor_confidence_resolved_loss.py b/tests/distributed/model/loss/test_dtensor_confidence_resolved_loss.py new file mode 100644 index 000000000..44c998d4a --- /dev/null +++ b/tests/distributed/model/loss/test_dtensor_confidence_resolved_loss.py @@ -0,0 +1,545 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +"""Tests for DTensor resolved_loss implementation. + +This module tests the distributed implementation of resolved_loss which computes +binary cross-entropy loss for predicting whether atoms are resolved. The tests +verify numerical correctness against the serial implementation and proper gradient +computation. + +The key challenge is handling the block-diagonal structure of token_to_rep_atom +with intersperse padding, where each CP rank only sees its local token-atom +correspondences. This requires using pad_and_scatter_atom_features_dtensor +instead of simple distribute_tensor for atom features. +""" + +import unittest + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.loss.confidencev2 import ( + resolved_loss, + resolved_negative_log_likelihood, +) +from boltz.model.loss.confidencev2 import resolved_loss as serial_resolved_loss +from boltz.testing.utils import ( + distribute_atom_features, + init_tensors_uniform, + random_features, + spawn_multiprocessing, +) + + +def parallel_assert_resolved_loss( + rank: int, + payload: tuple, +): + """Parallel test function for resolved_loss. + + This function runs on each rank in the distributed setup and verifies that + the DTensor implementation matches the serial reference. + + Uses pad_and_scatter_atom_features_dtensor for proper atom feature sharding + with intersperse padding. + """ + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + pred_resolved_global_host, + token_to_rep_atom_global_host, + atom_counts_per_token_host, + token_pad_mask_global_host, + true_coords_resolved_mask_global_host, + expected_loss_host, + expected_pred_grad_host, + multiplicity, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device_mesh = manager.device_mesh_subgroups + dtype = pred_resolved_global_host.dtype + + # Distribute pred_resolved: (B*mult, N_token, 2) with (Shard(0), Shard(1), Replicate()) + pred_resolved_dtensor = distribute_tensor( + pred_resolved_global_host.to(manager.device), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ).requires_grad_(True) + + # Distribute token_pad_mask: (B, N_token) with (Shard(0), Shard(1), Replicate()) + token_pad_mask_dtensor = distribute_tensor( + token_pad_mask_global_host.to(manager.device), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ) + + # --- Distribute atom features using distribute_atom_features utility --- + # Prepare inputs dict with all atom features (including per-multiplicity keys) + size_batch = token_to_rep_atom_global_host.shape[0] + inputs_atom = { + "atom_counts_per_token": atom_counts_per_token_host.to(dtype=torch.int64), + "token_to_rep_atom": token_to_rep_atom_global_host.to(dtype=dtype), + } + # Add per-multiplicity resolved masks: unflatten [B*mult, N_atom] -> [B, mult, N_atom] + resolved_mask_unflat = true_coords_resolved_mask_global_host.unflatten(0, (size_batch, multiplicity)) + for i_mul in range(multiplicity): + inputs_atom[f"true_coords_resolved_mask_{i_mul}"] = resolved_mask_unflat[:, i_mul].to(dtype=dtype) + + # Define placements for CP submesh and full mesh + placements_cp = { + "atom_counts_per_token": (Shard(0), Replicate()), + "token_to_rep_atom": (Shard(0), Replicate()), + } + placements_dp_cp = { + "token_to_rep_atom": (Shard(0), Shard(1), Replicate()), + } + for i_mul in range(multiplicity): + placements_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Replicate()) + placements_dp_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + + # Distribute atom features with intersperse padding + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + device_mesh, + manager.group["cp"], + multiplicities={"true_coords_resolved_mask": multiplicity}, + ) + + true_coords_resolved_mask_dtensor = feats_atom["true_coords_resolved_mask"] + token_to_rep_atom_dtensor = feats_atom["token_to_rep_atom"] + + # Create feature dictionary + feats_dtensor = { + "token_to_rep_atom": token_to_rep_atom_dtensor, + "token_pad_mask": token_pad_mask_dtensor, + } + + # Create copies to verify inputs aren't modified + pred_resolved_copy = pred_resolved_dtensor.to_local().detach().clone() + + # Forward pass + loss_dtensor = resolved_loss( + pred_resolved_dtensor, + feats_dtensor, + true_coords_resolved_mask_dtensor, + multiplicity=multiplicity, + ) + + # Verify input wasn't modified (values only, not requires_grad) + torch.testing.assert_close( + pred_resolved_copy, + pred_resolved_dtensor.to_local().detach(), + ) + + # Verify forward pass results + expected_loss_dtensor = distribute_tensor( + expected_loss_host.to(manager.device), + device_mesh=device_mesh, + placements=(Replicate(), Replicate(), Replicate()), + src_data_rank=None, + ) + + assert ( + loss_dtensor.shape == expected_loss_dtensor.shape + ), f"Loss shape mismatch: expected {expected_loss_dtensor.shape}, got {loss_dtensor.shape}" + torch.testing.assert_close( + loss_dtensor.to_local(), + expected_loss_dtensor.to_local(), + ) + + # Backward pass + loss_dtensor.backward() + + # Verify gradient + assert pred_resolved_dtensor.grad is not None, "Gradient not computed for pred_resolved" + assert ( + pred_resolved_dtensor.grad.shape == pred_resolved_global_host.shape + ), f"Grad shape mismatch: expected {pred_resolved_global_host.shape}, got {pred_resolved_dtensor.grad.shape}" + + # Gather and compare gradients + grad_global_result = pred_resolved_dtensor.grad.full_tensor().cpu() + torch.testing.assert_close( + grad_global_result, + expected_pred_grad_host, + ) + + # Verify full tensor matches expected + loss_global_result = loss_dtensor.full_tensor().cpu() + torch.testing.assert_close(loss_global_result, expected_loss_host) + + DistributedManager.cleanup() + monkeypatch.undo() + + +def parallel_assert_resolved_nll( + rank: int, + payload: tuple, +): + """Parallel test function for resolved_negative_log_likelihood. + + This tests the shardwise NLL computation in isolation with multiplicity support. + Uses pad_and_scatter_atom_features_dtensor for proper atom feature sharding. + """ + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + pred_resolved_global_host, + token_to_rep_atom_global_host, + atom_counts_per_token_host, + true_coords_resolved_mask_global_host, + expected_errors_host, + expected_d_errors_host, + expected_pred_grad_host, + multiplicity, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device_mesh = manager.device_mesh_subgroups + dtype = pred_resolved_global_host.dtype + + # Distribute pred_resolved: (B*mult, N_token, 2) with (Shard(0), Shard(1), Replicate()) + pred_resolved_dtensor = distribute_tensor( + pred_resolved_global_host.to(manager.device), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ).requires_grad_(True) + + # --- Distribute atom features using distribute_atom_features utility --- + size_batch = token_to_rep_atom_global_host.shape[0] + inputs_atom = { + "atom_counts_per_token": atom_counts_per_token_host.to(dtype=torch.int64), + "token_to_rep_atom": token_to_rep_atom_global_host.to(dtype=dtype), + } + # Add per-multiplicity resolved masks: unflatten [B*mult, N_atom] -> [B, mult, N_atom] + resolved_mask_unflat = true_coords_resolved_mask_global_host.unflatten(0, (size_batch, multiplicity)) + for i_mul in range(multiplicity): + inputs_atom[f"true_coords_resolved_mask_{i_mul}"] = resolved_mask_unflat[:, i_mul].to(dtype=dtype) + + # Define placements for CP submesh and full mesh + placements_cp = { + "atom_counts_per_token": (Shard(0), Replicate()), + "token_to_rep_atom": (Shard(0), Replicate()), + } + placements_dp_cp = { + "token_to_rep_atom": (Shard(0), Shard(1), Replicate()), + } + for i_mul in range(multiplicity): + placements_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Replicate()) + placements_dp_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + + # Distribute atom features with intersperse padding + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + device_mesh, + manager.group["cp"], + multiplicities={"true_coords_resolved_mask": multiplicity}, + ) + + true_coords_resolved_mask_dtensor = feats_atom["true_coords_resolved_mask"] + token_to_rep_atom_dtensor = feats_atom["token_to_rep_atom"] + + # Forward pass + errors_dtensor = resolved_negative_log_likelihood( + pred_resolved_dtensor, + token_to_rep_atom_dtensor, + true_coords_resolved_mask_dtensor, + ) + + # Verify forward pass + expected_errors_dtensor = distribute_tensor( + expected_errors_host.to(manager.device), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ) + + torch.testing.assert_close( + errors_dtensor.to_local(), + expected_errors_dtensor.to_local(), + ) + + # Backward pass with custom gradient + d_errors = distribute_tensor( + expected_d_errors_host.to(manager.device), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ) + errors_dtensor.backward(d_errors) + + # Verify gradient + assert pred_resolved_dtensor.grad is not None, "Gradient not computed" + grad_global_result = pred_resolved_dtensor.grad.full_tensor().cpu() + torch.testing.assert_close( + grad_global_result, + expected_pred_grad_host, + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, device_type={x[2]}", +) +@pytest.mark.parametrize("multiplicity", [1, 2]) +def test_resolved_loss(setup_env, multiplicity): + """Test resolved_loss DTensor implementation against serial reference.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + torch.manual_seed(seed) + + # Test dimensions + dp_size = grid_group_sizes["dp"] + cp_size = grid_group_sizes["cp"][0] + + # Batch size must equal DP size for per-sample processing in pad_and_scatter + B = dp_size + n_tokens_per_shard = 16 + N_token = n_tokens_per_shard * cp_size + # Atoms: use variable atoms per token (1-3 atoms per token) + n_atoms_per_token_min, n_atoms_per_token_max = 1, 3 + # Estimate total atoms: avg 2 atoms/token * N_token, rounded to be divisible by cp_size + avg_atoms_per_token = (n_atoms_per_token_min + n_atoms_per_token_max) / 2 + N_atom = int(avg_atoms_per_token * N_token) + # Make N_atom divisible by cp_size for even sharding + N_atom = ((N_atom + cp_size - 1) // cp_size) * cp_size + dtype = torch.float64 + + # Use random_features to generate token_to_rep_atom with proper block-diagonal structure + # where each token randomly picks one of its owned atoms as representative + feats = random_features( + size_batch=B, + n_tokens=N_token, + n_atoms=N_atom, + n_msa=1, # Not used for this test + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=(-0.1, 0.1), + selected_keys=["token_to_rep_atom", "token_pad_mask", "atom_pad_mask", "atom_counts_per_token"], + ) + + token_to_rep_atom_global = feats["token_to_rep_atom"].to(dtype=dtype) + token_pad_mask_global = feats["token_pad_mask"].to(dtype=dtype) + atom_counts_per_token = feats["atom_counts_per_token"] + + # Add some padding at the end of each shard for testing + for shard in range(cp_size): + end_idx = (shard + 1) * n_tokens_per_shard + token_pad_mask_global[:, end_idx - 2 : end_idx] = 0 + + # Create pred_resolved: (B*mult, N_token, 2) + pred_resolved_global = torch.empty(B * multiplicity, N_token, 2, device=device_type, dtype=dtype) + init_tensors_uniform([pred_resolved_global], low=-0.5, high=0.5) + pred_resolved_global.requires_grad_(True) + + # Create true_coords_resolved_mask: (B*mult, N_atom) + true_coords_resolved_mask_global = torch.randint(0, 2, (B * multiplicity, N_atom), device=device_type, dtype=dtype) + + # Create feature dictionary for serial computation + feats_global = { + "token_to_rep_atom": token_to_rep_atom_global, + "token_pad_mask": token_pad_mask_global, + } + + # Compute serial reference + expected_loss = serial_resolved_loss( + pred_resolved_global, + feats_global, + true_coords_resolved_mask_global, + token_level_confidence=True, + multiplicity=multiplicity, + ) + expected_loss.backward() + expected_loss_host = expected_loss.detach().clone().cpu() + expected_pred_grad_host = pred_resolved_global.grad.detach().clone().cpu() + + # Prepare payload for parallel test + # Note: atom_counts_per_token is needed for pad_and_scatter_atom_features_dtensor + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + pred_resolved_global.detach().clone().cpu(), + token_to_rep_atom_global.clone().cpu(), + atom_counts_per_token.clone().cpu(), + token_pad_mask_global.clone().cpu(), + true_coords_resolved_mask_global.clone().cpu(), + expected_loss_host, + expected_pred_grad_host, + multiplicity, + ) + + # Launch parallel test + spawn_multiprocessing(parallel_assert_resolved_loss, world_size, payload) + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, device_type={x[2]}", +) +def test_resolved_negative_log_likelihood(setup_env): + """Test resolved_negative_log_likelihood in isolation with multiplicity=2.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 123 + torch.manual_seed(seed) + + # Test dimensions + dp_size = grid_group_sizes["dp"] + cp_size = grid_group_sizes["cp"][0] + multiplicity = 2 # Hardcoded multiplicity for this test + + # Batch size must equal DP size for per-sample processing + B = dp_size + n_tokens_per_shard = 8 + N_token = n_tokens_per_shard * cp_size + # Atoms: use variable atoms per token (1-3 atoms per token) + n_atoms_per_token_min, n_atoms_per_token_max = 1, 3 + avg_atoms_per_token = (n_atoms_per_token_min + n_atoms_per_token_max) / 2 + N_atom = int(avg_atoms_per_token * N_token) + # Make N_atom divisible by cp_size for even sharding + N_atom = ((N_atom + cp_size - 1) // cp_size) * cp_size + dtype = torch.float64 + + # Use random_features to generate token_to_rep_atom with proper structure + feats = random_features( + size_batch=B, + n_tokens=N_token, + n_atoms=N_atom, + n_msa=1, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=(-0.1, 0.1), + selected_keys=["token_to_rep_atom", "atom_counts_per_token"], + ) + + # token_to_rep_atom is NOT multiplexed: shape (B, N_token, N_atom) + token_to_rep_atom_global = feats["token_to_rep_atom"].to(dtype=dtype) + atom_counts_per_token = feats["atom_counts_per_token"] + + # Create pred_resolved with multiplicity: (B*mult, N_token, 2) + pred_resolved_global = torch.empty(B * multiplicity, N_token, 2, device=device_type, dtype=dtype) + init_tensors_uniform([pred_resolved_global], low=-0.5, high=0.5) + pred_resolved_global.requires_grad_(True) + + # true_coords_resolved_mask with multiplicity: (B*mult, N_atom) + true_coords_resolved_mask_global = torch.randint(0, 2, (B * multiplicity, N_atom), device=device_type, dtype=dtype) + + # Compute serial reference for NLL using einsum (matches the DTensor implementation) + # token_to_rep_atom: (B, N_token, N_atom) -> "btj" + # resolved_mask reshaped: (B, mult, N_atom) -> "bmj" + # ref_mask: (B, mult, N_token) -> "bmt" then flatten to (B*mult, N_token) + resolved_mask_reshaped = true_coords_resolved_mask_global.view(B, multiplicity, N_atom) + ref_mask = torch.einsum("btj,bmj->bmt", token_to_rep_atom_global, resolved_mask_reshaped) + ref_mask = ref_mask.flatten(0, 1) # (B*mult, N_token) + + log_softmax_resolved = torch.nn.functional.log_softmax(pred_resolved_global, dim=-1) + expected_errors = -ref_mask * log_softmax_resolved[:, :, 0] - (1 - ref_mask) * log_softmax_resolved[:, :, 1] + + # Create gradient for backward pass + expected_d_errors = torch.empty_like(expected_errors) + init_tensors_uniform([expected_d_errors], low=-0.5, high=0.5) + + # Backward with custom gradient + expected_errors.backward(expected_d_errors) + expected_errors_host = expected_errors.detach().clone().cpu() + expected_d_errors_host = expected_d_errors.detach().clone().cpu() + expected_pred_grad_host = pred_resolved_global.grad.detach().clone().cpu() + + # Prepare payload with atom_counts_per_token for pad_and_scatter + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + pred_resolved_global.detach().clone().cpu(), + token_to_rep_atom_global.clone().cpu(), + atom_counts_per_token.clone().cpu(), + true_coords_resolved_mask_global.clone().cpu(), + expected_errors_host, + expected_d_errors_host, + expected_pred_grad_host, + multiplicity, + ) + + # Launch parallel test + spawn_multiprocessing(parallel_assert_resolved_nll, world_size, payload) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/distributed/model/loss/test_dtensor_distogram.py b/tests/distributed/model/loss/test_dtensor_distogram.py new file mode 100644 index 000000000..0e556ee3a --- /dev/null +++ b/tests/distributed/model/loss/test_dtensor_distogram.py @@ -0,0 +1,371 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""DTensor-based tests for Boltz-2 distogram loss. + +Tests the DTensor CP implementation against serial loss references: +- distogramv2 (Boltz-2): 5D tensors with K conformers and D distograms +- distogram v1 (Boltz-1): 4D tensors unsqueezed to 5D with K=1, D=1 + +Maps to: src/boltz/distributed/model/loss/distogram.py +""" + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.loss.distogram import ( + distogram_loss as distogram_loss_dtensor, +) +from boltz.model.loss.distogram import distogram_loss as distogram_loss_v1 +from boltz.model.loss.distogramv2 import distogram_loss as distogram_loss_v2 +from boltz.testing.utils import ( + assert_tensors_identical, + chunk_along_dim, + init_tensors_uniform, + skip_if_cuda_not_avail_or_device_count_less_than_word_size, + spawn_multiprocessing, +) + + +def parallel_assert_distogram_loss_dtensor( + rank, + payload, +): + """Worker function for DTensor distogram loss testing.""" + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + pred_global_host, + target_global_host, + mask_global_host, + global_loss_expected_host, + batch_loss_expected_host, + d_global_loss_host, + d_pred_expected_host, + aggregate_distogram, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + comm = TransposeComm(manager.group["cp"], manager.layout_subgroups["cp"]) + + # Distribute tensors as DTensors + pred_dtensor = distribute_tensor( + pred_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=(Shard(0), Shard(1), Shard(2)), + ).requires_grad_(True) + + target_dtensor = distribute_tensor( + target_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=(Shard(0), Shard(1), Shard(2)), + ) + + mask_dtensor = distribute_tensor( + mask_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=(Shard(0), Shard(1), Replicate()), + ) + + output_dtensor = {"pdistogram": pred_dtensor} + feats_dtensor = { + "disto_target": target_dtensor, + "token_disto_mask": mask_dtensor, + } + + # Create copies to verify inputs aren't modified + output_dtensor_copy = { + key: tensor.detach().clone().requires_grad_(tensor.requires_grad) for key, tensor in output_dtensor.items() + } + feats_dtensor_copy = { + key: tensor.detach().clone().requires_grad_(tensor.requires_grad) for key, tensor in feats_dtensor.items() + } + + # Forward pass + global_loss_result, batch_loss_result = distogram_loss_dtensor( + output_dtensor, feats_dtensor, comm, aggregate_distogram=aggregate_distogram + ) + + # Verify placements have correct ndim for the 3D mesh (dp, cp0, cp1) + mesh_ndim = manager.device_mesh_subgroups.ndim + assert len(global_loss_result.placements) == mesh_ndim, ( + f"global_loss placements should have {mesh_ndim} elements for {mesh_ndim}D mesh, " + f"got {len(global_loss_result.placements)}: {global_loss_result.placements}" + ) + assert len(batch_loss_result.placements) == mesh_ndim, ( + f"batch_loss placements should have {mesh_ndim} elements for {mesh_ndim}D mesh, " + f"got {len(batch_loss_result.placements)}: {batch_loss_result.placements}" + ) + + # Verify inputs weren't modified (binary identity) + assert_tensors_identical( + output_dtensor_copy["pdistogram"].to_local(), + output_dtensor["pdistogram"].to_local(), + check_grad=False, + check_grad_fn=False, + ) + + for key in feats_dtensor: + assert_tensors_identical( + feats_dtensor_copy[key].to_local(), + feats_dtensor[key].to_local(), + check_grad=False, + check_grad_fn=False, + ) + + # Use full_tensor() for global_loss because placements may be Partial + torch.testing.assert_close( + global_loss_result.full_tensor(), + global_loss_expected_host.to(manager.device), + ) + + # batch_loss is [B] sharded on DP dim - chunk reference to match local shard + dp_rank = manager.group_rank["dp"] + dp_size = grid_group_sizes["dp"] + batch_loss_expected_local = chunk_along_dim(batch_loss_expected_host, dim=0, chunk_i=dp_rank, chunks=dp_size) + torch.testing.assert_close( + batch_loss_result.to_local(), + batch_loss_expected_local.to(manager.device), + ) + + # Backward pass + d_global_loss_dtensor = distribute_tensor( + d_global_loss_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=(Replicate(), Replicate(), Replicate()), + ) + + global_loss_result.backward(d_global_loss_dtensor) + + # Verify gradient on pred + assert output_dtensor["pdistogram"].grad is not None, "pred gradient is None - trivial equality guard failed" + assert d_pred_expected_host is not None, "Reference pred gradient is None - test setup error" + + # Shard the reference gradient to match this rank's local portion. + # Chunk along DP dim (batch), then CP dims (spatial). + layout_map = manager.layout_subgroups["cp"] + i, j = layout_map.unravel(manager.group_rank["cp"]) + + d_pred_expected_local = chunk_along_dim(d_pred_expected_host, dim=0, chunk_i=dp_rank, chunks=dp_size) + d_pred_expected_local = chunk_along_dim(d_pred_expected_local, dim=1, chunk_i=i, chunks=layout_map.shape[0]) + d_pred_expected_local = chunk_along_dim(d_pred_expected_local, dim=2, chunk_i=j, chunks=layout_map.shape[1]) + + torch.testing.assert_close( + output_dtensor["pdistogram"].grad.to_local().cpu(), + d_pred_expected_local, + msg=lambda m: f"Pred gradient mismatch on rank {rank}\n{m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, loss_config", + [ + # CUDA dp=1 cp=(1,1): serial-equivalent sanity check (1 GPU) + (((1, (1, 1)), True, "cuda", "ENV"), (1, 1, True)), + # CUDA dp=2 cp=(1,1): DP-only path (2 GPUs) + (((2, (1, 1)), True, "cuda", "ENV"), (1, 1, True)), + # CUDA dp=2 cp=(2,2): full DP+CP with the harder non-aggregate path + (((2, (2, 2)), True, "cuda", "ENV"), (3, 2, False)), + # CPU dp=2 cp=(1,1): dp>1 regression guard with aggregate + (((2, (1, 1)), True, "cpu", "ENV"), (1, 1, True)), + # CPU dp=2 cp=(3,3): DP + non-power-of-two CP with non-aggregate + (((2, (3, 3)), True, "cpu", "ENV"), (3, 2, False)), + # CPU dp=1 cp=(2,2): K>1 conformers with aggregate (sum+normalize K conformers) + (((1, (2, 2)), True, "cpu", "ENV"), (3, 1, True)), + ], + indirect=("setup_env",), + ids=[ + "cuda-dp1-cp1x1-K1D1-agg", + "cuda-dp2-cp1x1-K1D1-agg", + "cuda-dp2-cp2x2-K3D2-noagg", + "cpu-dp2-cp1x1-K1D1-agg", + "cpu-dp2-cp3x3-K3D2-noagg", + "cpu-dp1-cp2x2-K3D1-agg", + ], +) +def test_dtensor_distogram_loss(setup_env, loss_config): + """Test DTensor distogram loss against serial reference.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + K, D, aggregate_distogram = loss_config + + skip_if_cuda_not_avail_or_device_count_less_than_word_size( + device_type=device_type, + world_size=world_size, + ) + + # Create test tensors + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 10 + num_bins = 16 + + with torch.random.fork_rng(devices=[], enabled=True): + torch.manual_seed(42) + + min_init_val = -0.5 + max_init_val = 0.5 + + pred_global = torch.empty((B, N, N, D, num_bins), requires_grad=True) + init_tensors_uniform([pred_global], low=min_init_val, high=max_init_val) + + # Create targets + if aggregate_distogram and K == 1: + target_idx = torch.randint(0, num_bins, (B, N, N)) + target_global = torch.nn.functional.one_hot(target_idx, num_classes=num_bins).float() + target_global = target_global.unsqueeze(3) + else: + target_global = torch.empty((B, N, N, K, num_bins)) + init_tensors_uniform([target_global], low=0.01, high=1.0) + target_global = target_global / target_global.sum(dim=-1, keepdim=True).clamp(min=1e-8) + target_global.requires_grad_(False) + + # Create mask + mask_global = torch.randint(0, 2, (B, N), dtype=torch.bool) + + # Run serial forward pass as reference + output_global = {"pdistogram": pred_global} + feats_global = {"disto_target": target_global, "token_disto_mask": mask_global} + + global_loss_expected, batch_loss_expected = distogram_loss_v2( + output_global, feats_global, aggregate_distogram=aggregate_distogram + ) + + # Create upstream gradient + d_global_loss = torch.empty(global_loss_expected.shape, dtype=global_loss_expected.dtype) + init_tensors_uniform([d_global_loss], low=min_init_val, high=max_init_val) + + global_loss_expected.backward(d_global_loss) + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + pred_global.detach().clone().cpu(), + target_global.detach().clone().cpu(), + mask_global.detach().clone().cpu(), + global_loss_expected.detach().clone().cpu(), + batch_loss_expected.detach().clone().cpu(), + d_global_loss.detach().clone().cpu(), + pred_global.grad.detach().clone().cpu(), + aggregate_distogram, + ) + + spawn_multiprocessing(parallel_assert_distogram_loss_dtensor, world_size, payload) + + +@pytest.mark.parametrize( + "setup_env", + [ + # CPU dp=2 cp=(1,1): minimal distributed setup to verify v1 equivalence + ((2, (1, 1)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=["cpu-dp2-cp1x1"], +) +def test_dtensor_distogram_loss_v1_compat(setup_env): + """Test that the DTensor loss with D=1, K=1 matches the v1 serial loss. + + The v1 serial loss uses 4D tensors [B, N, N, bins]. To use the v2/CP + implementation as a v1 loss, unsqueeze dim 3 of pred and target to get + [B, N, N, 1, bins] with D=1 and K=1. With aggregate_distogram=True, + min-over-D and mean-over-K are identity ops, so the result must match v1. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + skip_if_cuda_not_avail_or_device_count_less_than_word_size( + device_type=device_type, + world_size=world_size, + ) + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 10 + num_bins = 16 + + with torch.random.fork_rng(devices=[], enabled=True): + torch.manual_seed(42) + + min_init_val = -0.5 + max_init_val = 0.5 + + # v1 tensors are 4D: [B, N, N, bins] + pred_4d = torch.empty((B, N, N, num_bins), requires_grad=True) + init_tensors_uniform([pred_4d], low=min_init_val, high=max_init_val) + + target_idx = torch.randint(0, num_bins, (B, N, N)) + target_4d = torch.nn.functional.one_hot(target_idx, num_classes=num_bins).float() + + mask_global = torch.randint(0, 2, (B, N), dtype=torch.bool) + + # Run v1 serial loss as reference + output_v1 = {"pdistogram": pred_4d} + feats_v1 = {"disto_target": target_4d, "token_disto_mask": mask_global} + global_loss_expected, batch_loss_expected = distogram_loss_v1(output_v1, feats_v1) + + d_global_loss = torch.empty(global_loss_expected.shape, dtype=global_loss_expected.dtype) + init_tensors_uniform([d_global_loss], low=min_init_val, high=max_init_val) + global_loss_expected.backward(d_global_loss) + + # Unsqueeze to 5D for the CP implementation: [B,N,N,bins] → [B,N,N,1,bins] + pred_5d = pred_4d.detach().clone().unsqueeze(3) + target_5d = target_4d.detach().clone().unsqueeze(3) + # Gradient also gains the unsqueeze dim + d_pred_5d = pred_4d.grad.detach().clone().unsqueeze(3) + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + pred_5d.cpu(), + target_5d.cpu(), + mask_global.detach().clone().cpu(), + global_loss_expected.detach().clone().cpu(), + batch_loss_expected.detach().clone().cpu(), + d_global_loss.detach().clone().cpu(), + d_pred_5d.cpu(), + True, # aggregate_distogram + ) + + spawn_multiprocessing(parallel_assert_distogram_loss_dtensor, world_size, payload) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/distributed/model/loss/test_dtensor_get_true_coordinates.py b/tests/distributed/model/loss/test_dtensor_get_true_coordinates.py new file mode 100644 index 000000000..5fcd91719 --- /dev/null +++ b/tests/distributed/model/loss/test_dtensor_get_true_coordinates.py @@ -0,0 +1,397 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from __future__ import annotations + +import types + +import pytest +import torch +from torch.distributed.tensor import DTensor, Replicate, Shard + +from boltz.data.mol import minimum_lddt_symmetry_coords as serial_minimum_lddt_symmetry_coords +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.models.boltz2 import Boltz2 as DistributedBoltz2 +from boltz.testing.utils import ( + distribute_atom_features, + get_feature_placements, + random_features, + spawn_multiprocessing, +) + + +def _build_symmetry_features_for_batch( + all_coords_batch: torch.Tensor, + all_resolved_mask_batch: torch.Tensor, + crop_to_all_atom_map_batch: torch.Tensor, + chain_swaps_batch: list | None = None, +): + """Build symmetry features for a batch of samples (Boltz-2 semantics).""" + batch_size = all_coords_batch.shape[0] + if chain_swaps_batch is None: + chain_swaps_batch = [[[]] for _ in range(batch_size)] + amino_acids_symmetries_batch = [[] for _ in range(batch_size)] + ligand_symmetries_batch = [[] for _ in range(batch_size)] + + return { + "all_coords": all_coords_batch, + "all_resolved_mask": all_resolved_mask_batch, + "crop_to_all_atom_map": crop_to_all_atom_map_batch, + "chain_swaps": chain_swaps_batch, + "amino_acids_symmetries": amino_acids_symmetries_batch, + "ligand_symmetries": ligand_symmetries_batch, + } + + +def _make_two_chain_swaps(n_atoms: int) -> list: + """Build chain_swaps for one sample with two equal-length chains that can be swapped.""" + half = n_atoms // 2 + identity = [] + swap_ab = [ + (0, half, half, 2 * half, 0, 1), + (half, 2 * half, 0, half, 1, 0), + ] + return [identity, swap_ab] + + +_atom_keys = {"atom_pad_mask", "coords", "atom_resolved_mask"} +_placements = get_feature_placements(atom_keys=_atom_keys, token_keys=set()) +_placements_atom_features = _placements["atom_features"] +_placements_cp_atom_features = _placements["cp_atom_features"] + +_placements_sample_coords = (Shard(0), Shard(1), Replicate()) +_placements_cp_sample_coords = (Shard(0), Replicate()) + +_placements_token_index = (Shard(0), Replicate(), Replicate()) +_placements_cp_token_index = (Shard(0), Replicate()) + + +def parallel_assert_get_true_coordinates(rank, payload): + """Test get_true_coordinates: symmetry path (parity) + non-symmetry path (types/shapes). + + Symmetry path: DP sharding of symmetry features, compares DTensor output against + serial minimum_lddt_symmetry_coords reference for numerical parity. + + Non-symmetry path: verifies outputs are DTensors (not plain tensors), shapes are + correct for expanded and unexpanded modes. + """ + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + multiplicity, + input_feats_global_host, + sample_coords_global_host, + feats_symmetry_global_host, + expected_true_coords_per_mult_sample, + expected_true_mask_per_mult_sample, + use_nontrivial_swaps, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device = manager.device + dtype = torch.float32 + + size_batch = input_feats_global_host["atom_pad_mask"].shape[0] + rank_dp = manager.group_rank["dp"] + + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in input_feats_global_host.items() + if k in _placements_cp_atom_features + } + sample_coords_unflat = sample_coords_global_host.unflatten(0, (size_batch, multiplicity)) + for i_mul in range(multiplicity): + inputs_atom[f"sample_coords_{i_mul}"] = sample_coords_unflat[:, i_mul].to(dtype=dtype) + + placements_cp = _placements_cp_atom_features | { + f"sample_coords_{i_mul}": _placements_cp_sample_coords for i_mul in range(multiplicity) + } + placements_dp_cp = _placements_atom_features | { + f"sample_coords_{i_mul}": _placements_sample_coords for i_mul in range(multiplicity) + } + + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + manager.device_mesh_subgroups, + manager.group["cp"], + multiplicities={"sample_coords": multiplicity}, + ) + + coords_dtensor = feats_atom["sample_coords"] + atom_pad_mask_dtensor = feats_atom["atom_pad_mask"] + n_atoms_padded = atom_pad_mask_dtensor.shape[1] + assert coords_dtensor.shape == ( + size_batch * multiplicity, + n_atoms_padded, + 3, + ), f"coords must be shape ({size_batch * multiplicity}, {n_atoms_padded}, 3) but got {coords_dtensor.shape}" + + # token_index: needed by get_true_coordinates to determine local_batch_size + from torch.distributed.tensor import distribute_tensor + + token_index_global = input_feats_global_host["token_index"].to(device) + token_index_dtensor = distribute_tensor( + token_index_global, + device_mesh=manager.device_mesh_subgroups, + placements=_placements_token_index, + ) + + num_dp_ranks = grid_group_sizes["dp"] + batch_size_global = size_batch + coords_global_batch = size_batch * multiplicity + local_batch_size = batch_size_global // num_dp_ranks + local_start = rank_dp * local_batch_size + local_end = local_start + local_batch_size + + batch_local = { + "token_index": token_index_dtensor, + "all_coords": feats_symmetry_global_host["all_coords"][local_start:local_end].to(device), + "all_resolved_mask": feats_symmetry_global_host["all_resolved_mask"][local_start:local_end].to(device), + "crop_to_all_atom_map": feats_symmetry_global_host["crop_to_all_atom_map"][local_start:local_end].to(device), + "chain_swaps": feats_symmetry_global_host["chain_swaps"][local_start:local_end], + "amino_acids_symmetries": feats_symmetry_global_host["amino_acids_symmetries"][local_start:local_end], + "ligand_symmetries": feats_symmetry_global_host["ligand_symmetries"][local_start:local_end], + "atom_pad_mask": atom_pad_mask_dtensor, + } + out_dtensor = {"sample_atom_coords": coords_dtensor} + + dummy = types.SimpleNamespace() + + result = DistributedBoltz2.get_true_coordinates( + dummy, + batch=batch_local, + out=out_dtensor, + diffusion_samples=multiplicity, + symmetry_correction=True, + ) + + true_coords_dtensor = result["true_coords"] + true_mask_dtensor = result["true_coords_resolved_mask"] + + assert ( + true_coords_dtensor.shape[0] == coords_global_batch + ), f"true_coords_dtensor.shape[0] should be {coords_global_batch}, got {true_coords_dtensor.shape[0]}" + assert ( + true_mask_dtensor.shape[0] == coords_global_batch + ), f"true_mask_dtensor.shape[0] should be {coords_global_batch}, got {true_mask_dtensor.shape[0]}" + + actual_coords_full = true_coords_dtensor.full_tensor().cpu() + # With symmetric correction, true_coords = shardwise_unsqueeze(true_coords, dim=1) is used to fix the shape mismatch, which gives us a 4D tensor. + if actual_coords_full.ndim == 4: + actual_coords_full = actual_coords_full.squeeze(1) + actual_mask_full = true_mask_dtensor.full_tensor().cpu() + atom_pad_mask_full = atom_pad_mask_dtensor.full_tensor().cpu() + + for mult_idx in range(coords_global_batch): + expected_coords = expected_true_coords_per_mult_sample[mult_idx] + expected_mask = expected_true_mask_per_mult_sample[mult_idx] + + batch_idx = mult_idx // multiplicity + real_atom_mask = atom_pad_mask_full[batch_idx].bool() + + actual_coords_no_pad = actual_coords_full[mult_idx, real_atom_mask, :] + actual_mask_no_pad = actual_mask_full[mult_idx, real_atom_mask] + + torch.testing.assert_close( + actual_coords_no_pad, + expected_coords, + msg=f"Sample {mult_idx}: true_coords mismatch", + ) + torch.testing.assert_close( + actual_mask_no_pad, + expected_mask.squeeze(0) if expected_mask.ndim > 1 else expected_mask, + msg=f"Sample {mult_idx}: true_mask mismatch", + ) + + assert result["rmsds"] == 0, f"rmsds should be 0, got {result['rmsds']}" + assert result["best_rmsd_recall"] == 0, f"best_rmsd_recall should be 0, got {result['best_rmsd_recall']}" + + # ---- Non-symmetry path: type, shape, and expand_to_diffusion_samples ---- + batch_local["coords"] = feats_atom["coords"] + batch_local["atom_resolved_mask"] = feats_atom["atom_resolved_mask"] + + result_nosym = DistributedBoltz2.get_true_coordinates( + dummy, + batch=batch_local, + out=out_dtensor, + diffusion_samples=multiplicity, + symmetry_correction=False, + expand_to_diffusion_samples=True, + ) + tc_nosym = result_nosym["true_coords"] + tm_nosym = result_nosym["true_coords_resolved_mask"] + + assert isinstance(tc_nosym, DTensor), f"Rank {rank}: non-sym true_coords should be DTensor, got {type(tc_nosym)}" + assert isinstance(tm_nosym, DTensor), f"Rank {rank}: non-sym true_mask should be DTensor, got {type(tm_nosym)}" + assert tc_nosym.shape == ( + coords_global_batch, + n_atoms_padded, + 3, + ), f"Rank {rank}: expanded true_coords shape {tc_nosym.shape} != ({coords_global_batch}, {n_atoms_padded}, 3)" + assert tm_nosym.shape == ( + coords_global_batch, + n_atoms_padded, + ), f"Rank {rank}: expanded true_mask shape {tm_nosym.shape} != ({coords_global_batch}, {n_atoms_padded})" + + result_unexpanded = DistributedBoltz2.get_true_coordinates( + dummy, + batch=batch_local, + out=out_dtensor, + diffusion_samples=multiplicity, + symmetry_correction=False, + expand_to_diffusion_samples=False, + ) + tc_unexpanded = result_unexpanded["true_coords"] + assert isinstance(tc_unexpanded, DTensor), f"Rank {rank}: unexpanded true_coords should be DTensor" + assert ( + tc_unexpanded.shape[0] == batch_size_global + ), f"Rank {rank}: unexpanded batch dim {tc_unexpanded.shape[0]} should be {batch_size_global}" + assert ( + tc_unexpanded.shape[0] < tc_nosym.shape[0] + ), f"Rank {rank}: unexpanded batch {tc_unexpanded.shape[0]} should be < expanded {tc_nosym.shape[0]}" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("use_nontrivial_swaps", [False, True], ids=["identity_swaps", "nontrivial_swaps"]) +@pytest.mark.parametrize( + "setup_env", + [((2, (2, 2)), True, "cuda", "ENV")], # dp=2, cp=(2,2), world_size=8 + indirect=("setup_env",), +) +def test_dtensor_get_true_coordinates(setup_env, use_nontrivial_swaps): + """Test get_true_coordinates: symmetry correction parity + non-symmetry shape/type checks. + + Symmetry path: dp=2, multiplicity=2, compares against serial + minimum_lddt_symmetry_coords for numerical parity. Parametrized over + identity and nontrivial chain swaps. + + Non-symmetry path: verifies outputs are DTensors with correct shapes, + and that expand_to_diffusion_samples=True/False produces correct batch dims. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + torch.manual_seed(123) + + size_ring = grid_group_sizes["cp"][0] + num_dp_ranks = grid_group_sizes["dp"] + batch_size_per_dp_rank = 1 + batch_size_global = batch_size_per_dp_rank * num_dp_ranks + multiplicity = 2 + + n_atoms_per_token = 3 + n_tokens = size_ring * 4 + n_atoms = n_atoms_per_token * n_tokens + + feats_from_random = random_features( + size_batch=batch_size_global, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=1, + atom_counts_per_token_range=(1, n_atoms_per_token), + device=torch.device("cpu"), + float_value_range=(-1.0, 1.0), + selected_keys=["atom_pad_mask", "token_index", "atom_counts_per_token", "coords", "atom_resolved_mask"], + ) + + atom_pad_mask_global = feats_from_random["atom_pad_mask"] + token_index_global = feats_from_random["token_index"] + atom_counts_per_token_global = feats_from_random["atom_counts_per_token"] + + sample_coords_global = torch.randn((batch_size_global * multiplicity, n_atoms, 3), dtype=torch.float32) + + input_feats_global = { + "atom_pad_mask": atom_pad_mask_global, + "token_index": token_index_global, + "atom_counts_per_token": atom_counts_per_token_global, + "coords": feats_from_random["coords"], + "atom_resolved_mask": feats_from_random["atom_resolved_mask"], + } + + coords_for_symmetry = sample_coords_global[::multiplicity] + all_coords_global = coords_for_symmetry.clone() + all_resolved_mask_global = torch.ones((batch_size_global, n_atoms), dtype=torch.bool) + crop_to_all_atom_map_global = ( + torch.arange(n_atoms, dtype=torch.long).unsqueeze(0).expand(batch_size_global, -1).contiguous() + ) + + chain_swaps_batch = None + if use_nontrivial_swaps: + chain_swaps_batch = [_make_two_chain_swaps(n_atoms) for _ in range(batch_size_global)] + + feats_symmetry_global = _build_symmetry_features_for_batch( + all_coords_global, + all_resolved_mask_global, + crop_to_all_atom_map_global, + chain_swaps_batch=chain_swaps_batch, + ) + + expected_true_coords_per_mult_sample = [] + expected_true_mask_per_mult_sample = [] + + for idx in range(batch_size_global): + for rep in range(multiplicity): + i = idx * multiplicity + rep + expected_coords, expected_mask = serial_minimum_lddt_symmetry_coords( + coords=sample_coords_global[i : i + 1], + feats=feats_symmetry_global, + index_batch=idx, + ) + expected_true_coords_per_mult_sample.append(expected_coords.squeeze(0).detach().clone().cpu()) + expected_true_mask_per_mult_sample.append(expected_mask.detach().clone().cpu()) + + input_feats_global_host = {k: v.detach().clone().cpu() for k, v in input_feats_global.items()} + sample_coords_global_host = sample_coords_global.detach().clone().cpu() + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + multiplicity, + input_feats_global_host, + sample_coords_global_host, + feats_symmetry_global, + expected_true_coords_per_mult_sample, + expected_true_mask_per_mult_sample, + use_nontrivial_swaps, + ) + + spawn_multiprocessing(parallel_assert_get_true_coordinates, world_size, payload) diff --git a/tests/distributed/model/loss/test_dtensor_pae_loss.py b/tests/distributed/model/loss/test_dtensor_pae_loss.py new file mode 100644 index 000000000..9a6154fe8 --- /dev/null +++ b/tests/distributed/model/loss/test_dtensor_pae_loss.py @@ -0,0 +1,412 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +"""Tests for DTensor pae_loss implementation.""" + +import unittest + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.comm import One2OneComm +from boltz.distributed.data.utils import distribute_features +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.loss.confidencev2 import pae_loss +from boltz.distributed.utils import get_group_rank_from_axial_shift +from boltz.model.loss.confidencev2 import pae_loss as serial_pae_loss +from boltz.testing.utils import ( + distribute_atom_features, + init_tensors_uniform, + random_features, + seed_by_rank, + spawn_multiprocessing, +) + + +def create_heterogeneous_pae_features( + B: int, + N_token: int, + N_atom: int, + multiplicity: int, + device: torch.device, + dtype: torch.dtype, + base_feats: dict, +) -> tuple[dict, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Create heterogeneous features for PAE loss testing. + + Generates raw unpadded data with natural heterogeneity from base_feats. + The heterogeneity comes from: + - Different atom_counts_per_token distributions (from random_features) + - Mixed polymer/nonpolymer mol_type (from random_features) + - Different resolution masks per batch + + Returns: + feats_global: Feature dictionary (raw, unpadded) + pred_pae: Predicted PAE logits + pred_atom_coords: Predicted coordinates (raw, unpadded) + true_atom_coords: True coordinates (raw, unpadded) + true_coords_resolved_mask: Resolution mask (raw, unpadded) + """ + num_bins = 64 + + feats_global = { + "frames_idx": base_feats["frames_idx"].to(device=device), + "frame_resolved_mask": base_feats["frame_resolved_mask"].to(device=device, dtype=dtype), + "asym_id": base_feats["asym_id"].to(device=device), + "atom_to_token": base_feats["atom_to_token"].to(device=device, dtype=dtype), + "atom_pad_mask": base_feats["atom_pad_mask"].to(device=device, dtype=dtype), + "mol_type": base_feats["mol_type"].to(device=device), + "token_pad_mask": base_feats["token_pad_mask"].to(device=device, dtype=dtype), + "atom_resolved_mask": base_feats["atom_resolved_mask"].to(device=device, dtype=dtype), + "is_nonpolymer_with_frame": base_feats["is_nonpolymer_with_frame"].to(device=device), + "atom_counts_per_token": base_feats["atom_counts_per_token"].to(device=device), + } + + # Main tensors - raw unpadded + pred_pae = torch.empty(B, multiplicity, N_token, N_token, num_bins, device=device, dtype=dtype) + init_tensors_uniform([pred_pae], low=-0.5, high=0.5) + + pred_atom_coords = torch.empty(B * multiplicity, N_atom, 3, device=device, dtype=dtype) + true_atom_coords = torch.empty(B * multiplicity, N_atom, 3, device=device, dtype=dtype) + init_tensors_uniform([pred_atom_coords, true_atom_coords], low=-10.0, high=10.0) + + # Heterogeneous resolution masks per batch, repeated for each multiplicity copy. + # pae_loss indexes with arange(0, B*mult, mult) to pick one mask per sample. + true_coords_resolved_mask = torch.ones(B * multiplicity, N_atom, device=device, dtype=dtype) + for b in range(B): + unresolved_fraction = 0.2 + 0.1 * (b % 3) + n_unresolved = int(N_atom * unresolved_fraction) + unresolved_indices = torch.randperm(N_atom, device=device)[:n_unresolved] + for m in range(multiplicity): + true_coords_resolved_mask[b * multiplicity + m, unresolved_indices] = 0 + + return feats_global, pred_pae, pred_atom_coords, true_atom_coords, true_coords_resolved_mask + + +def parallel_assert_pae_loss( + rank: int, + payload: tuple, +): + """Parallel test function for pae_loss. + + This function runs on each rank in the distributed setup and verifies that + the DTensor implementation matches the serial reference. + """ + test_config, inputs_global, feats_global, expected = payload + + # Unpack test configuration + grid_group_sizes = test_config["grid_group_sizes"] + device_type = test_config["device_type"] + backend = test_config["backend"] + env_per_rank = test_config["env_per_rank"] + multiplicity = test_config["multiplicity"] + max_dist = test_config["max_dist"] + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device_mesh = manager.device_mesh_subgroups + group_layout = manager.layout_subgroups["cp"] + dtype = inputs_global["pred_pae"].dtype + + # Setup communication for coordinate transpose + rank_coords = group_layout.unravel(manager.group_rank["cp"]) + comm = One2OneComm( + manager.group["cp"], + rank_send_to=get_group_rank_from_axial_shift(rank_coords, 0, -1, group_layout), + rank_recv_from=get_group_rank_from_axial_shift(rank_coords, 0, 1, group_layout), + ) + + # --- Distribute pred_pae (B*mult, N_token, N_token, bins) --- + pred_pae_dtensor = distribute_tensor( + inputs_global["pred_pae"].to(manager.device), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Shard(2)), + ).requires_grad_(True) + + # --- Distribute atom features with intersperse padding --- + # Use distribute_atom_features for all atom-indexed tensors including coords + size_batch = feats_global["atom_counts_per_token"].shape[0] + pred_coords_unflat = inputs_global["pred_atom_coords"].unflatten(0, (size_batch, multiplicity)) + true_coords_unflat = inputs_global["true_atom_coords"].unflatten(0, (size_batch, multiplicity)) + true_mask_unflat = inputs_global["true_coords_resolved_mask"].unflatten(0, (size_batch, multiplicity)) + + inputs_atom = { + "atom_counts_per_token": feats_global["atom_counts_per_token"].to(dtype=torch.int64), + "atom_to_token": feats_global["atom_to_token"].to(dtype=dtype), + "atom_pad_mask": feats_global["atom_pad_mask"].to(dtype=dtype), + "atom_resolved_mask": feats_global["atom_resolved_mask"].to(dtype=dtype), + "frames_idx": feats_global["frames_idx"], + } + for i_mul in range(multiplicity): + inputs_atom[f"pred_atom_coords_{i_mul}"] = pred_coords_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"true_atom_coords_{i_mul}"] = true_coords_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"true_coords_resolved_mask_{i_mul}"] = true_mask_unflat[:, i_mul].to(dtype=dtype) + + placements_cp = { + "atom_counts_per_token": (Shard(0), Replicate()), + "atom_to_token": (Shard(0), Replicate()), + "atom_pad_mask": (Shard(0), Replicate()), + "atom_resolved_mask": (Shard(0), Replicate()), + "frames_idx": (Shard(1), Replicate()), + } + placements_dp_cp = { + "atom_to_token": (Shard(0), Shard(1), Replicate()), + "atom_pad_mask": (Shard(0), Shard(1), Replicate()), + "atom_resolved_mask": (Shard(0), Shard(1), Replicate()), + "frames_idx": (Shard(0), Shard(1), Replicate()), + } + for i_mul in range(multiplicity): + placements_cp[f"pred_atom_coords_{i_mul}"] = (Shard(0), Replicate()) + placements_cp[f"true_atom_coords_{i_mul}"] = (Shard(0), Replicate()) + placements_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Replicate()) + placements_dp_cp[f"pred_atom_coords_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + placements_dp_cp[f"true_atom_coords_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + placements_dp_cp[f"true_coords_resolved_mask_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + device_mesh, + manager.group["cp"], + multiplicities={ + "pred_atom_coords": multiplicity, + "true_atom_coords": multiplicity, + "true_coords_resolved_mask": multiplicity, + }, + ) + + pred_atom_coords_dtensor = feats_atom["pred_atom_coords"].requires_grad_(True) + true_atom_coords_dtensor = feats_atom["true_atom_coords"] + true_coords_resolved_mask_dtensor = feats_atom["true_coords_resolved_mask"] + + # --- Distribute token features (not atom-indexed) --- + token_placements = (Shard(0), Shard(1), Replicate()) + if manager.group_rank["world"] == 0: + token_feats_global = { + "frame_resolved_mask": feats_global["frame_resolved_mask"].to(manager.device), + "asym_id": feats_global["asym_id"].to(manager.device), + "mol_type": feats_global["mol_type"].to(manager.device), + "token_pad_mask": feats_global["token_pad_mask"].to(manager.device), + "is_nonpolymer_with_frame": feats_global["is_nonpolymer_with_frame"].to(manager.device), + } + else: + token_feats_global = None + + token_placements_map = { + "frame_resolved_mask": token_placements, + "asym_id": token_placements, + "mol_type": token_placements, + "token_pad_mask": token_placements, + "is_nonpolymer_with_frame": token_placements, + } + feats_token_dtensor = distribute_features( + token_feats_global, + token_placements_map, + manager.group["world"], + manager.group_ranks["world"][0], + device_mesh, + ) + + feats_dtensor = { + "frames_idx": feats_atom["frames_idx"], + **feats_token_dtensor, + # Atom features from distribute_atom_features + "atom_to_token": feats_atom["atom_to_token"], + "atom_pad_mask": feats_atom["atom_pad_mask"], + "atom_resolved_mask": feats_atom["atom_resolved_mask"], + } + + # --- Forward pass --- + loss_dtensor = pae_loss( + pred_pae_dtensor, + pred_atom_coords_dtensor, + true_atom_coords_dtensor, + true_coords_resolved_mask_dtensor, + feats_dtensor, + comm=comm, + dist_manager=manager, + group_layout=group_layout, + multiplicity=multiplicity, + max_dist=max_dist, + ) + + # --- Verify forward pass results --- + loss_actual = loss_dtensor.full_tensor().cpu() + loss_expected = expected["loss"] + torch.testing.assert_close(loss_actual, loss_expected) + + # --- Backward pass --- + loss_dtensor.backward() + + # Verify pred_pae gradient + assert pred_pae_dtensor.grad is not None, "Gradient not computed for pred_pae" + torch.testing.assert_close( + pred_pae_dtensor.grad.full_tensor().cpu(), + expected["pred_pae_grad"], + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (3, 3)), True, "cpu", "ENV"), + ((2, (2, 2)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, device_type={x[2]}", +) +@pytest.mark.parametrize("multiplicity", [1, 2]) +def test_pae_loss(setup_env, multiplicity): + """Test pae_loss DTensor implementation against serial reference.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed_by_rank(0) + + # Test dimensions + dp_size = grid_group_sizes["dp"] + cp_size = grid_group_sizes["cp"][0] + + B = dp_size # Batch size must equal DP size for per-sample processing + n_tokens_per_shard = 16 + N_token = n_tokens_per_shard * cp_size + n_atoms_per_token_min, n_atoms_per_token_max = 1, 18 + avg_atoms_per_token = (n_atoms_per_token_min + n_atoms_per_token_max) / 2 + N_atom = int(avg_atoms_per_token * N_token * 1.25) # 25% buffer + N_atom = ((N_atom + cp_size - 1) // cp_size) * cp_size + max_dist = 32.0 + dtype = torch.float32 + device = torch.device(device_type) + + # Generate base features with random atom-token structure + base_feats = random_features( + size_batch=B, + n_tokens=N_token, + n_atoms=N_atom, + n_msa=1, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=device, + float_value_range=(-0.1, 0.1), + selected_keys=[ + "token_pad_mask", + "atom_pad_mask", + "atom_counts_per_token", + "atom_to_token", + "asym_id", + "mol_type", + "atom_resolved_mask", + "frames_idx", + "frame_resolved_mask", + "is_nonpolymer_with_frame", + ], + ) + + # Ensure frame_resolved_mask is not killed by random atom_resolved_mask: + # set atom_resolved_mask to all-ones so frame_resolved_mask reflects only + # frame geometry (collinear mask), not random atom resolution. + # true_coords_resolved_mask (created below) still exercises the resolved path. + base_feats["atom_resolved_mask"] = torch.ones_like(base_feats["atom_resolved_mask"]) + base_feats["frame_resolved_mask"] = torch.ones_like(base_feats["frame_resolved_mask"]) + + # Create heterogeneous features - raw unpadded data + feats_global, pred_pae, pred_atom_coords, true_atom_coords, true_coords_resolved_mask = ( + create_heterogeneous_pae_features( + B=B, + N_token=N_token, + N_atom=N_atom, + multiplicity=multiplicity, + device=device, + dtype=dtype, + base_feats=base_feats, + ) + ) + pred_pae.requires_grad_(True) + pred_atom_coords.requires_grad_(True) + + # Compute serial reference on raw unpadded data + # Serial pae_loss expects true_coords_resolved_mask with shape (B, N_atom) + expected_loss, _rel_loss = serial_pae_loss( + pred_pae=pred_pae, + pred_atom_coords=pred_atom_coords, + feats=feats_global, + true_atom_coords=true_atom_coords, + true_coords_resolved_mask=true_coords_resolved_mask, + multiplicity=multiplicity, + max_dist=max_dist, + ) + expected_loss.backward() + assert pred_pae.grad is not None and pred_pae.grad.abs().sum() > 0, "Serial grad is zero" + # Pack payload as dicts + test_config = { + "grid_group_sizes": grid_group_sizes, + "device_type": device_type, + "backend": backend, + "env_per_rank": env_per_rank, + "multiplicity": multiplicity, + "max_dist": max_dist, + } + + # Reshape pred_pae from serial (B, mult, N, N, bins) to distributed (B*mult, N, N, bins) + pred_pae_flat = pred_pae.detach().clone().flatten(0, 1).cpu() + pred_pae_grad_flat = pred_pae.grad.detach().clone().flatten(0, 1).cpu() + + inputs_global = { + "pred_pae": pred_pae_flat, + "pred_atom_coords": pred_atom_coords.detach().clone().cpu(), + "true_atom_coords": true_atom_coords.clone().cpu(), + "true_coords_resolved_mask": true_coords_resolved_mask.clone().cpu(), + } + + feats_global_cpu = {k: v.clone().cpu() for k, v in feats_global.items()} + + expected = { + "loss": expected_loss.detach().clone().cpu(), + "pred_pae_grad": pred_pae_grad_flat, + } + + payload = (test_config, inputs_global, feats_global_cpu, expected) + + # Launch parallel test + spawn_multiprocessing(parallel_assert_pae_loss, world_size, payload) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/distributed/model/loss/test_dtensor_smooth_lddt_loss.py b/tests/distributed/model/loss/test_dtensor_smooth_lddt_loss.py new file mode 100644 index 000000000..9605d32f3 --- /dev/null +++ b/tests/distributed/model/loss/test_dtensor_smooth_lddt_loss.py @@ -0,0 +1,222 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.loss.diffusion import smooth_lddt_loss +from boltz.model.loss.diffusion import smooth_lddt_loss as serial_smooth_lddt_loss_v1 +from boltz.model.loss.diffusionv2 import smooth_lddt_loss as serial_smooth_lddt_loss_v2 +from boltz.testing.utils import spawn_multiprocessing + + +def parallel_assert_smooth_lddt_loss( + rank: int, + payload: tuple, +): + ( + multiplicity, + nucleic_acid_cutoff, + other_cutoff, + pred_coords, + true_coords, + is_nucleotide, + coords_mask, + expected_loss_host, + expected_pred_coords_grad_host, + grid_group_sizes, + device_type, + backend, + env_map, + v2, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + layout_map = manager.layout_subgroups["cp"] + device_mesh = manager.device_mesh_subgroups + comm = TransposeComm(manager.group["cp"], layout_map) + + pred_coords = distribute_tensor( + pred_coords.to(manager.device), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ) + true_coords = distribute_tensor( + true_coords.to(manager.device), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ) + is_nucleotide = distribute_tensor( + is_nucleotide.to(manager.device), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ) + coords_mask = distribute_tensor( + coords_mask.to(manager.device), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ) + + pred_coords.requires_grad_(True) + + loss = smooth_lddt_loss( + pred_coords, + true_coords, + is_nucleotide, + coords_mask, + multiplicity=multiplicity, + comm=comm, + nucleic_acid_cutoff=nucleic_acid_cutoff, + other_cutoff=other_cutoff, + v2=v2, + ) + + assert ( + loss.shape == expected_loss_host.shape + ), f"Loss shape mismatch: expected {expected_loss_host.shape}, got {loss.shape}" + assert ( + loss.stride() == expected_loss_host.stride() + ), f"Loss stride mismatch: expected {expected_loss_host.stride()}, got {loss.stride()}" + + loss.backward() + + torch.testing.assert_close(loss.full_tensor().cpu(), expected_loss_host) + + assert ( + pred_coords.grad.shape == expected_pred_coords_grad_host.shape + ), f"Pred coords grad shape mismatch: expected {expected_pred_coords_grad_host.shape}, got {pred_coords.grad.shape}" + assert ( + pred_coords.grad.stride() == expected_pred_coords_grad_host.stride() + ), f"Pred coords grad stride mismatch: expected {expected_pred_coords_grad_host.stride()}, got {pred_coords.grad.stride()}" + torch.testing.assert_close( + pred_coords.grad.full_tensor().cpu(), + expected_pred_coords_grad_host, + ) + + # Clean up + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize("multiplicity", [1, 2]) +@pytest.mark.parametrize("v2", [True, False], ids=["v2", "v1"]) +def test_smooth_lddt_loss( + setup_env: tuple, + multiplicity: int, + v2: bool, + nucleic_acid_cutoff: float = 30.0, + other_cutoff: float = 15.0, + dtype: torch.dtype = torch.float64, +): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + seed = 42 + rng = torch.Generator(device_type) + rng.manual_seed(seed) + + B = 2 * grid_group_sizes["dp"] + N_per_shard = 32 + N = N_per_shard * grid_group_sizes["cp"][0] + + # Make coordinates large enough that some distances exceed cutoff (default of 30) + pred_coords = ( + torch.randn(B * multiplicity, N, 3, generator=rng, device=device_type, dtype=dtype) * nucleic_acid_cutoff + ) + true_coords = ( + torch.randn(B * multiplicity, N, 3, generator=rng, device=device_type, dtype=dtype) * nucleic_acid_cutoff + ) + pred_coords.requires_grad_(True) + + # Multiplicity is called within smooth_lddt_loss for features + is_nucleotide = torch.randint(0, 2, (B, N), generator=rng, device=device_type, dtype=dtype) + coords_mask = torch.randint(0, 2, (B, N), generator=rng, device=device_type, dtype=dtype) + + # mask the last 2 atoms in each shard to emulate inserted virtual atoms in the middle of the atom sequence + for cp_rank in range(grid_group_sizes["cp"][0]): + end_idx = (cp_rank + 1) * N_per_shard + start_idx = end_idx - 2 + coords_mask[:, start_idx:end_idx] = 0 + + serial_smooth_lddt_loss = serial_smooth_lddt_loss_v2 if v2 else serial_smooth_lddt_loss_v1 + reference_loss = serial_smooth_lddt_loss( + pred_coords, + true_coords, + is_nucleotide, + coords_mask, + multiplicity=multiplicity, + nucleic_acid_cutoff=nucleic_acid_cutoff, + other_cutoff=other_cutoff, + ) + reference_loss.backward() + + # Call subprocesses for parallel testing + payload = ( + multiplicity, + nucleic_acid_cutoff, + other_cutoff, + pred_coords.detach().clone().cpu(), + true_coords.clone().cpu(), + is_nucleotide.clone().cpu(), + coords_mask.clone().cpu(), + reference_loss.detach().clone().cpu(), + pred_coords.grad.detach().clone().cpu(), + grid_group_sizes, + device_type, + backend, + env_per_rank, + v2, + ) + spawn_multiprocessing( + parallel_assert_smooth_lddt_loss, + world_size, + payload, + ) diff --git a/tests/distributed/model/loss/test_dtensor_weighted_minimum_rmsd_single.py b/tests/distributed/model/loss/test_dtensor_weighted_minimum_rmsd_single.py new file mode 100644 index 000000000..8920e7cf2 --- /dev/null +++ b/tests/distributed/model/loss/test_dtensor_weighted_minimum_rmsd_single.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor weighted_minimum_rmsd_single. + +Verifies that the distributed (DTensor) version of weighted_minimum_rmsd_single +produces results matching the serial version from boltz.model.loss.validation. +""" + +from __future__ import annotations + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.loss.validation import ( + weighted_minimum_rmsd_single as dtensor_weighted_minimum_rmsd_single, +) +from boltz.model.loss.validation import weighted_minimum_rmsd_single as serial_weighted_minimum_rmsd_single +from boltz.testing.utils import distribute_atom_features, get_feature_placements, random_features, spawn_multiprocessing + +_atom_keys = {"atom_pad_mask", "atom_to_token"} +_token_keys = {"mol_type"} +_placements = get_feature_placements(atom_keys=_atom_keys, token_keys=_token_keys) +_placements_atom_features = _placements["atom_features"] +_placements_cp_atom_features = _placements["cp_atom_features"] +_placements_token_features = _placements["token_features"] + +_placements_coords = (Shard(0), Shard(1), Replicate()) +_placements_cp_coords = (Shard(0), Replicate()) + + +def parallel_assert_weighted_minimum_rmsd_single(rank, payload): + """Test distributed weighted_minimum_rmsd_single against serial version.""" + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + input_feats_global_host, + pred_coords_global_host, + true_coords_global_host, + expected_rmsd, + expected_aligned, + expected_weights, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device = manager.device + dtype = torch.float32 + device_mesh = manager.device_mesh_subgroups + + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in input_feats_global_host.items() + if k in _placements_cp_atom_features + } + inputs_atom["pred_coords"] = pred_coords_global_host.to(dtype=dtype) + inputs_atom["true_coords"] = true_coords_global_host.to(dtype=dtype) + + placements_cp = _placements_cp_atom_features | { + "pred_coords": _placements_cp_coords, + "true_coords": _placements_cp_coords, + } + placements_dp_cp = _placements_atom_features | { + "pred_coords": _placements_coords, + "true_coords": _placements_coords, + } + + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + device_mesh, + manager.group["cp"], + ) + + pred_coords_dtensor = feats_atom["pred_coords"] + true_coords_dtensor = feats_atom["true_coords"] + atom_pad_mask_dtensor = feats_atom["atom_pad_mask"] + atom_to_token_dtensor = feats_atom["atom_to_token"] + + mol_type_dtensor = distribute_tensor( + input_feats_global_host["mol_type"].to(device=device, dtype=dtype), + device_mesh=device_mesh, + placements=_placements_token_features["mol_type"], + ) + + atom_mask_float = atom_pad_mask_dtensor + + rmsd, aligned, weights = dtensor_weighted_minimum_rmsd_single( + pred_atom_coords=pred_coords_dtensor, + atom_coords=true_coords_dtensor, + atom_mask=atom_mask_float, + atom_to_token=atom_to_token_dtensor, + mol_type=mol_type_dtensor, + ) + + rmsd_full = rmsd.full_tensor().cpu() + aligned_full = aligned.full_tensor().cpu() + + atom_pad_mask_full = atom_pad_mask_dtensor.full_tensor().cpu() + + torch.testing.assert_close( + rmsd_full, + expected_rmsd, + msg="RMSD mismatch between distributed and serial", + ) + + batch_size = pred_coords_global_host.shape[0] + for b in range(batch_size): + real_atom_mask = atom_pad_mask_full[b].bool() + aligned_no_pad = aligned_full[b, real_atom_mask, :] + expected_aligned_b = expected_aligned[b] + + torch.testing.assert_close( + aligned_no_pad, + expected_aligned_b, + msg=f"Batch {b}: aligned coords mismatch", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), # dp=1, cp=(2,2), world_size=4 + ((2, (2, 2)), True, "cuda", "ENV"), # dp=2, cp=(2,2), world_size=8 + ], + indirect=("setup_env",), +) +def test_dtensor_weighted_minimum_rmsd_single(setup_env): + """Test distributed weighted_minimum_rmsd_single against serial version. + + Uses random features with proper atom-to-token mapping. Compares rmsd, + aligned coordinates, and weights with the serial implementation. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + torch.manual_seed(42) + + size_ring = grid_group_sizes["cp"][0] + num_dp_ranks = grid_group_sizes["dp"] + batch_size = num_dp_ranks + n_atoms_per_token = 3 + n_tokens = size_ring * 4 + n_atoms = n_atoms_per_token * n_tokens + + feats_from_random = random_features( + size_batch=batch_size, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=1, + atom_counts_per_token_range=(1, n_atoms_per_token), + device=torch.device("cpu"), + float_value_range=(-1.0, 1.0), + selected_keys=["atom_to_token", "mol_type", "atom_pad_mask", "atom_counts_per_token"], + ) + + pred_coords = torch.randn((batch_size, n_atoms, 3), dtype=torch.float32) + true_coords = torch.randn((batch_size, n_atoms, 3), dtype=torch.float32) + + atom_to_token = feats_from_random["atom_to_token"].float() + mol_type = feats_from_random["mol_type"] + atom_pad_mask = feats_from_random["atom_pad_mask"].float() + + serial_device = torch.device("cuda:0") + expected_rmsd, expected_aligned, expected_weights = serial_weighted_minimum_rmsd_single( + pred_atom_coords=pred_coords.to(serial_device), + atom_coords=true_coords.to(serial_device), + atom_mask=atom_pad_mask.to(serial_device), + atom_to_token=atom_to_token.to(serial_device), + mol_type=mol_type.to(serial_device).float(), + ) + expected_rmsd = expected_rmsd.detach().cpu() + expected_aligned = expected_aligned.detach().cpu() + expected_weights = expected_weights.detach().cpu() + + input_feats_global = { + "atom_to_token": atom_to_token, + "mol_type": mol_type, + "atom_pad_mask": atom_pad_mask, + "atom_counts_per_token": feats_from_random["atom_counts_per_token"], + } + input_feats_global_host = {k: v.detach().clone().cpu() for k, v in input_feats_global.items()} + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + input_feats_global_host, + pred_coords.detach().clone().cpu(), + true_coords.detach().clone().cpu(), + expected_rmsd, + expected_aligned, + expected_weights, + ) + + spawn_multiprocessing(parallel_assert_weighted_minimum_rmsd_single, world_size, payload) diff --git a/tests/distributed/model/loss/test_dtensor_weighted_rigid_align.py b/tests/distributed/model/loss/test_dtensor_weighted_rigid_align.py new file mode 100644 index 000000000..5c6d69ece --- /dev/null +++ b/tests/distributed/model/loss/test_dtensor_weighted_rigid_align.py @@ -0,0 +1,318 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor weighted_rigid_align. + +Adapted from Boltz-1x CP tests. Verifies that the DTensor weighted_rigid_align +produces identical results to the serial version, and that outputs are +binary-identical across replicate ranks. +""" + +from math import isqrt +from typing import Optional + +import pytest +import torch +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.loss.diffusion import weighted_rigid_align as dtensor_weighted_rigid_align +from boltz.model.loss.diffusionv2 import weighted_rigid_align as serial_weighted_rigid_align +from boltz.testing.utils import assert_all_identical, assert_tensors_identical, seed_by_rank, spawn_multiprocessing + + +def compute_serial_expectation( + true_coords_global: torch.Tensor, + pred_coords_global: torch.Tensor, + weights_global: torch.Tensor, + mask_global: torch.Tensor, + device: torch.device, +) -> torch.Tensor: + """Compute expected result using serial weighted_rigid_align.""" + true_coords_device = true_coords_global.to(device) + pred_coords_device = pred_coords_global.to(device) + weights_device = weights_global.to(device) + mask_device = mask_global.to(device) + + result = serial_weighted_rigid_align(true_coords_device, pred_coords_device, weights_device, mask_device) + return result.detach().clone() + + +def compute_dtensor_weighted_rigid_align_with_validation( + true_coords_global: torch.Tensor, + pred_coords_global: torch.Tensor, + weights_global: torch.Tensor, + mask_global: torch.Tensor, + device_mesh: DeviceMesh, + label_test_case: str, +) -> DTensor: + """Compute DTensor weighted_rigid_align with input validation checks.""" + coords_placements = (Shard(0), Shard(1), Replicate()) + + true_coords_dtensor = distribute_tensor(true_coords_global.detach().clone(), device_mesh, coords_placements) + pred_coords_dtensor = distribute_tensor(pred_coords_global.detach().clone(), device_mesh, coords_placements) + weights_dtensor = distribute_tensor(weights_global.detach().clone(), device_mesh, coords_placements) + mask_dtensor = distribute_tensor(mask_global.detach().clone(), device_mesh, coords_placements) + + # Create copies for validation + true_coords_copy = true_coords_dtensor.detach().clone() + pred_coords_copy = pred_coords_dtensor.detach().clone() + weights_copy = weights_dtensor.detach().clone() + mask_copy = mask_dtensor.detach().clone() + + result_dtensor = dtensor_weighted_rigid_align( + true_coords_dtensor, pred_coords_dtensor, weights_dtensor, mask_dtensor + ) + + # Verify no change to inputs + assert_tensors_identical( + true_coords_dtensor.to_local(), true_coords_copy.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical( + pred_coords_dtensor.to_local(), pred_coords_copy.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(weights_dtensor.to_local(), weights_copy.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(mask_dtensor.to_local(), mask_copy.to_local(), check_grad=False, check_grad_fn=False) + + # Verify output placements + assert ( + result_dtensor.placements == coords_placements + ), f"{label_test_case} output placements mismatch with input placements" + + # Verify binary identical output across the Replicate device_mesh axis + assert_all_identical(result_dtensor.to_local(), device_mesh.get_group(2)) + + return result_dtensor + + +def parallel_assert_dtensor_weighted_rigid_align( + rank: int, + grid_group_sizes: dict[str, int], + device_type: str, + backend: str, + env_map: Optional[dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + seed_by_rank(0, seed=42) + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Random input test case + size_batch = 2 + n_atoms_per_rank = 6 + n_atoms_padding_per_rank = 2 + n_atoms = size_ring * n_atoms_per_rank + + true_coords_global = torch.randn((size_batch, n_atoms, 3), dtype=torch.float32) * 2.0 + 10.0 + pred_coords_global = torch.randn((size_batch, n_atoms, 3), dtype=torch.float32) * 5.0 + 15.0 + + # Mask with padding per rank shard + mask_global = torch.zeros((size_batch, size_ring, n_atoms_per_rank), dtype=torch.float32) + mask_global[:, :, :(-n_atoms_padding_per_rank)] = 1.0 + mask_global = mask_global.reshape(size_batch, n_atoms) + weights_global = mask_global.clone() + + label = "random_input" + + # Serial expectation + expected_result = compute_serial_expectation( + true_coords_global, pred_coords_global, weights_global, mask_global, manager.device + ) + + # DTensor result with validation + result_dtensor = compute_dtensor_weighted_rigid_align_with_validation( + true_coords_global, + pred_coords_global, + weights_global, + mask_global, + manager.device_mesh_subgroups, + label, + ) + + # Compare local shards + expected_dtensor = distribute_tensor(expected_result, manager.device_mesh_subgroups, result_dtensor.placements) + torch.testing.assert_close( + result_dtensor.to_local(), + expected_dtensor.to_local(), + msg=lambda m: f"{label} local shard mismatch: {m}", + ) + + # Compare global tensors + result_global = result_dtensor.full_tensor() + torch.testing.assert_close( + result_global, + expected_result, + msg=lambda m: f"{label} global result mismatch: {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +def _build_degenerate_inputs(case: str, device: torch.device): + """Build inputs that trigger the degenerate-case warnings in DTensor weighted_rigid_align.""" + if case == "scalar_degenerate": + # total_num_points <= dim (2 <= 3) → scalar warning + true_coords = torch.randn(1, 2, 3, dtype=torch.float32, device=device) * 2.0 + 10.0 + pred_coords = torch.randn(1, 2, 3, dtype=torch.float32, device=device) * 5.0 + 15.0 + mask = torch.ones(1, 2, dtype=torch.float32, device=device) + weights = mask.clone() + elif case == "per_batch_degenerate": + # batch size 2, 4 points; mask so batch index 1 has only 2 valid (< dim+1=4) → per-batch warning + true_coords = torch.randn(2, 4, 3, dtype=torch.float32, device=device) * 2.0 + 10.0 + pred_coords = torch.randn(2, 4, 3, dtype=torch.float32, device=device) * 5.0 + 15.0 + mask = torch.ones(2, 4, dtype=torch.float32, device=device) + mask[1, 2:] = 0.0 # batch 1: only 2 valid points + weights = mask.clone() + elif case == "svd_low_rank": + # 4 points but collinear → covariance rank-deficient → SVD low-rank warning + true_coords = torch.zeros(1, 4, 3, dtype=torch.float32, device=device) + true_coords[:, :, 0] = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device) # on x-axis + pred_coords = true_coords.clone() + 0.1 * torch.randn(1, 4, 3, device=device) + mask = torch.ones(1, 4, dtype=torch.float32, device=device) + weights = mask.clone() + else: + raise ValueError(f"Unknown case: {case}") + return true_coords, pred_coords, weights, mask + + +def _expected_warning_substrings(case: str): + """Substrings that must appear in the warning message for the given degenerate case.""" + if case == "scalar_degenerate": + return ["The size of one of the point clouds is <= dim+1."] + if case == "per_batch_degenerate": + return ["[rank_coord:", "Batch indices (subset):"] + if case == "svd_low_rank": + return ["[rank_coord:", "Excessively low rank"] + raise ValueError(f"Unknown case: {case}") + + +DEGENERATE_CASES = ("scalar_degenerate", "per_batch_degenerate", "svd_low_rank") + + +def parallel_assert_dtensor_weighted_rigid_align_degenerate( + rank: int, + grid_group_sizes: dict, + device_type: str, + backend: str, + env_map: Optional[dict[str, str]], +): + """Worker: run DTensor weighted_rigid_align on degenerate inputs and assert expected warnings (all cases in one spawn).""" + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + for case in DEGENERATE_CASES: + true_coords_global, pred_coords_global, weights_global, mask_global = _build_degenerate_inputs( + case, manager.device + ) + coords_placements = (Shard(0), Shard(1), Replicate()) + true_coords_dt = distribute_tensor(true_coords_global, manager.device_mesh_subgroups, coords_placements) + pred_coords_dt = distribute_tensor(pred_coords_global, manager.device_mesh_subgroups, coords_placements) + weights_dt = distribute_tensor(weights_global, manager.device_mesh_subgroups, coords_placements) + mask_dt = distribute_tensor(mask_global, manager.device_mesh_subgroups, coords_placements) + + with pytest.warns(UserWarning) as record: + dtensor_weighted_rigid_align(true_coords_dt, pred_coords_dt, weights_dt, mask_dt) + + combined_message = " ".join(str(w.message) for w in record) + for substring in _expected_warning_substrings(case): + assert ( + substring in combined_message + ), f"Case {case}: expected substring {substring!r} in warning message(s), got: {combined_message!r}" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_dtensor_weighted_rigid_align(setup_env): + """Test DTensor weighted_rigid_align vs serial equivalence.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + spawn_multiprocessing( + parallel_assert_dtensor_weighted_rigid_align, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +@pytest.mark.parametrize( + "setup_env", + [((1, (1, 1)), True, "cuda", "ENV")], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}", +) +def test_dtensor_weighted_rigid_align_degenerate_warnings(setup_env): + """Test that DTensor weighted_rigid_align raises expected warnings for degenerate inputs (dp=1, cp=(1,1)).""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + spawn_multiprocessing( + parallel_assert_dtensor_weighted_rigid_align_degenerate, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) diff --git a/tests/distributed/model/loss/test_get_lddt_metrics.py b/tests/distributed/model/loss/test_get_lddt_metrics.py new file mode 100644 index 000000000..ee7fd71c9 --- /dev/null +++ b/tests/distributed/model/loss/test_get_lddt_metrics.py @@ -0,0 +1,453 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from __future__ import annotations + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +import boltz.distributed.model.loss.validation as _validation_module +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.loss.validation import get_lddt_metrics +from boltz.model.validation.validator import Validator +from boltz.testing.utils import distribute_atom_features, get_feature_placements, random_features, spawn_multiprocessing + +_atom_keys = {"atom_pad_mask", "atom_to_token", "atom_counts_per_token"} +_token_keys = {"mol_type", "asym_id", "token_pad_mask"} +_placements = get_feature_placements(atom_keys=_atom_keys, token_keys=_token_keys) +_placements_atom_features = _placements["atom_features"] +_placements_cp_atom_features = _placements["cp_atom_features"] +_placements_token_features = _placements["token_features"] + +_placements_pred_coords = {"pred_coords": (Shard(0), Shard(1), Replicate())} +_placements_cp_pred_coords = {"pred_coords": (Shard(0), Replicate())} + + +def parallel_assert_get_lddt_metrics(rank, payload): + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + n_samples, + K, + expand_to_diffusion_samples, + feats_global_host, + pred_coords_global_host, + true_coords_base_host, + true_coords_global_host, + true_coords_resolved_mask_base_host, + true_coords_resolved_mask_host, + ref_lddt, + ref_total, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device = manager.device + dtype = torch.float32 + + size_batch = feats_global_host["atom_pad_mask"].shape[0] + rank_dp = manager.group_rank["dp"] + num_dp_ranks = grid_group_sizes["dp"] + local_batch_size = size_batch // num_dp_ranks + local_start = rank_dp * local_batch_size + local_end = local_start + local_batch_size + + def _all_gather_single_repr(single_dtensor): + single_dtensor = single_dtensor.redistribute( + placements=[Shard(0), Replicate(), Replicate()], + ) + return single_dtensor.to_local() + + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in _placements_cp_atom_features + } + pred_coords_unflat = pred_coords_global_host.unflatten(0, (size_batch, n_samples)) + for i_mul in range(n_samples): + inputs_atom[f"pred_coords_{i_mul}"] = pred_coords_unflat[:, i_mul].to(dtype=dtype) + + placements_cp = dict(_placements_cp_atom_features) + placements_dp_cp = dict(_placements_atom_features) + for i_mul in range(n_samples): + placements_cp[f"pred_coords_{i_mul}"] = _placements_cp_pred_coords["pred_coords"] + placements_dp_cp[f"pred_coords_{i_mul}"] = _placements_pred_coords["pred_coords"] + + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + manager.device_mesh_subgroups, + manager.group["cp"], + multiplicities={"pred_coords": n_samples}, + ) + + atom_to_token_dtensor = feats_atom["atom_to_token"] + + mol_type_dtensor = distribute_tensor( + feats_global_host["mol_type"].to(device=device, dtype=torch.int64), + device_mesh=manager.device_mesh_subgroups, + placements=_placements_token_features["mol_type"], + ) + asym_id_dtensor = distribute_tensor( + feats_global_host["asym_id"].to(device=device, dtype=torch.int64), + device_mesh=manager.device_mesh_subgroups, + placements=_placements_token_features["asym_id"], + ) + + mol_type_local = _all_gather_single_repr(mol_type_dtensor) + asym_id_local = _all_gather_single_repr(asym_id_dtensor) + pred_coords_local = _all_gather_single_repr(feats_atom["pred_coords"]) + atom_pad_mask_local = _all_gather_single_repr(feats_atom["atom_pad_mask"]).bool() + + local_mul_start = local_start * n_samples + local_mul_end = local_end * n_samples + if expand_to_diffusion_samples: + true_coords_local_unpadded = true_coords_global_host[local_mul_start:local_mul_end].to( + device=device, dtype=dtype + ) + mask_local_unpadded = true_coords_resolved_mask_host[local_mul_start:local_mul_end].to( + device=device, dtype=dtype + ) + active_unpadded_mask = mask_local_unpadded[0].bool() + else: + true_coords_local_unpadded = ( + true_coords_base_host[local_start:local_end].squeeze(0).to(device=device, dtype=dtype) + ) + mask_local_unpadded = ( + true_coords_resolved_mask_base_host[local_start:local_end].squeeze(0).to(device=device, dtype=dtype) + ) + active_unpadded_mask = mask_local_unpadded.bool() + + atom_mask_row = atom_pad_mask_local[0].bool() + n_atoms_padded = atom_mask_row.shape[0] + n_atoms_active = int(active_unpadded_mask.sum().item()) + if int(atom_mask_row.sum().item()) != n_atoms_active: + raise ValueError( + "atom_pad_mask/padded atom-space mismatch: " + f"sum(atom_pad_mask)={int(atom_mask_row.sum().item())}, n_atoms_active_unpadded={n_atoms_active}" + ) + if expand_to_diffusion_samples: + true_coords_local_active = true_coords_local_unpadded[:, :, active_unpadded_mask, :] + mask_local_active = mask_local_unpadded[:, active_unpadded_mask] + + true_coords_local = torch.zeros( + true_coords_local_unpadded.shape[0], + true_coords_local_unpadded.shape[1], + n_atoms_padded, + true_coords_local_unpadded.shape[3], + device=device, + dtype=dtype, + ) + true_coords_local[:, :, atom_mask_row, :] = true_coords_local_active + + mask_local = torch.zeros( + mask_local_unpadded.shape[0], + n_atoms_padded, + device=device, + dtype=dtype, + ) + mask_local[:, atom_mask_row] = mask_local_active + else: + true_coords_local_active = true_coords_local_unpadded[:, active_unpadded_mask, :] + mask_local_active = mask_local_unpadded[active_unpadded_mask] + + true_coords_local = torch.zeros( + true_coords_local_unpadded.shape[0], + n_atoms_padded, + true_coords_local_unpadded.shape[2], + device=device, + dtype=dtype, + ) + true_coords_local[:, atom_mask_row, :] = true_coords_local_active + + mask_local = torch.zeros( + n_atoms_padded, + device=device, + dtype=dtype, + ) + mask_local[atom_mask_row] = mask_local_active + + lddt_dict, total_dict = get_lddt_metrics( + atom_to_token_dtensor=atom_to_token_dtensor, + num_conformers=K, + n_samples=n_samples, + true_coords=true_coords_local, + true_coords_resolved_mask=mask_local, + mol_type=mol_type_local, + asym_id=asym_id_local, + sample_atom_coords=pred_coords_local, + expand_to_diffusion_samples=expand_to_diffusion_samples, + ) + + ref_slice = slice(local_mul_start, local_mul_end) + for key in ref_lddt: + torch.testing.assert_close( + lddt_dict[key].cpu(), + ref_lddt[key][ref_slice], + ) + torch.testing.assert_close( + total_dict[key].cpu(), + ref_total[key][ref_slice], + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize("expand_to_diffusion_samples", [True, False]) +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (3, 3)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), +) +def test_get_lddt_metrics(setup_env, expand_to_diffusion_samples): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type != "cuda": + pytest.skip("cdist_lddt requires CUDA") + + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + torch.manual_seed(0) + rng = torch.Generator(device="cpu") + rng.manual_seed(0) + size_batch = grid_group_sizes["dp"] + size_cp = grid_group_sizes["cp"][0] + n_tokens = size_cp * 20 + n_atoms = n_tokens * 20 + n_samples = 2 + K = 2 + + feats_global_host = random_features( + size_batch=size_batch, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=1, + atom_counts_per_token_range=(1, 20), + device=torch.device("cpu"), + float_value_range=(0.0, 1.0), + selected_keys=[ + "atom_to_token", + "atom_pad_mask", + "atom_counts_per_token", + "mol_type", + "asym_id", + "token_pad_mask", + ], + rng=rng, + ) + token_pad_mask = torch.ones((size_batch, n_tokens), dtype=torch.bool) + token_pad_mask[:, ::3] = False + token_pad_mask[:, :2] = True + feats_global_host["token_pad_mask"] = token_pad_mask.to(dtype=feats_global_host["token_pad_mask"].dtype) + + atom_to_token = feats_global_host["atom_to_token"] + atom_pad_mask = torch.zeros((size_batch, n_atoms), dtype=torch.bool) + for batch_idx in range(size_batch): + token_mask = token_pad_mask[batch_idx] + atom_mask = atom_to_token[batch_idx][:, token_mask].any(dim=1) + atom_pad_mask[batch_idx] = atom_mask + atom_to_token[batch_idx, ~atom_mask, :] = 0 + atom_to_token[batch_idx, :, ~token_mask] = 0 + + feats_global_host["atom_pad_mask"] = atom_pad_mask.to(dtype=feats_global_host["atom_pad_mask"].dtype) + pred_coords_global_host = torch.randn(size_batch * n_samples, n_atoms, 3, dtype=torch.float32) + true_coords_base_host = torch.randn(size_batch, K, n_atoms, 3, dtype=torch.float32) + true_coords_global_host = true_coords_base_host.repeat_interleave(n_samples, dim=0) + true_coords_resolved_mask_base_host = feats_global_host["atom_pad_mask"].to(torch.float32) + true_coords_resolved_mask_host = true_coords_resolved_mask_base_host.repeat_interleave(n_samples, dim=0) + + ref_lddt: dict[str, torch.Tensor] = {} + ref_total: dict[str, torch.Tensor] = {} + feats_serial = { + "atom_to_token": feats_global_host["atom_to_token"].to(dtype=torch.float32), + "mol_type": feats_global_host["mol_type"], + "asym_id": feats_global_host["asym_id"], + "coords": true_coords_base_host, + } + + if expand_to_diffusion_samples: + ref_lddt, ref_total = Validator.get_lddt_metrics( + None, + model=None, + batch=feats_serial, + out={"sample_atom_coords": pred_coords_global_host}, + idx_dataset=0, + n_samples=n_samples, + true_coords_resolved_mask=true_coords_resolved_mask_host, + true_coords=true_coords_global_host, + expand_to_diffusion_samples=True, + ) + else: + pred_coords_by_batch = pred_coords_global_host.unflatten(0, (size_batch, n_samples)) + for batch_idx in range(size_batch): + feats_serial_single = { + "atom_to_token": feats_serial["atom_to_token"][batch_idx : batch_idx + 1], + "mol_type": feats_serial["mol_type"][batch_idx : batch_idx + 1], + "asym_id": feats_serial["asym_id"][batch_idx : batch_idx + 1], + "coords": feats_serial["coords"][batch_idx : batch_idx + 1], + } + lddt_single, total_single = Validator.get_lddt_metrics( + None, + model=None, + batch=feats_serial_single, + out={"sample_atom_coords": pred_coords_by_batch[batch_idx]}, + idx_dataset=0, + n_samples=n_samples, + true_coords_resolved_mask=true_coords_resolved_mask_base_host[batch_idx], + true_coords=true_coords_base_host[batch_idx], + expand_to_diffusion_samples=False, + ) + if not ref_lddt: + for key in lddt_single: + ref_lddt[key] = torch.zeros(size_batch * n_samples, K, dtype=lddt_single[key].dtype) + ref_total[key] = torch.zeros(size_batch * n_samples, K, dtype=total_single[key].dtype) + row_slice = slice(batch_idx * n_samples, (batch_idx + 1) * n_samples) + for key in lddt_single: + ref_lddt[key][row_slice] = lddt_single[key] + ref_total[key][row_slice] = total_single[key] + + spawn_multiprocessing( + parallel_assert_get_lddt_metrics, + world_size, + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + n_samples, + K, + expand_to_diffusion_samples, + feats_global_host, + pred_coords_global_host, + true_coords_base_host, + true_coords_global_host, + true_coords_resolved_mask_base_host, + true_coords_resolved_mask_host, + dict(ref_lddt), + dict(ref_total), + ), + ) + + +@pytest.mark.parametrize( + "mutation, expand, match", + [ + ("batch_size", True, "local batch size 1"), + ("sample_batch", True, "sample_atom_coords batch must equal"), + ("sample_ndim", True, "sample_atom_coords must be rank 3"), + ("true_coords_ndim_expanded", True, "true_coords must be rank 4"), + ("true_coords_ndim_not_expanded", False, "true_coords must be rank 3"), + ("mask_rank_expanded", True, "true_coords_resolved_mask must be rank 2"), + ("mask_rank_not_expanded", False, "true_coords_resolved_mask must be rank 1"), + ("true_coords_K_expanded", True, "true_coords conformer count"), + ("true_coords_K_not_expanded", False, "true_coords conformer count"), + ("true_coords_batch_expanded", True, "true_coords batch dim"), + ("mol_type_tokens", True, "mol_type N_tokens"), + ("asym_id_tokens", True, "asym_id N_tokens"), + ("sample_atoms", True, "sample_atom_coords N_atoms"), + ("mask_atoms_expanded", True, "true_coords_resolved_mask N_atoms"), + ("mask_atoms_not_expanded", False, "true_coords_resolved_mask N_atoms"), + ("true_coords_atoms_expanded", True, "true_coords N_atoms"), + ("true_coords_atoms_not_expanded", False, "true_coords N_atoms"), + ], +) +def test_get_lddt_metrics_shape_errors(monkeypatch, mutation, expand, match): + B, N_tokens, N_atoms, K, n_samples = 1, 4, 8, 2, 2 + + atom_to_token = torch.zeros(B, N_atoms, N_tokens) + mol_type = torch.zeros(B, N_tokens, dtype=torch.long) + asym_id = torch.zeros(B, N_tokens, dtype=torch.long) + sample_atom_coords = torch.zeros(B * n_samples, N_atoms, 3) + + if expand: + true_coords = torch.zeros(B * n_samples, K, N_atoms, 3) + mask = torch.zeros(B * n_samples, N_atoms) + else: + true_coords = torch.zeros(K, N_atoms, 3) + mask = torch.zeros(N_atoms) + + if mutation == "batch_size": + atom_to_token = torch.zeros(2, N_atoms, N_tokens) + elif mutation == "sample_batch": + sample_atom_coords = torch.zeros(B * n_samples + 1, N_atoms, 3) + elif mutation == "sample_ndim": + sample_atom_coords = torch.zeros(B * n_samples * N_atoms, 3) + elif mutation == "true_coords_ndim_expanded": + true_coords = torch.zeros(K, N_atoms, 3) + elif mutation == "true_coords_ndim_not_expanded": + true_coords = torch.zeros(B * n_samples, K, N_atoms, 3) + elif mutation == "mask_rank_expanded": + mask = torch.zeros(N_atoms) + elif mutation == "mask_rank_not_expanded": + mask = torch.zeros(B * n_samples, N_atoms) + elif mutation == "true_coords_K_expanded": + true_coords = torch.zeros(B * n_samples, K + 1, N_atoms, 3) + elif mutation == "true_coords_K_not_expanded": + true_coords = torch.zeros(K + 1, N_atoms, 3) + elif mutation == "true_coords_batch_expanded": + true_coords = torch.zeros(B * n_samples + 1, K, N_atoms, 3) + elif mutation == "mol_type_tokens": + mol_type = torch.zeros(B, N_tokens + 1, dtype=torch.long) + elif mutation == "asym_id_tokens": + asym_id = torch.zeros(B, N_tokens + 1, dtype=torch.long) + elif mutation == "sample_atoms": + sample_atom_coords = torch.zeros(B * n_samples, N_atoms + 1, 3) + elif mutation == "mask_atoms_expanded": + mask = torch.zeros(B * n_samples, N_atoms + 1) + elif mutation == "mask_atoms_not_expanded": + mask = torch.zeros(N_atoms + 1) + elif mutation == "true_coords_atoms_expanded": + true_coords = torch.zeros(B * n_samples, K, N_atoms + 1, 3) + elif mutation == "true_coords_atoms_not_expanded": + true_coords = torch.zeros(K, N_atoms + 1, 3) + + monkeypatch.setattr(_validation_module, "reconstruct_atom_to_token_global", lambda _: atom_to_token) + + with pytest.raises(ValueError, match=match): + get_lddt_metrics( + atom_to_token_dtensor=None, + num_conformers=K, + n_samples=n_samples, + true_coords=true_coords, + true_coords_resolved_mask=mask, + mol_type=mol_type, + asym_id=asym_id, + sample_atom_coords=sample_atom_coords, + expand_to_diffusion_samples=expand, + ) diff --git a/tests/distributed/model/loss/test_smooth_lddt_loss_triton.py b/tests/distributed/model/loss/test_smooth_lddt_loss_triton.py new file mode 100644 index 000000000..8f819f386 --- /dev/null +++ b/tests/distributed/model/loss/test_smooth_lddt_loss_triton.py @@ -0,0 +1,603 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import os +import re +import subprocess +from pathlib import Path + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.manager import DistributedManager + +try: + from boltz.distributed.model.loss.triton.smooth_lddt_loss import ( + grid_launch_config, + smooth_lddt_loss_bwd_kernel, + smooth_lddt_loss_fwd_kernel, + ) + + has_smooth_lddt_loss_triton_kernels = True +except ImportError: + has_smooth_lddt_loss_triton_kernels = False + +from boltz.distributed.model.loss.diffusion import ( + _smooth_lddt_loss_backward_local, + _smooth_lddt_loss_forward_local, + _smooth_lddt_loss_local_triton_backward, + _smooth_lddt_loss_local_triton_forward, + smooth_lddt_loss_triton, +) +from boltz.distributed.model.modules.utils import PRECISION_TO_DTYPE, Precision, setup_tf32_env +from boltz.model.loss.diffusion import smooth_lddt_loss as smooth_lddt_loss_ref_impl_v1 +from boltz.model.loss.diffusionv2 import smooth_lddt_loss as smooth_lddt_loss_ref_impl_v2 +from boltz.testing.utils import spawn_multiprocessing + + +def assert_smooth_lddt_loss_equivalence(rank, payload): + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + multiplicity, + pred_coords_global, + true_coords_global, + is_nucleotide_global, + coords_mask_global, + v2, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Setup comm + transpose_comm = TransposeComm(manager.group["cp"], manager.layout_subgroups["cp"]) + + # Prepare inputs + # pred_coords: (Shard(0), Shard(1), Replicate()) + # true_coords: (Shard(0), Shard(1), Replicate()) + # is_nucleotide: (Shard(0), Shard(1), Replicate()) + # coords_mask: (Shard(0), Shard(1), Replicate()) + + placements = (Shard(0), Shard(1), Replicate()) + + pred_coords_dtensor = distribute_tensor( + pred_coords_global.to(device=manager.device, dtype=dtype), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ).requires_grad_(True) + + true_coords_dtensor = distribute_tensor( + true_coords_global.to(device=manager.device, dtype=dtype), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ).requires_grad_(False) # True coords usually don't need grad + + is_nucleotide_dtensor = distribute_tensor( + is_nucleotide_global.to(device=manager.device, dtype=dtype), # cast to float/int + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + + coords_mask_dtensor = distribute_tensor( + coords_mask_global.to(device=manager.device, dtype=dtype), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ) + + # Extract global tensors for reference implementation (running serially) + # This avoids potential bugs in distributed reference implementation and compares against + # the mathematical ground truth (global execution). + pred_global = pred_coords_dtensor.full_tensor().detach().clone().requires_grad_(True) + true_global = true_coords_dtensor.full_tensor().detach() + is_nuc_global = is_nucleotide_dtensor.full_tensor().detach() + mask_global = coords_mask_dtensor.full_tensor().detach() + + # Run Reference Function (Global/Serial) + smooth_lddt_loss_ref_impl = smooth_lddt_loss_ref_impl_v2 if v2 else smooth_lddt_loss_ref_impl_v1 + loss_ref = smooth_lddt_loss_ref_impl( + pred_coords=pred_global, + true_coords=true_global, + is_nucleotide=is_nuc_global, + coords_mask=mask_global, + multiplicity=multiplicity, + ) + + # Run Custom Function (DTensor version) - Triton + # We use the original dtensor (v2 copy not needed if we use fresh global for ref) + # But pred_coords_dtensor tracks grad. + if dtype == torch.bfloat16: + # Expect error for bf16 + with pytest.raises(ValueError, match=f"Triton kernel for smooth LDDT loss does not support {dtype}"): + smooth_lddt_loss_triton( + pred_coords=pred_coords_dtensor, + true_coords=true_coords_dtensor, + is_nucleotide=is_nucleotide_dtensor, + coords_mask=coords_mask_dtensor, + comm=transpose_comm, + multiplicity=multiplicity, + v2=v2, + ) + return + + loss_custom = smooth_lddt_loss_triton( + pred_coords=pred_coords_dtensor, + true_coords=true_coords_dtensor, + is_nucleotide=is_nucleotide_dtensor, + coords_mask=coords_mask_dtensor, + comm=transpose_comm, + multiplicity=multiplicity, + v2=v2, + ) + + # Compare Forward + # loss_ref is global scalar. + # loss_custom is DTensor (scalar). + torch.testing.assert_close(loss_ref, loss_custom.full_tensor()) + + # Backward Reference + loss_ref.backward() + grad_ref = pred_global.grad + + # Backward Custom + loss_custom.backward() + grad_custom = pred_coords_dtensor.grad + + # Compare Backward + # Gather custom gradients to global to compare with reference + torch.testing.assert_close(grad_ref, grad_custom.full_tensor()) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_smooth_lddt_loss_triton_kernels, reason="Triton kernels not available") +@pytest.mark.parametrize("is_self_comm", [True, False]) +@pytest.mark.parametrize("dtype_str", ["fp32", "bf16", "tf32", "fp64"]) +def test_smooth_lddt_loss_local_triton_equivalence(is_self_comm, dtype_str): + match dtype_str: + case "fp32": + dtype = torch.float32 + precision = Precision.FP32 + case "bf16": + dtype = torch.bfloat16 + precision = Precision.BF16 + case "tf32": + dtype = torch.float32 + precision = Precision.TF32 + case "fp64": + dtype = torch.float64 + precision = Precision.FP64 + case _: + raise ValueError(f"Unsupported dtype: {dtype_str}") + + with setup_tf32_env(precision): + # Setup inputs + B_local = 1 + multiplicity = 16 + N_atom_local = 100 + device = torch.device("cuda") + + init_val_range = 30.0 + pred_coords_local = ( + torch.randn(B_local * multiplicity, N_atom_local, 3, device=device, dtype=dtype) * init_val_range + ) + true_coords_local = ( + torch.randn(B_local * multiplicity, N_atom_local, 3, device=device, dtype=dtype) * init_val_range + ) + pred_coords_t_local = ( + torch.randn(B_local * multiplicity, N_atom_local, 3, device=device, dtype=dtype) * init_val_range + ) + true_coords_t_local = ( + torch.randn(B_local * multiplicity, N_atom_local, 3, device=device, dtype=dtype) * init_val_range + ) + + is_nucleotide_local = torch.randint(0, 2, (B_local, N_atom_local), device=device).bool() + coords_mask_local = torch.randint(0, 2, (B_local, N_atom_local), device=device).to(dtype=dtype) + coords_mask_t_local = torch.randint(0, 2, (B_local, N_atom_local), device=device).to(dtype=dtype) + + nucleic_acid_cutoff = 5.0 + other_cutoff = 3.0 + + # Clone inputs for Triton forward pass to check for in-place modifications + pred_coords_local_triton = pred_coords_local.clone() + true_coords_local_triton = true_coords_local.clone() + pred_coords_t_local_triton = pred_coords_t_local.clone() + true_coords_t_local_triton = true_coords_t_local.clone() + is_nucleotide_local_triton = is_nucleotide_local.clone() + coords_mask_local_triton = coords_mask_local.clone() + coords_mask_t_local_triton = coords_mask_t_local.clone() + + # --- Forward Pass --- + # Run PyTorch version + num_ref, den_ref = _smooth_lddt_loss_forward_local( + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + is_self_comm, + nucleic_acid_cutoff, + other_cutoff, + multiplicity, + ) + + # Run Triton version + if dtype == torch.bfloat16: + with pytest.raises(ValueError, match=f"Triton kernel for smooth LDDT loss does not support {dtype}"): + _smooth_lddt_loss_local_triton_forward( + pred_coords_local_triton, + true_coords_local_triton, + pred_coords_t_local_triton, + true_coords_t_local_triton, + is_nucleotide_local_triton, + coords_mask_local_triton, + coords_mask_t_local_triton, + is_self_comm, + nucleic_acid_cutoff, + other_cutoff, + multiplicity, + ) + return + + num_triton, den_triton = _smooth_lddt_loss_local_triton_forward( + pred_coords_local_triton, + true_coords_local_triton, + pred_coords_t_local_triton, + true_coords_t_local_triton, + is_nucleotide_local_triton, + coords_mask_local_triton, + coords_mask_t_local_triton, + is_self_comm, + nucleic_acid_cutoff, + other_cutoff, + multiplicity, + ) + + # Check equivalence + # PyTorch sum() over bf16 promotes to fp32, while Triton kernel keeps bf16. + # We cast ref to match triton output for comparison. + num_ref = num_ref.to(dtype=num_triton.dtype) + den_ref = den_ref.to(dtype=den_triton.dtype) + + torch.testing.assert_close(num_triton, num_ref) + torch.testing.assert_close(den_triton, den_ref) + + # Check that inputs were not modified in-place + torch.testing.assert_close(pred_coords_local_triton, pred_coords_local) + torch.testing.assert_close(true_coords_local_triton, true_coords_local) + torch.testing.assert_close(pred_coords_t_local_triton, pred_coords_t_local) + torch.testing.assert_close(true_coords_t_local_triton, true_coords_t_local) + torch.testing.assert_close(is_nucleotide_local_triton, is_nucleotide_local) + torch.testing.assert_close(coords_mask_local_triton, coords_mask_local) + torch.testing.assert_close(coords_mask_t_local_triton, coords_mask_t_local) + + # --- Backward Pass --- + + # Dummy gradients for num and den (scalars per batch element) + grad_num_reduced = torch.randn(B_local * multiplicity, device=device, dtype=dtype) + grad_den_reduced = torch.randn(B_local * multiplicity, device=device, dtype=dtype) + + # Clone gradients for Triton backward pass + grad_num_reduced_triton = grad_num_reduced.clone() + grad_den_reduced_triton = grad_den_reduced.clone() + + # Run PyTorch backward + grad_pred_local_ref, grad_pred_t_local_ref = _smooth_lddt_loss_backward_local( + grad_num_reduced, + grad_den_reduced, + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + is_self_comm, + nucleic_acid_cutoff, + other_cutoff, + multiplicity, + ) + + # Run Triton backward + grad_pred_local_triton, grad_pred_t_local_triton = _smooth_lddt_loss_local_triton_backward( + grad_num_reduced_triton, + grad_den_reduced_triton, + pred_coords_local_triton, + true_coords_local_triton, + pred_coords_t_local_triton, + true_coords_t_local_triton, + is_nucleotide_local_triton, + coords_mask_local_triton, + coords_mask_t_local_triton, + is_self_comm, + nucleic_acid_cutoff, + other_cutoff, + multiplicity, + ) + + # Check equivalence + grad_pred_local_ref = grad_pred_local_ref.to(dtype=grad_pred_local_triton.dtype) + grad_pred_t_local_ref = grad_pred_t_local_ref.to(dtype=grad_pred_t_local_triton.dtype) + + torch.testing.assert_close(grad_pred_local_triton, grad_pred_local_ref, atol=1e-6, rtol=1e-4) + torch.testing.assert_close(grad_pred_t_local_triton, grad_pred_t_local_ref, atol=1e-6, rtol=1e-4) + + # Check that gradients were not modified in-place + torch.testing.assert_close(grad_num_reduced_triton, grad_num_reduced) + torch.testing.assert_close(grad_den_reduced_triton, grad_den_reduced) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_smooth_lddt_loss_triton_kernels, reason="Triton kernels not available") +@pytest.mark.parametrize( + "setup_env", + ( + params_test := [ + ((2, (2, 2)), True, "cuda", "ENV"), + ] + ), + indirect=("setup_env",), + ids=[ + f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}" + for x in params_test + ], +) +@pytest.mark.parametrize("v2", [True, False], ids=["v2", "v1"]) +def test_smooth_lddt_loss_equivalence( + setup_env, + v2: bool, + multiplicity: int = 16, + dtype: torch.dtype = torch.float64, +): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + if device_type == "cuda" and torch.cuda.device_count() < world_size: + pytest.skip(f"Not enough GPUs. Required: {world_size}, Available: {torch.cuda.device_count()}") + + # Setup dummy data + B = 1 * grid_group_sizes["dp"] + N = 1000 * grid_group_sizes["cp"][0] + + B_expanded = B * multiplicity + + init_val_range = 30.0 + + pred_coords_global = torch.randn(B_expanded, N, 3) * init_val_range + true_coords_global = torch.randn(B_expanded, N, 3) * init_val_range + is_nucleotide_global = torch.randint(0, 2, (B, N)).float() + coords_mask_global = torch.randint(0, 2, (B, N)).float() + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + multiplicity, + pred_coords_global, + true_coords_global, + is_nucleotide_global, + coords_mask_global, + v2, + ) + + spawn_multiprocessing(assert_smooth_lddt_loss_equivalence, world_size, payload) + + +def assert_no_register_spilling(path_to_ptx_file: Path): + ptx_code = path_to_ptx_file.read_text() + + # get the ".target sm_{arch}a" directive from the ptx code + sm_arch_match = re.search(r"\.target (sm_\w+)", ptx_code) + if not sm_arch_match: + raise RuntimeError(f"No .target directive found in {path_to_ptx_file}") + sm_arch = sm_arch_match.group(1) + + # Run ptxas + # -v: Verbose (prints register/spill stats) + # --gpu-name=sm_{arch}: Matches target hardware + ptxas_path = os.environ["TRITON_PTXAS_PATH"] + + cmd = [ptxas_path, "-v", f"--gpu-name={sm_arch}", str(path_to_ptx_file)] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + output = result.stderr + # Check for spill stores and loads + # Expected: "0 bytes spill stores, 0 bytes spill loads" + if "0 bytes spill stores, 0 bytes spill loads" not in output: + raise RuntimeError(f"Register spilling detected in {path_to_ptx_file}:\n{output}") + # otherwise, this test will fail with something like this: + # RuntimeError: Register spilling detected in /tmp/pytest-of-*/**/test_no_register_spilling_fwd0/**/.ptx: + # ptxas info : 28 bytes gmem + # ptxas info : Compiling entry function 'smooth_lddt_loss_fwd_kernel' for '{sm_arch}' + # ptxas info : Function properties for smooth_lddt_loss_fwd_kernel + # 14976 bytes stack frame, 64736 bytes spill stores, 81760 bytes spill loads + # ptxas info : Used 32 registers, used 1 barriers, 14976 bytes cumulative stack size + # ptxas info : Compile time = 34792.172 ms + except subprocess.CalledProcessError as e: + raise RuntimeError(f"ptxas failed with error:\n{e.stderr}") from e + + +# The spilling test is expensive to run and triton doesn't always follow the recompilation rules +# so we only run the test for a 1 case fwd and bwd +@pytest.mark.skipif(not has_smooth_lddt_loss_triton_kernels, reason="Triton kernels not available") +@pytest.mark.parametrize("precision", [Precision.FP32], ids=lambda x: f"{x}") +@pytest.mark.parametrize("B", [1], ids=lambda x: f"B:{x}") +@pytest.mark.parametrize("M", [16], ids=lambda x: f"M:{x}") +@pytest.mark.parametrize("N", [4608], ids=lambda x: f"N:{x}") +@pytest.mark.parametrize("fwd_or_bwd", ["fwd", "bwd"]) +def test_no_register_spilling(tmp_path, monkeypatch, precision, fwd_or_bwd, B, M, N): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # setup the env to dump the ptx code + # NOTE: the cuobjdump version must be recent enough to support the running GPU architecture + # either wise the cuobjdump call invoked by triton kernel dumping will fail. The user + # can set, e.g., TRITON_CUOBJDUMP_PATH=$CONDA_PREFIX/bin/cuobjdump or other cuobjdump available + # that supports the running GPU architecture + monkeypatch.setenv("TRITON_KERNEL_DUMP", "1") + monkeypatch.setenv("TRITON_DUMP_DIR", str(tmp_path)) + # Ensure cache dir is unique to avoid hitting cached kernels without dump + monkeypatch.setenv("TRITON_ALWAYS_COMPILE", "1") + monkeypatch.setenv("TRITON_CACHE_DIR", str(tmp_path / "cache")) + # NOTE: the ptxas version must be recent enough to support the running GPU architecture + # either wise the ptxas call later will fail + monkeypatch.setenv("TRITON_PTXAS_PATH", os.environ.get("TRITON_PTXAS_PATH", "ptxas")) + + # invoke the kernel to get the ptx code + D = 3 + + device = torch.device("cuda") + dtype = PRECISION_TO_DTYPE[precision] + + pred_coords_local = torch.randn(B * M, N, D, device=device, dtype=dtype) + true_coords_local = torch.randn(B * M, N, D, device=device, dtype=dtype) + pred_coords_t_local = torch.randn(B * M, N, D, device=device, dtype=dtype) + true_coords_t_local = torch.randn(B * M, N, D, device=device, dtype=dtype) + + is_nucleotide_local = torch.randint(0, 2, (B * M, N), device=device, dtype=torch.bool) + coords_mask_local = torch.randint(0, 2, (B * M, N), device=device, dtype=dtype) + coords_mask_t_local = torch.randint(0, 2, (B * M, N), device=device, dtype=dtype) + + num_result = torch.zeros(B * M, device=device, dtype=dtype) + den_result = torch.zeros(B * M, device=device, dtype=dtype) + + nucleic_acid_cutoff = 5.0 + other_cutoff = 3.0 + is_self_comm = False + + if fwd_or_bwd == "bwd": + grad_num = torch.randn_like(num_result) + grad_den = torch.randn_like(den_result) + + grad_pred_coords_local_result = torch.zeros_like(pred_coords_local) + grad_pred_coords_t_local_result = torch.zeros_like(pred_coords_t_local) + smooth_lddt_loss_bwd_kernel[grid_launch_config]( + grad_num, + grad_den, + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + grad_pred_coords_local_result, + grad_pred_coords_t_local_result, + pred_coords_local.stride(0), + pred_coords_local.stride(1), + pred_coords_local.stride(2), + true_coords_local.stride(0), + true_coords_local.stride(1), + true_coords_local.stride(2), + pred_coords_t_local.stride(0), + pred_coords_t_local.stride(1), + pred_coords_t_local.stride(2), + true_coords_t_local.stride(0), + true_coords_t_local.stride(1), + true_coords_t_local.stride(2), + is_nucleotide_local.stride(0), + is_nucleotide_local.stride(1), + coords_mask_local.stride(0), + coords_mask_local.stride(1), + coords_mask_t_local.stride(0), + coords_mask_t_local.stride(1), + grad_pred_coords_local_result.stride(0), + grad_pred_coords_local_result.stride(1), + grad_pred_coords_local_result.stride(2), + grad_pred_coords_t_local_result.stride(0), + grad_pred_coords_t_local_result.stride(1), + grad_pred_coords_t_local_result.stride(2), + nucleic_acid_cutoff, + other_cutoff, + is_self_comm, + pred_coords_local.shape[0], + pred_coords_local.shape[1], + coords_mask_local.shape[0], + ) + else: + smooth_lddt_loss_fwd_kernel[grid_launch_config]( + pred_coords_local, + true_coords_local, + pred_coords_t_local, + true_coords_t_local, + is_nucleotide_local, + coords_mask_local, + coords_mask_t_local, + num_result, + den_result, + pred_coords_local.stride(0), + pred_coords_local.stride(1), + pred_coords_local.stride(2), + true_coords_local.stride(0), + true_coords_local.stride(1), + true_coords_local.stride(2), + pred_coords_t_local.stride(0), + pred_coords_t_local.stride(1), + pred_coords_t_local.stride(2), + true_coords_t_local.stride(0), + true_coords_t_local.stride(1), + true_coords_t_local.stride(2), + is_nucleotide_local.stride(0), + is_nucleotide_local.stride(1), + coords_mask_local.stride(0), + coords_mask_local.stride(1), + coords_mask_t_local.stride(0), + coords_mask_t_local.stride(1), + nucleic_acid_cutoff, + other_cutoff, + is_self_comm, + pred_coords_local.shape[0], + pred_coords_local.shape[1], + coords_mask_local.shape[0], + ) + + # parse the ptx code to check for register spilling + ptx_files = list(tmp_path.glob(f"**/smooth_lddt_loss_{fwd_or_bwd}_kernel.ptx")) + + if not ptx_files: + raise RuntimeError(f"No PTX file found in {tmp_path}/**/smooth_lddt_loss_{fwd_or_bwd}_kernel.ptx") + + path_to_ptx_file = ptx_files[0] + + assert_no_register_spilling(path_to_ptx_file) diff --git a/tests/distributed/model/models/__init__.py b/tests/distributed/model/models/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/tests/distributed/model/models/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/tests/distributed/model/models/test_dtensor_boltz2.py b/tests/distributed/model/models/test_dtensor_boltz2.py new file mode 100644 index 000000000..571c3babd --- /dev/null +++ b/tests/distributed/model/models/test_dtensor_boltz2.py @@ -0,0 +1,3389 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for the Boltz2 distributed model wrapper. + +Verification checks: + V1: Construction – all serial parameters are present in the distributed wrapper + V2: Placeholder modules raise NotImplementedError on forward + V3: Ready submodules produce DTensor outputs with correct shapes + V4: configure_optimizers returns valid optimizer configuration + V5: configure_callbacks returns DistributedEMA when EMA is enabled + V6: on_after_backward redistributes DTensor gradients to Replicate + V7: Multi-rank construction and parameter identity across ranks + V8: on_load_checkpoint adjusts checkpoint hyperparameters + V9: Non-vacuous guards – distributed modules have DTensor params, + placeholders have plain params + V10: bf16 mixed precision – ready submodules produce bf16 outputs under + autocast, gradients are reduced in >=fp32 (dp=1, cp=1x1) + V11: EmbeddingParamsReplicated wrapping for token_bonds_type + V12: BFactorModule wrapping when predict_bfactor=True + V13: Forward/backward parity – distributed Boltz2 forward matches serial + (restored from dev-v2 for debug comparison with V16) + V14: predict_step parity – distributed predict_step matches serial + V14b: predict_step confidence output – 2-GPU smoke test verifying no + DTensor values leak into the predict output dict (which would crash + the BoltzWriter callback) + V16: Serial vs distributed training_step parity – compares loss, gradients, + post-optimizer parameters, and CSVLogger-captured logged metrics between + serial Boltz2.training_step and distributed Boltz2Distributed.training_step + with all 5 randomness sources controlled (recycling, noise, augmentation, + sampling, dropout). Parametrized across dp2-cp1x1 (DP only), dp1-cp2x2 + (CP only), and dp2-cp2x2 (DP + CP) to verify correctness under all + sharding modes. + V15: setup – validator wiring, predict no-op, datamodule=None + V17: Regression tests for get_true_coordinates and loss shape consistency – + verifies non-symmetry path returns DTensors, symmetry_correction raises + NotImplementedError, and loss zeros have scalar shape () + V18: validation_step and on_validation_epoch_end – real forward pass with + metric accumulation, aggregation, and logging +""" + +import math +import random as stdlib_random +import shutil +import tempfile +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import pytest +import torch +from pytorch_lightning.loggers import CSVLogger +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor + +import boltz.distributed.model.modules.diffusion as distributed_diffusion_module +import boltz.model.modules.diffusionv2 as serial_diffusion_v2_module +from boltz.data import const +from boltz.data.module.trainingv2 import Boltz2TrainingDataModule as Boltz2TrainingDataModuleSerial +from boltz.data.module.trainingv2 import collate +from boltz.distributed.data.module.trainingv2 import Boltz2TrainingDataModule as BoltzTrainingDataModuleDTensor +from boltz.distributed.data.utils import ( + ATOM_FEATURES_V2, + LIGAND_GEOMETRY_FEATURES, + distribute_features, + map_subgroup_mesh_to_cpu, +) +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.elementwise_op import ElementwiseOp, elementwise_op, scalar_tensor_op +from boltz.distributed.model.models.boltz2 import Boltz2 as Boltz2Distributed +from boltz.distributed.model.models.boltz2 import _PlaceholderModule +from boltz.distributed.model.validation.rcsb import DistributedRCSBValidator +from boltz.distributed.testing.utils import setup_mock_training_datamodule_config +from boltz.model.layers.attention import AttentionPairBias as AttentionPairBiasV1 +from boltz.model.layers.pairformer import PairformerLayer as SerialPairformerLayer +from boltz.model.models.boltz2 import Boltz2 as SerialBoltz2 +from boltz.model.validation.rcsb import RCSBValidator +from boltz.testing.utils import ( + SetModuleInfValues, + concat_data, + create_boltz2_model_init_params, + distribute_atom_features, + get_feature_placements, + init_module_params_glorot, + init_tensors_uniform, + pad_to_length, + random_features, + seed_by_rank, + spawn_multiprocessing, +) + + +class _DictNamespace: + """A picklable namespace with both attribute access and .get() support.""" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def get(self, key, default=None): + return self.__dict__.get(key, default) + + +class _LogCapture: + """Captures LightningModule.log() calls into a dict, backed by a CSVLogger. + + Replaces the usual ``lambda *a, **kw: None`` monkeypatch so that the + training_step logging code path is exercised rather than silenced. + After the step, :meth:`flush` persists the captured metrics to the CSV file. + """ + + def __init__(self, csv_logger: CSVLogger): + self.metrics: dict[str, float] = {} + self._csv_logger = csv_logger + + def __call__(self, name, value, **kwargs): + if isinstance(value, torch.Tensor): + value = value.detach().cpu().item() + self.metrics[name] = value + + def flush(self, step: int = 0) -> None: + """Write captured metrics to the backing CSVLogger.""" + self._csv_logger.log_metrics(self.metrics, step=step) + self._csv_logger.save() + + +def _setup_training_data_7z64_8b2e(out_dir: Path, base_data_dir: Path) -> Path: + """Merge 7z64 and 8b2e processed data into a single training directory with records. + + Uses the two smallest samples (8b2e=1062 atoms, 7z64=2278 atoms) to reduce + GPU memory consumption in FP64 parity tests. + """ + names = ["7z64", "8b2e"] + source_dirs = [base_data_dir / f"processed_{name}" for name in names] + merged = concat_data(out_dir, *source_dirs) + records_dir = merged / "records" + records_dir.mkdir(parents=True, exist_ok=True) + copied: set[str] = set() + for src in source_dirs: + for rf in (src / "records").glob("*.json"): + if rf.name in copied: + raise ValueError(f"Duplicate record file {rf.name}") + shutil.copy(rf, records_dir / rf.name) + copied.add(rf.name) + return merged + + +def _deterministic_getitem_monkeypatch(monkeypatch, dataset, base_seed=42): + """Wrap TrainingDataset.__getitem__ to seed all RNGs per idx. + + Ensures idx=0 picks sample 0 and idx=1 picks sample 1 by intercepting + np.random.choice calls for dataset and sample selection. Seeds np.random, + torch, and Python random with base_seed + idx before each call so that all + downstream RNG usage (cropper, featurizer, center_random_augmentation) is + deterministic for a given sample index. + """ + original_getitem = type(dataset).__getitem__ + + def _wrapped_getitem(self, idx): + np.random.seed(base_seed + idx) + torch.manual_seed(base_seed + idx) + stdlib_random.seed(base_seed + idx) + + _original_np_choice = np.random.choice + _call_count = [0] + _num_samples = len(self.samples[0]) + + def _deterministic_choice(a, p=None, **kwargs): + _call_count[0] += 1 + result = _original_np_choice(a, p=p, **kwargs) + if _call_count[0] == 1: + return 0 + elif _call_count[0] == 2: + return idx % _num_samples + return result + + np.random.choice = _deterministic_choice + try: + return original_getitem(self, idx) + finally: + np.random.choice = _original_np_choice + + monkeypatch.setattr(type(dataset), "__getitem__", _wrapped_getitem) + + +def _make_training_args(**overrides): + """Create minimal training_args for Boltz2.""" + defaults = { + "recycling_steps": 1, + "sampling_steps": 2, + "diffusion_multiplicity": 1, + "diffusion_samples": 1, + "diffusion_loss_weight": 1.0, + "distogram_loss_weight": 0.3, + "confidence_loss_weight": 0.0, + "bfactor_loss_weight": 0.0, + "symmetry_correction": False, + "adam_beta_1": 0.9, + "adam_beta_2": 0.95, + "adam_eps": 1e-8, + "base_lr": 1e-3, + "max_lr": 1e-3, + "lr_scheduler": "af3", + "lr_warmup_no_steps": 10, + "lr_start_decay_after_n_steps": 100, + "lr_decay_every_n_steps": 50000, + "lr_decay_factor": 0.95, + "weight_decay": 0.0, + } + defaults.update(overrides) + return _DictNamespace(**defaults) + + +def _make_validation_args(**overrides): + defaults = { + "recycling_steps": 0, + "sampling_steps": 2, + "diffusion_samples": 1, + "symmetry_correction": False, + "run_confidence_sequentially": False, + } + defaults.update(overrides) + return _DictNamespace(**defaults) + + +TOKEN_S = 32 +TOKEN_Z = 16 + +_BOLTZ2_SELECTED_KEYS = [ + "atom_pad_mask", + "atom_to_token", + "pair_mask", + "token_pad_mask", + "ref_pos", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "ref_space_uid", + "res_type", + "profile", + "deletion_mean", + "pocket_feature", + "atom_resolved_mask", + "mol_type", + "msa", + "has_deletion", + "deletion_value", + "msa_paired", + "msa_mask", + "token_bonds", + "type_bonds", + "token_pair_pad_mask", + "asym_id", + "residue_index", + "entity_id", + "token_index", + "sym_id", + "cyclic_period", + "coords", + "disto_target", + "token_disto_mask", + "atom_counts_per_token", + "token_to_rep_atom", + "frames_idx", + "contact_conditioning", + "contact_threshold", + "method_feature", + "modified", + "bfactor", + "plddt", +] + + +def _create_minimal_serial_boltz2( + confidence_prediction=False, + affinity_prediction=False, + ema=False, + bond_type_feature=False, + predict_bfactor=False, + validate_structure=False, + validators=None, + num_val_datasets=1, +): + """Create a minimal serial Boltz2 model for testing.""" + training_args = _make_training_args() + validation_args = _make_validation_args() + + pairformer_args = {"num_blocks": 1, "num_heads": 2, "dropout": 0.0} + + model = SerialBoltz2( + atom_s=16, + atom_z=8, + token_s=TOKEN_S, + token_z=TOKEN_Z, + num_bins=8, + training_args=training_args, + validation_args=validation_args, + embedder_args={ + "atom_encoder_depth": 1, + "atom_encoder_heads": 2, + "activation_checkpointing": False, + }, + msa_args={"msa_s": 16, "msa_blocks": 1, "msa_dropout": 0.0, "z_dropout": 0.0}, + pairformer_args=pairformer_args, + score_model_args={ + "sigma_data": 16.0, + "dim_fourier": 32, + "atom_encoder_depth": 1, + "atom_encoder_heads": 2, + "token_transformer_depth": 1, + "token_transformer_heads": 2, + "atom_decoder_depth": 1, + "atom_decoder_heads": 2, + "activation_checkpointing": False, + "conditioning_transition_layers": 1, + }, + diffusion_process_args={"num_sampling_steps": 2}, + diffusion_loss_args={}, + confidence_prediction=confidence_prediction, + confidence_model_args=( + {"pairformer_args": pairformer_args, "confidence_args": {}} if confidence_prediction else None + ), + affinity_prediction=affinity_prediction, + predict_args={"recycling_steps": 0, "sampling_steps": 2, "diffusion_samples": 1, "max_parallel_samples": 1}, + validate_structure=validate_structure, + structure_prediction_training=True, + ema=ema, + use_templates=False, + predict_bfactor=predict_bfactor, + bond_type_feature=bond_type_feature, + validators=validators if validate_structure else None, + num_val_datasets=num_val_datasets if validate_structure else 1, + ) + + return model + + +def _prepare_serial_model(ema=False, bond_type_feature=False, predict_bfactor=False): + """Create a serial model and return its state dict and hparams.""" + model = _create_minimal_serial_boltz2(ema=ema, bond_type_feature=bond_type_feature, predict_bfactor=predict_bfactor) + return model.state_dict(), dict(model.hparams) + + +def _init_distributed(rank, grid_group_sizes, device_type, backend, env_map): + """Common boilerplate: set env vars and initialize DistributedManager.""" + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + return monkeypatch, DistributedManager() + + +def _build_dist_model(serial_hparams, serial_state_dict, dist_manager): + """Create serial model, load state, wrap as distributed, move to device.""" + serial_model = SerialBoltz2(**serial_hparams) + serial_model.load_state_dict(serial_state_dict, strict=True) + serial_model = serial_model.to(dist_manager.device) + dist_model = Boltz2Distributed(serial_model, dist_manager) + return dist_model.to(dist_manager.device) + + +# ====================================================================== # +# Single-process unit tests (no distributed setup needed) # +# ====================================================================== # + + +def test_wrapper_preserves_all_serial_parameters(): + """V1: All serial parameters should appear in the distributed wrapper.""" + serial_model = _create_minimal_serial_boltz2() + serial_param_names = set(serial_model.state_dict().keys()) + assert len(serial_param_names) > 0, "Serial model has no parameters" + + expected_prefixes = { + "s_init", + "z_init_1", + "z_init_2", + "s_norm", + "z_norm", + "s_recycle", + "z_recycle", + "token_bonds", + "msa_module", + "pairformer_module", + "distogram_module", + "input_embedder", + "rel_pos", + "contact_conditioning", + "diffusion_conditioning", + "structure_module", + } + found_prefixes = {n.split(".")[0] for n in serial_param_names} + missing = expected_prefixes - found_prefixes + assert not missing, f"Expected parameter prefixes missing from serial model: {missing}" + + +def test_wrapper_hparams_saved(): + """V1b: Hyper-parameters should be preserved in save_hyperparameters.""" + serial_model = _create_minimal_serial_boltz2() + assert hasattr(serial_model, "hparams"), "Serial model should have hparams" + assert "atom_s" in serial_model.hparams + assert "token_s" in serial_model.hparams + + +def test_placeholder_raises_on_forward(): + """V2: Placeholder module should raise NotImplementedError.""" + placeholder = _PlaceholderModule(torch.nn.Linear(4, 8), "TestModule") + with pytest.raises(NotImplementedError, match="TestModule"): + placeholder(torch.randn(2, 4)) + + +def test_placeholder_preserves_parameters(): + """V2b: Placeholder should expose the serial module's parameters, frozen.""" + placeholder = _PlaceholderModule(torch.nn.Linear(4, 8), "TestModule") + params = dict(placeholder.named_parameters()) + assert "_serial.weight" in params + assert "_serial.bias" in params + assert params["_serial.weight"].shape == (8, 4) + for name, p in placeholder.named_parameters(): + assert not p.requires_grad, f"Placeholder param '{name}' should be frozen" + + +def test_serial_model_ema_callbacks(): + """V5 precondition: serial EMA callback configuration.""" + assert _create_minimal_serial_boltz2(ema=True).use_ema is True + assert len(_create_minimal_serial_boltz2(ema=True).configure_callbacks()) == 1 + assert _create_minimal_serial_boltz2(ema=False).use_ema is False + assert len(_create_minimal_serial_boltz2(ema=False).configure_callbacks()) == 0 + + +def test_confidence_model_can_be_created(): + """Confidence serial model can be created with proper args.""" + assert _create_minimal_serial_boltz2(confidence_prediction=True).confidence_prediction is True + + +def test_bfactor_in_atom_features_v2(): + """Data pipeline: bfactor must be in ATOM_FEATURES_V2 for distributed sharding.""" + assert "bfactor" in ATOM_FEATURES_V2 + + +def test_pairformer_v1_v2_detection_preconditions(): + """V1/V2 detection: serial V1 attention has norm_s, V2 does not. + + The distributed PairformerLayer infers v1/v2 from + ``hasattr(layer.attention, 'norm_s')``. This test verifies the + serial-side preconditions that the heuristic depends on. + """ + token_s, token_z, num_heads = 32, 16, 2 + + v1_layer = SerialPairformerLayer(token_s, token_z, num_heads, v2=False) + v2_layer = SerialPairformerLayer(token_s, token_z, num_heads, v2=True) + + assert isinstance(v1_layer.attention, AttentionPairBiasV1) + assert hasattr(v1_layer.attention, "norm_s"), "V1 attention must have norm_s" + assert not hasattr(v2_layer.attention, "norm_s"), "V2 attention must not have norm_s" + + # The heuristic: not hasattr(layer.attention, "norm_s") → v2 + assert (not hasattr(v1_layer.attention, "norm_s")) is False # → V1 + assert (not hasattr(v2_layer.attention, "norm_s")) is True # → V2 + + +def test_checkpoint_lr_is_overwritten(): + """V8: on_load_checkpoint should overwrite lr and weight_decay in checkpoint.""" + serial_model = _create_minimal_serial_boltz2() + checkpoint = { + "optimizer_states": [{"param_groups": [{"lr": 0.5, "weight_decay": 0.99}, {"lr": 0.5, "weight_decay": 0.99}]}], + "lr_schedulers": [{"max_lr": 0.5, "base_lrs": [0.5, 0.5], "_last_lr": [0.5, 0.5]}], + "hyper_parameters": { + "training_args": {"max_lr": 0.5, "diffusion_multiplicity": 99, "recycling_steps": 99, "weight_decay": 0.99} + }, + } + serial_model.on_load_checkpoint(checkpoint) + + for pg in checkpoint["optimizer_states"][0]["param_groups"]: + assert pg["lr"] == 1e-3 and pg["weight_decay"] == 0.0 + + sched = checkpoint["lr_schedulers"][0] + assert sched["max_lr"] == 1e-3 + assert all(lr == 1e-3 for lr in sched["base_lrs"]) + + hp = checkpoint["hyper_parameters"]["training_args"] + assert hp["max_lr"] == 1e-3 and hp["diffusion_multiplicity"] == 1 and hp["weight_decay"] == 0.0 + + +# ====================================================================== # +# Comprehensive multi-rank worker (V3-V7, V9, V11) # +# ====================================================================== # + + +# V13 (forward/backward parity) was removed — V16 (training_step parity) now +# includes dp1-cp2x2 and dp2-cp2x2 parametrizations that provide strictly +# stronger guarantees (loss + gradient + post-optimizer parity) at the same +# topologies. This matches the Boltz-1 test structure where +# test_boltz1_model_parallel_training_step with cp2x2 is the primary parity test. + + +# ====================================================================== # +# V10: bf16 mixed precision (CUDA-only) # +# ====================================================================== # + + +def _worker_bf16_mixed_precision( + rank: int, + serial_state_dict: dict, + serial_hparams: dict, + grid_group_sizes: dict, + device_type: str, + backend: str, + env_map: dict[str, str] | None = None, +): + """Worker: exercise ready submodules under torch.autocast(bf16). + + Verifies forward outputs are bf16, gradients are reduced in >=fp32, + and clear_autocast_cache() does not error. + """ + monkeypatch, dm = _init_distributed(rank, grid_group_sizes, device_type, backend, env_map) + + dist_model = _build_dist_model(serial_hparams, serial_state_dict, dm) + dist_model.train() + + dp = grid_group_sizes["dp"] + B = max(1, dp) + N_global = 8 * grid_group_sizes["cp"][0] + single_pl = (Shard(0), Shard(1), Replicate()) + pair_pl = (Shard(0), Shard(1), Shard(2)) + + # Forward under autocast + with torch.autocast("cuda", dtype=torch.bfloat16): + s_in = distribute_tensor( + torch.randn(B, N_global, TOKEN_S, device=dm.device), dm.device_mesh_subgroups, single_pl + ) + s_out = dist_model.s_init(s_in) + assert s_out.to_local().dtype == torch.bfloat16 + + s_norm_out = dist_model.s_norm(s_in) + assert s_norm_out.to_local().dtype == torch.float32 # autocast LayerNorm policy + + z1 = dist_model.z_init_1(s_in) + assert z1.to_local().dtype == torch.bfloat16 + + z_in = distribute_tensor( + torch.randn(B, N_global, N_global, TOKEN_Z, device=dm.device), dm.device_mesh_subgroups, pair_pl + ) + disto = dist_model.distogram_module(z_in) + assert disto.to_local().dtype == torch.bfloat16 + + # Backward: gradient dtype and placement + s_grad_in = distribute_tensor( + torch.randn(B, N_global, TOKEN_S, device=dm.device), dm.device_mesh_subgroups, single_pl + ) + with torch.autocast("cuda", dtype=torch.bfloat16): + out = dist_model.s_init(s_grad_in) + out.to_local().sum().backward() + + w = dist_model.s_init.weight + assert w.grad is not None and isinstance(w.grad, DTensor) + for p in w.grad.placements: + assert isinstance(p, Replicate), f"Weight grad should be Replicate, got {p}" + assert w.grad.to_local().dtype == w.to_local().dtype # fp32 via promote_types + assert w.grad.to_local().abs().sum() > 0 + + # clear_autocast_cache branch + recycling path + dist_model.zero_grad() + with torch.autocast("cuda", dtype=torch.bfloat16): + torch.clear_autocast_cache() + s_zeros = distribute_tensor( + torch.zeros(B, N_global, TOKEN_S, device=dm.device, dtype=torch.bfloat16), + dm.device_mesh_subgroups, + single_pl, + ) + s_recycled = elementwise_op( + dist_model.s_init(s_grad_in), + dist_model.s_recycle(dist_model.s_norm(s_zeros)), + ElementwiseOp.SUM, + ) + assert s_recycled.to_local().dtype == torch.bfloat16 + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + # Minimal dp=1, cp=1x1: focused autocast dtype check. + # Full BF16 training+DP+CP is covered by test_boltz2_finetune_from_checkpoint + # in test_dtensor_boltz2_train.py and test_boltz2_run_predict in + # test_dtensor_predict.py. + ((1, (1, 1)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp1-cp1x1"], +) +def test_boltz2_bf16_mixed_precision(setup_env): + """V10: Ready submodules produce bf16 outputs under autocast, grads reduced in fp32. + + Focused autocast check on individual submodules (s_init, z_init_1, + distogram_module, s_norm) with dp=1, cp=1x1. Full BF16-mixed training + across DP and CP topologies is exercised by + ``test_boltz2_finetune_from_checkpoint`` and ``test_boltz2_run_predict``. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + sd, hp = _prepare_serial_model(ema=False) + spawn_multiprocessing( + _worker_bf16_mixed_precision, + world_size, + sd, + hp, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +# ====================================================================== # +# V14: predict_step parity (inference sampling) # +# ====================================================================== # + + +def parallel_assert_boltz2_model_predict_step( + rank, + grid_group_sizes, + device_type, + backend, + dtype, + boltz2_model_params, + module_state_dict, + predict_args, + diffusion_samples, + num_sampling_steps, + input_feats_global_fp64_host, + init_noise_global_host, + step_noise_list_global_host, + serial_coords_host, + serial_masks_host, + env_per_rank=None, +): + """V14 multi-rank worker: verify distributed predict_step matches serial.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + reference_module = SerialBoltz2(**boltz2_model_params) + reference_module = reference_module.to(dtype=dtype) + reference_module.load_state_dict(module_state_dict) + reference_module.structure_module.coordinate_augmentation = False + reference_module.apply(SetModuleInfValues()) + reference_module = reference_module.to(device=manager.device) + module = Boltz2Distributed(reference_module, manager) + module.eval() + + host_tensor_keys = {k for k, v in input_feats_global_fp64_host.items() if isinstance(v, torch.Tensor)} + _placements = get_feature_placements( + token_keys=host_tensor_keys, + msa_keys=host_tensor_keys, + atom_keys={ + "ref_pos", + "atom_resolved_mask", + "ref_element", + "ref_charge", + "ref_atom_name_chars", + "ref_space_uid", + "coords", + "atom_pad_mask", + "atom_to_token", + "pair_mask", + "atom_counts_per_token", + "token_to_rep_atom", + "bfactor", + "plddt", + }, + model_io_keys={"noise"}, + model_io_fp32_keys=set(), + ) + _placements_token_features = _placements["token_features"] + _placements_msa_features = _placements["msa_features"] + _placements_cp_atom_features = _placements["cp_atom_features"] + _placements_atom_features = _placements["atom_features"] + _placements_cp_model_io = _placements["cp_model_io"] + _placements_model_io = _placements["model_io"] + + # ------------------------------------------------------------------ + # Distribute token + MSA features (rank 0 broadcasts) + # ------------------------------------------------------------------ + if manager.group_rank["world"] == 0: + input_feats_token_msa_global = { + k: v.to(device=manager.device, dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in input_feats_global_fp64_host.items() + if k in _placements_token_features or k in _placements_msa_features + } + else: + input_feats_token_msa_global = None + + feats_token_msa = distribute_features( + input_feats_token_msa_global, + _placements_token_features | _placements_msa_features, + manager.group["world"], + manager.group_ranks["world"][0], + manager.device_mesh_subgroups, + ) + + # ------------------------------------------------------------------ + # Distribute atom features + sampling noise via distribute_atom_features. + # Noise tensors use the _noise_{i_noise}_{i_mul} naming convention + # from test_atom_diffusion_sample so that intersperse padding naturally + # places zeros at padding positions. + # ------------------------------------------------------------------ + size_batch = input_feats_global_fp64_host["atom_pad_mask"].shape[0] + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in input_feats_global_fp64_host.items() + if k in _placements_cp_atom_features + } + + all_noise = [init_noise_global_host] + list(step_noise_list_global_host) + for i_noise, noise_host in enumerate(all_noise): + unflat = noise_host.unflatten(0, (size_batch, diffusion_samples)) + for i_mul in range(diffusion_samples): + inputs_atom[f"_noise_{i_noise}_{i_mul}"] = unflat[:, i_mul].to(dtype=dtype) + + noise_cp_placements = {} + noise_placements = {} + for i_noise in range(len(all_noise)): + for i_mul in range(diffusion_samples): + key = f"_noise_{i_noise}_{i_mul}" + noise_cp_placements[key] = _placements_cp_model_io["noise"] + noise_placements[key] = _placements_model_io["noise"] + + feats_and_noise = distribute_atom_features( + inputs=inputs_atom, + placements_cp=_placements_cp_atom_features | noise_cp_placements, + placements_dp_cp=_placements_atom_features | noise_placements, + device_mesh=manager.device_mesh_subgroups, + cp_group=manager.group["cp"], + multiplicities={f"_noise_{i}": diffusion_samples for i in range(len(all_noise))}, + ) + + noise_dts = [] + for i_noise in range(len(all_noise)): + noise_dts.append(feats_and_noise.pop(f"_noise_{i_noise}")) + init_noise_dt = noise_dts[0] + step_noise_dts = noise_dts[1:] + + feats_dt = {**feats_token_msa, **feats_and_noise} + + # ------------------------------------------------------------------ + # Monkeypatch distributed sample() for determinism + # ------------------------------------------------------------------ + _orig_center_random_augmentation = distributed_diffusion_module.center_random_augmentation + + def _centering_only_augmentation(atom_coords, atom_mask, **kwargs): + kwargs["augmentation"] = False + kwargs["centering"] = True + return _orig_center_random_augmentation(atom_coords, atom_mask, **kwargs) + + _dt_randn_calls = [] + _dt_randn_sequence = [init_noise_dt] + step_noise_dts + + def _fixed_create_distributed_randn(shape, device_mesh, placements, dtype=torch.float32, scale=1.0): + idx = len(_dt_randn_calls) + _dt_randn_calls.append(idx) + noise_dt = _dt_randn_sequence[idx] + if scale != 1.0: + noise_dt = scalar_tensor_op(scale, noise_dt, ElementwiseOp.PROD) + return noise_dt + + monkeypatch.setattr(distributed_diffusion_module, "center_random_augmentation", _centering_only_augmentation) + monkeypatch.setattr(distributed_diffusion_module, "create_distributed_randn", _fixed_create_distributed_randn) + + # ------------------------------------------------------------------ + # Run distributed predict_step + # ------------------------------------------------------------------ + module.predict_args = predict_args + with torch.no_grad(): + pred_dict = module.predict_step(feats_dt, batch_idx=0) + + assert pred_dict["exception"] is False + + # ------------------------------------------------------------------ + # Compare on gather rank 0 of CP axis 0. + # The distributed predict_step gathers coords/masks via + # torch.distributed.gather + concat(dim=1). With dp>1 each DP rank + # gathers only its own batch element(s), so we slice the serial + # reference by DP rank. + # ------------------------------------------------------------------ + tag_group_gather = 0 + if manager.subgroups_rank["cp"][tag_group_gather] == 0: + gathered_mask = pred_dict["masks"] + gathered_coords = pred_dict["coords"] + + mask_expanded = gathered_mask.repeat_interleave(diffusion_samples, 0).bool() + dt_real = gathered_coords[mask_expanded] + + n_dp = grid_group_sizes["dp"] + dp_rank = manager.group_rank["dp"] + B_local = size_batch // n_dp + M = diffusion_samples + serial_coords_slice = serial_coords_host[dp_rank * B_local * M : (dp_rank + 1) * B_local * M] + serial_mask_slice = serial_masks_host[dp_rank * B_local : (dp_rank + 1) * B_local] + + serial_coords_device = serial_coords_slice.to(device=manager.device, dtype=dtype) + serial_mask_expanded = serial_mask_slice.to(device=manager.device).repeat_interleave(M, 0).bool() + serial_real = serial_coords_device[serial_mask_expanded] + + torch.testing.assert_close(dt_real, serial_real) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + # dp=2, CP=(2,2) on CUDA — exercises DP slicing + CP gather + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda val: f"{val[2]}-dp:{val[0][0]}-cp{'x'.join(map(str, val[0][1]))}", +) +def test_boltz2_predict_step(setup_env): + """V14: predict_step parity between distributed and serial Boltz2. + + Tests that the distributed Boltz2.predict_step produces numerically + identical sampled coordinates compared to the serial implementation + in eval mode with no backward pass. + + Side-by-side comparison findings (serial vs distributed): + + Bug 1 — Random augmentation always applied in serial sample(): + Serial diffusionv2.py:370-376 calls compute_random_augmentation() + unconditionally. Distributed diffusion.py:982-994 gates + center_random_augmentation() on self.coordinate_augmentation. + Mitigation: monkeypatch serial compute_random_augmentation to return + identity rotation and zero translation (same as test_atom_diffusion_sample). + + Bug 2 — alignment_reverse_diff FP32 downcast in serial: + Serial diffusionv2.py:564-573 forces .float() (FP32) for + weighted_rigid_align. Distributed diffusion.py:1076-1082 passes + DTensors as-is (FP64 in test). + Mitigation: set alignment_reverse_diff=False. + + Bug 3 — Serial sample() augmentation shape bug for B > 1: + compute_random_augmentation(multiplicity) returns R of shape (M,3,3) + but atom_coords has shape (B*M,N,3) — einsum crashes when B > 1. + Mitigation: the identity augmentation mock returns (B*M,3,3), making + the einsum a no-op for any batch size. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + dtype = torch.float64 + size_batch_per_rank = 1 + B = size_batch_per_rank * grid_group_sizes["dp"] + size_cp = grid_group_sizes["cp"][0] + + min_val_init = -0.01 + max_val_init = 0.01 + scale_glorot = 0.05 + + num_sampling_steps = 2 + diffusion_samples = 2 + + seed = 42 + seed_by_rank(0, seed=seed) + + boltz2_model_params = create_boltz2_model_init_params(use_large_model=False) + boltz2_model_params["diffusion_process_args"]["alignment_reverse_diff"] = False + + predict_args = { + "recycling_steps": 0, + "sampling_steps": num_sampling_steps, + "diffusion_samples": diffusion_samples, + "max_parallel_samples": diffusion_samples, + "write_confidence_summary": False, + "write_full_pae": False, + } + boltz2_model_params["predict_args"] = predict_args + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + n_tokens = 30 * size_cp + W = boltz2_model_params["atoms_per_window_queries"] + n_atoms_raw = n_tokens * n_atoms_per_token_max + n_atoms = ((n_atoms_raw + W - 1) // W) * W + n_msa = size_cp * 2 + + assert n_atoms % size_cp == 0 + assert n_atoms % W == 0 + + input_feats_global_fp64 = random_features( + size_batch=B, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=n_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=device_type, + float_value_range=(min_val_init, max_val_init), + selected_keys=_BOLTZ2_SELECTED_KEYS, + num_disto_bins=boltz2_model_params["num_bins"], + ) + input_feats_global_fp64["msa"] = torch.randint( + 0, const.num_tokens, (B, n_msa, n_tokens), dtype=torch.int64, device=device_type + ) + + # ------------------------------------------------------------------ + # Build serial model in eval mode + # ------------------------------------------------------------------ + reference_module = SerialBoltz2(**boltz2_model_params) + init_module_params_glorot(reference_module, gain=scale_glorot) + reference_module.apply(SetModuleInfValues()) + reference_module.structure_module.coordinate_augmentation = False + module_state_dict_fp64 = reference_module.state_dict() + reference_module = reference_module.to(dtype=torch.float64, device=device_type).eval() + + # ------------------------------------------------------------------ + # Pre-generate deterministic noise for sampling. + # sample_schedule(N) returns N+1 sigmas (F.pad adds trailing 0), so + # sigmas_and_gammas has N entries → N denoising steps. Total + # torch.randn calls = 1 (init) + N (per-step eps) = N+1. + # ------------------------------------------------------------------ + _B_M = B * diffusion_samples + init_noise = torch.empty((_B_M, n_atoms, 3), device=device_type, dtype=dtype) + step_noise_list = [ + torch.empty((_B_M, n_atoms, 3), device=device_type, dtype=dtype) for _ in range(num_sampling_steps) + ] + init_tensors_uniform([init_noise, *step_noise_list], low=min_val_init, high=max_val_init) + + # ------------------------------------------------------------------ + # Monkeypatch serial sample() for determinism (Bug 1/3 mitigations) + # ------------------------------------------------------------------ + def _identity_compute_random_augmentation(multiplicity_arg, device=None, dtype=None): + R = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(_B_M, -1, -1) + tr = torch.zeros(_B_M, 1, 3, device=device, dtype=dtype) + return R, tr + + _serial_randn_calls = [] + _serial_randn_sequence = [init_noise] + step_noise_list + + def _fixed_randn(*args, **kwargs): + idx = len(_serial_randn_calls) + _serial_randn_calls.append(idx) + return _serial_randn_sequence[idx].clone() + + _monkeypatch = pytest.MonkeyPatch() + _monkeypatch.setattr( + serial_diffusion_v2_module, "compute_random_augmentation", _identity_compute_random_augmentation + ) + _monkeypatch.setattr(serial_diffusion_v2_module.torch, "randn", _fixed_randn) + + # ------------------------------------------------------------------ + # Serial predict_step + # ------------------------------------------------------------------ + with torch.no_grad(): + serial_pred_dict = reference_module.predict_step(input_feats_global_fp64, batch_idx=0) + + assert serial_pred_dict["exception"] is False + serial_coords = serial_pred_dict["coords"] + serial_masks = serial_pred_dict["masks"] + + _monkeypatch.undo() + + # ------------------------------------------------------------------ + # Move everything to CPU for spawn_multiprocessing + # ------------------------------------------------------------------ + input_feats_host = {k: v.detach().to(device="cpu", copy=True) for k, v in input_feats_global_fp64.items()} + serial_coords_host = serial_coords.detach().to(device="cpu", copy=True) + serial_masks_host = serial_masks.detach().to(device="cpu", copy=True) + + spawn_multiprocessing( + parallel_assert_boltz2_model_predict_step, + world_size, + grid_group_sizes, + device_type, + backend, + dtype, + boltz2_model_params, + module_state_dict_fp64, + predict_args, + diffusion_samples, + num_sampling_steps, + input_feats_host, + init_noise.cpu(), + [n.cpu() for n in step_noise_list], + serial_coords_host, + serial_masks_host, + env_per_rank, + ) + + +# ====================================================================== # +# V14b: predict_step confidence output – no DTensor leaks # +# ====================================================================== # + + +def _worker_predict_step_confidence( + rank, + grid_group_sizes, + device_type, + backend, + dtype, + boltz2_model_params, + module_state_dict, + predict_args, + diffusion_samples, + num_sampling_steps, + input_feats_global_fp64_host, + init_noise_global_host, + step_noise_list_global_host, + env_per_rank=None, +): + """V14b worker: verify predict_step with confidence produces no DTensor leaks.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + reference_module = SerialBoltz2(**boltz2_model_params) + reference_module = reference_module.to(dtype=dtype) + reference_module.load_state_dict(module_state_dict) + reference_module.structure_module.coordinate_augmentation = False + reference_module.apply(SetModuleInfValues()) + reference_module = reference_module.to(device=manager.device) + module = Boltz2Distributed(reference_module, manager) + module.eval() + + host_tensor_keys = {k for k, v in input_feats_global_fp64_host.items() if isinstance(v, torch.Tensor)} + _placements = get_feature_placements( + token_keys=host_tensor_keys, + msa_keys=host_tensor_keys, + atom_keys={ + "ref_pos", + "atom_resolved_mask", + "ref_element", + "ref_charge", + "ref_atom_name_chars", + "ref_space_uid", + "coords", + "atom_pad_mask", + "atom_to_token", + "pair_mask", + "atom_counts_per_token", + "token_to_rep_atom", + "bfactor", + "plddt", + }, + model_io_keys={"noise"}, + model_io_fp32_keys=set(), + ) + _placements_token_features = _placements["token_features"] + _placements_msa_features = _placements["msa_features"] + _placements_cp_atom_features = _placements["cp_atom_features"] + _placements_cp_model_io = _placements["cp_model_io"] + _placements_model_io = _placements["model_io"] + + if manager.group_rank["world"] == 0: + input_feats_token_msa_global = { + k: v.to(device=manager.device, dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in input_feats_global_fp64_host.items() + if k in _placements_token_features or k in _placements_msa_features + } + else: + input_feats_token_msa_global = None + + feats_token_msa = distribute_features( + input_feats_token_msa_global, + _placements_token_features | _placements_msa_features, + manager.group["world"], + manager.group_ranks["world"][0], + manager.device_mesh_subgroups, + ) + + size_batch = input_feats_global_fp64_host["atom_pad_mask"].shape[0] + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in input_feats_global_fp64_host.items() + if k in _placements_cp_atom_features + } + + all_noise = [init_noise_global_host] + list(step_noise_list_global_host) + for i_noise, noise_host in enumerate(all_noise): + unflat = noise_host.unflatten(0, (size_batch, diffusion_samples)) + for i_mul in range(diffusion_samples): + inputs_atom[f"_noise_{i_noise}_{i_mul}"] = unflat[:, i_mul].to(dtype=dtype) + + noise_cp_placements = {} + noise_placements = {} + for i_noise in range(len(all_noise)): + for i_mul in range(diffusion_samples): + key = f"_noise_{i_noise}_{i_mul}" + noise_cp_placements[key] = _placements_cp_model_io["noise"] + noise_placements[key] = _placements_model_io["noise"] + + feats_and_noise = distribute_atom_features( + inputs=inputs_atom, + placements_cp=_placements_cp_atom_features | noise_cp_placements, + placements_dp_cp=_placements["atom_features"] | noise_placements, + device_mesh=manager.device_mesh_subgroups, + cp_group=manager.group["cp"], + multiplicities={f"_noise_{i}": diffusion_samples for i in range(len(all_noise))}, + ) + + noise_dts = [] + for i_noise in range(len(all_noise)): + noise_dts.append(feats_and_noise.pop(f"_noise_{i_noise}")) + init_noise_dt = noise_dts[0] + step_noise_dts = noise_dts[1:] + + feats_dt = {**feats_token_msa, **feats_and_noise} + + _orig_center_random_augmentation = distributed_diffusion_module.center_random_augmentation + + def _centering_only_augmentation(atom_coords, atom_mask, **kwargs): + kwargs["augmentation"] = False + kwargs["centering"] = True + return _orig_center_random_augmentation(atom_coords, atom_mask, **kwargs) + + _dt_randn_calls = [] + _dt_randn_sequence = [init_noise_dt] + step_noise_dts + + def _fixed_create_distributed_randn(shape, device_mesh, placements, dtype=torch.float32, scale=1.0): + idx = len(_dt_randn_calls) + _dt_randn_calls.append(idx) + noise_dt = _dt_randn_sequence[idx] + if scale != 1.0: + noise_dt = scalar_tensor_op(scale, noise_dt, ElementwiseOp.PROD) + return noise_dt + + monkeypatch.setattr(distributed_diffusion_module, "center_random_augmentation", _centering_only_augmentation) + monkeypatch.setattr(distributed_diffusion_module, "create_distributed_randn", _fixed_create_distributed_randn) + + module.predict_args = predict_args + with torch.no_grad(): + pred_dict = module.predict_step(feats_dt, batch_idx=0) + + assert pred_dict["exception"] is False, "predict_step raised an exception" + + expected_confidence_keys = {"pde", "plddt", "complex_plddt", "complex_iplddt", "complex_pde", "complex_ipde"} + missing = expected_confidence_keys - set(pred_dict.keys()) + assert not missing, f"Missing confidence keys in predict_step output: {missing}" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (1, 1)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda val: f"{val[2]}-dp:{val[0][0]}-cp{'x'.join(map(str, val[0][1]))}", +) +def test_boltz2_predict_step_confidence(setup_env): + """V14b: predict_step with confidence_prediction=True produces no DTensor leaks. + + Runs on 1 GPU (dp=1, cp=1x1) and verifies that the predict_step output + dict contains only plain tensors, catching any DTensor-to-writer leaks + that would crash the BoltzWriter callback. Even with trivial CP sharding, + the model produces DTensor outputs; the _assert_no_dtensors_in_output guard + in predict_step catches any unconverted DTensors at runtime. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + dtype = torch.float64 + size_batch = 1 + size_cp = grid_group_sizes["cp"][0] + + num_sampling_steps = 2 + diffusion_samples = 1 + + seed = 42 + seed_by_rank(0, seed=seed) + + boltz2_model_params = create_boltz2_model_init_params(use_large_model=False) + boltz2_model_params["diffusion_process_args"]["alignment_reverse_diff"] = False + boltz2_model_params["num_bins"] = 64 + boltz2_model_params["confidence_prediction"] = True + boltz2_model_params["confidence_model_args"] = { + "pairformer_args": boltz2_model_params["pairformer_args"], + "confidence_args": {}, + } + + predict_args = { + "recycling_steps": 0, + "sampling_steps": num_sampling_steps, + "diffusion_samples": diffusion_samples, + "max_parallel_samples": diffusion_samples, + "write_confidence_summary": True, + "write_full_pae": True, + } + boltz2_model_params["predict_args"] = predict_args + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + n_tokens = 30 * size_cp + W = boltz2_model_params["atoms_per_window_queries"] + n_atoms_raw = n_tokens * n_atoms_per_token_max + n_atoms = ((n_atoms_raw + W - 1) // W) * W + n_msa = size_cp * 2 + + assert n_atoms % size_cp == 0 + assert n_atoms % W == 0 + + input_feats_global_fp64 = random_features( + size_batch=size_batch, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=n_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=device_type, + float_value_range=(-0.01, 0.01), + selected_keys=_BOLTZ2_SELECTED_KEYS, + num_disto_bins=boltz2_model_params["num_bins"], + ) + input_feats_global_fp64["msa"] = torch.randint( + 0, const.num_tokens, (size_batch, n_msa, n_tokens), dtype=torch.int64, device=device_type + ) + + reference_module = SerialBoltz2(**boltz2_model_params) + init_module_params_glorot(reference_module, gain=0.05) + reference_module.apply(SetModuleInfValues()) + reference_module.structure_module.coordinate_augmentation = False + module_state_dict_fp64 = reference_module.state_dict() + + _B_M = size_batch * diffusion_samples + init_noise = torch.empty((_B_M, n_atoms, 3), device=device_type, dtype=dtype) + step_noise_list = [ + torch.empty((_B_M, n_atoms, 3), device=device_type, dtype=dtype) for _ in range(num_sampling_steps) + ] + init_tensors_uniform([init_noise, *step_noise_list], low=-0.01, high=0.01) + + input_feats_host = {k: v.detach().to(device="cpu", copy=True) for k, v in input_feats_global_fp64.items()} + + spawn_multiprocessing( + _worker_predict_step_confidence, + world_size, + grid_group_sizes, + device_type, + backend, + dtype, + boltz2_model_params, + module_state_dict_fp64, + predict_args, + diffusion_samples, + num_sampling_steps, + input_feats_host, + init_noise.cpu(), + [n.cpu() for n in step_noise_list], + env_per_rank, + ) + + +# ====================================================================== # +# V13: Forward/backward parity (serial vs distributed) # +# ====================================================================== # + + +def _worker_forward_backward_parity( + rank, + grid_group_sizes, + device_type, + backend, + dtype, + boltz2_model_params, + module_state_dict, + n_recycles, + multiplicity_diffusion_train, + input_feats_global_fp64_host, + sigmas_expected_global_fp64_host, + noise_expected_global_fp64_host, + output_pdistogram_expected_global_host, + output_denoised_atom_coords_expected_global_host, + output_pbfactor_expected_global_host, + output_s_expected_global_host, + output_z_expected_global_host, + output_aligned_true_coords_expected_global_host, + d_output_pdistogram_expected_global_host, + d_output_denoised_atom_coords_expected_global_host, + d_output_pbfactor_expected_global_host, + expected_param_grads_global_host_dict, + env_per_rank=None, + use_random_features=True, + training_data_dir=None, + canonical_mols_dir=None, + base_seed=42, +): + """V13 multi-rank worker: verify distributed forward/backward matches serial.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + reference_module = SerialBoltz2(**boltz2_model_params) + reference_module = reference_module.to(dtype=dtype) + reference_module.load_state_dict(module_state_dict) + reference_module.structure_module.coordinate_augmentation = False + reference_module.apply(SetModuleInfValues()) + reference_module = reference_module.to(device=manager.device) + module = Boltz2Distributed(reference_module, manager) + module.train() + + _io_model_io_keys = {"noise", "denoised_atom_coords", "d_denoised_atom_coords", "aligned_true_atom_coords"} + + if use_random_features: + host_tensor_keys = {k for k, v in input_feats_global_fp64_host.items() if isinstance(v, torch.Tensor)} + _placements = get_feature_placements( + token_keys=host_tensor_keys, + msa_keys=host_tensor_keys, + atom_keys={ + "ref_pos", + "atom_resolved_mask", + "ref_element", + "ref_charge", + "ref_atom_name_chars", + "ref_space_uid", + "coords", + "atom_pad_mask", + "atom_to_token", + "pair_mask", + "atom_counts_per_token", + "token_to_rep_atom", + "bfactor", + "plddt", + }, + model_io_keys=_io_model_io_keys, + model_io_fp32_keys=set(), + ) + + # Token + MSA features: broadcast from rank 0 + if manager.group_rank["world"] == 0: + input_feats_token_msa_global = { + k: v.to(device=manager.device, dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in input_feats_global_fp64_host.items() + if k in _placements["token_features"] or k in _placements["msa_features"] + } + else: + input_feats_token_msa_global = None + + feats_token_msa = distribute_features( + input_feats_token_msa_global, + _placements["token_features"] | _placements["msa_features"], + manager.group["world"], + manager.group_ranks["world"][0], + manager.device_mesh_subgroups, + ) + + # Atom features: scatter across CP mesh + size_batch = input_feats_global_fp64_host["atom_pad_mask"].shape[0] + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in input_feats_global_fp64_host.items() + if k in _placements["cp_atom_features"] + } + + noise_unflat = noise_expected_global_fp64_host.unflatten(0, (size_batch, multiplicity_diffusion_train)) + denoised_unflat = output_denoised_atom_coords_expected_global_host.unflatten( + 0, (size_batch, multiplicity_diffusion_train) + ) + d_denoised_unflat = d_output_denoised_atom_coords_expected_global_host.unflatten( + 0, (size_batch, multiplicity_diffusion_train) + ) + aligned_coords_unflat = output_aligned_true_coords_expected_global_host.unflatten( + 0, (size_batch, multiplicity_diffusion_train) + ) + for i_mul in range(multiplicity_diffusion_train): + inputs_atom[f"noise_{i_mul}"] = noise_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"denoised_atom_coords_{i_mul}"] = denoised_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"d_denoised_atom_coords_{i_mul}"] = d_denoised_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"aligned_true_atom_coords_{i_mul}"] = aligned_coords_unflat[:, i_mul].to(dtype=dtype) + + placements_cp_model_io_mul = { + f"{k}_{i_mul}": v + for k, v in _placements["cp_model_io"].items() + for i_mul in range(multiplicity_diffusion_train) + } + placements_cp = _placements["cp_atom_features"] | placements_cp_model_io_mul + + placements_model_io_mul = { + f"{k}_{i_mul}": v + for k, v in _placements["model_io"].items() + for i_mul in range(multiplicity_diffusion_train) + } + placements_dp_cp = _placements["atom_features"] | placements_model_io_mul + + multiplicities = { + "noise": multiplicity_diffusion_train, + "denoised_atom_coords": multiplicity_diffusion_train, + "d_denoised_atom_coords": multiplicity_diffusion_train, + "aligned_true_atom_coords": multiplicity_diffusion_train, + } + + feats_atom = distribute_atom_features( + inputs=inputs_atom, + placements_cp=placements_cp, + placements_dp_cp=placements_dp_cp, + device_mesh=manager.device_mesh_subgroups, + cp_group=manager.group["cp"], + multiplicities=multiplicities, + ) + + feats = {**feats_token_msa, **feats_atom} + + noise_dt = feats_atom.pop("noise") + expected_denoised_dt = feats_atom.pop("denoised_atom_coords") + d_denoised_dt_expected = feats_atom.pop("d_denoised_atom_coords") + expected_aligned_coords_dt = feats_atom.pop("aligned_true_atom_coords") + + else: + # --- Dataloader path: load features from distributed data module --- + DistributedManager.create_group( + "world_cpu", manager.group_ranks["world"], backend="gloo", use_local_synchronization=True + ) + DistributedManager.create_group( + "cp_cpu", manager.group_ranks["cp"], backend="gloo", use_local_synchronization=True + ) + cp_device_mesh = map_subgroup_mesh_to_cpu(manager) + + _mp_data = pytest.MonkeyPatch() + cfg = setup_mock_training_datamodule_config(training_data_dir) + cfg.batch_size = 1 + cfg.samples_per_epoch = grid_group_sizes["dp"] + cfg.moldir = str(canonical_mols_dir) + cfg.return_train_symmetries = False + for ds_cfg in cfg.datasets: + ds_cfg.filters = None + seed_by_rank(0, seed=base_seed) + dm = BoltzTrainingDataModuleDTensor(cfg, manager.device_mesh_subgroups, cp_device_mesh) + _deterministic_getitem_monkeypatch(_mp_data, dm._serial_module._train_set, base_seed=base_seed) + dl = dm.train_dataloader() + batch_cpu = next(iter(dl)) + batch_gpu = dm.transfer_batch_to_device(batch_cpu, manager.device, 0) + _mp_data.undo() + + feats = {} + for k, v in batch_gpu.items(): + if isinstance(v, (DTensor, torch.Tensor)) and v.dtype.is_floating_point: + feats[k] = v.to(dtype=dtype) + elif isinstance(v, list): + feats[k] = [ + item.to(dtype=dtype) if isinstance(item, torch.Tensor) and item.dtype.is_floating_point else item + for item in v + ] + else: + feats[k] = v + + # Distribute noise + model I/O reference tensors via intersperse padding + _placements = get_feature_placements( + atom_keys=set(), + model_io_keys=_io_model_io_keys, + model_io_fp32_keys=set(), + ) + size_batch = input_feats_global_fp64_host["atom_pad_mask"].shape[0] + inputs_io = {"atom_counts_per_token": input_feats_global_fp64_host["atom_counts_per_token"].clone()} + + noise_unflat = noise_expected_global_fp64_host.unflatten(0, (size_batch, multiplicity_diffusion_train)) + denoised_unflat = output_denoised_atom_coords_expected_global_host.unflatten( + 0, (size_batch, multiplicity_diffusion_train) + ) + d_denoised_unflat = d_output_denoised_atom_coords_expected_global_host.unflatten( + 0, (size_batch, multiplicity_diffusion_train) + ) + aligned_coords_unflat = output_aligned_true_coords_expected_global_host.unflatten( + 0, (size_batch, multiplicity_diffusion_train) + ) + for i_mul in range(multiplicity_diffusion_train): + inputs_io[f"noise_{i_mul}"] = noise_unflat[:, i_mul].to(dtype=dtype) + inputs_io[f"denoised_atom_coords_{i_mul}"] = denoised_unflat[:, i_mul].to(dtype=dtype) + inputs_io[f"d_denoised_atom_coords_{i_mul}"] = d_denoised_unflat[:, i_mul].to(dtype=dtype) + inputs_io[f"aligned_true_atom_coords_{i_mul}"] = aligned_coords_unflat[:, i_mul].to(dtype=dtype) + + placements_cp_model_io_mul = { + f"{k}_{i_mul}": v + for k, v in _placements["cp_model_io"].items() + for i_mul in range(multiplicity_diffusion_train) + } + placements_cp = _placements["cp_atom_features"] | placements_cp_model_io_mul + + placements_model_io_mul = { + f"{k}_{i_mul}": v + for k, v in _placements["model_io"].items() + for i_mul in range(multiplicity_diffusion_train) + } + placements_dp_cp = placements_model_io_mul + + io_feats = distribute_atom_features( + inputs=inputs_io, + placements_cp=placements_cp, + placements_dp_cp=placements_dp_cp, + device_mesh=manager.device_mesh_subgroups, + cp_group=manager.group["cp"], + multiplicities={k: multiplicity_diffusion_train for k in _io_model_io_keys}, + ) + + # distribute_atom_features applies intersperse padding per DP rank + # independently, but CollateDTensor additionally homogenizes local + # shard shapes across DP ranks via all-reduce MAX. When samples have + # different atom counts (7ylz vs 8b2e), the dataloader features are + # homogenized but the model-I/O DTensors from distribute_atom_features + # are not. Pad each DTensor's atom dim to match the batch. + target_atoms_global = feats["atom_pad_mask"].shape[-1] + for k in list(io_feats.keys()): + if io_feats[k].shape[1] < target_atoms_global: + io_feats[k] = pad_to_length(io_feats[k], dim=1, length=target_atoms_global) + + noise_dt = io_feats.pop("noise") + expected_denoised_dt = io_feats.pop("denoised_atom_coords") + d_denoised_dt_expected = io_feats.pop("d_denoised_atom_coords") + expected_aligned_coords_dt = io_feats.pop("aligned_true_atom_coords") + + # Monkeypatch deterministic noise for distributed forward + sigmas_device = sigmas_expected_global_fp64_host.to(device=manager.device, dtype=dtype) + sigmas_dt = distribute_tensor(sigmas_device, manager.device_mesh_subgroups, (Shard(0), Replicate(), Replicate())) + + monkeypatch.setattr(module.structure_module, "noise_distribution", lambda bs, dtype=None: sigmas_dt) + monkeypatch.setattr(distributed_diffusion_module, "create_distributed_randn", lambda *a, **kw: noise_dt) + + # Distributed forward + output_dict = module( + feats, + recycling_steps=n_recycles, + multiplicity_diffusion_train=multiplicity_diffusion_train, + ) + + assert "pdistogram" in output_dict + assert "denoised_atom_coords" in output_dict + assert "pbfactor" in output_dict + assert "s" in output_dict + assert "z" in output_dict + assert "sigmas" in output_dict + assert "aligned_true_atom_coords" in output_dict + + token_pad_mask_global = feats["token_pad_mask"].full_tensor() + token_pair_pad_mask_global = feats["token_pair_pad_mask"].full_tensor() + atom_pad_mask_global = feats["atom_pad_mask"].full_tensor() + atom_pad_mask_mul_global = atom_pad_mask_global[:, :, None].repeat_interleave(multiplicity_diffusion_train, 0) + + s_full = output_dict["s"].full_tensor() * token_pad_mask_global[:, :, None] + expected_s = output_s_expected_global_host.to(device=manager.device, dtype=dtype) + torch.testing.assert_close(s_full, expected_s) + + z_full = output_dict["z"].full_tensor() * token_pair_pad_mask_global[:, :, :, None] + expected_z = output_z_expected_global_host.to(device=manager.device, dtype=dtype) + torch.testing.assert_close(z_full, expected_z) + + pdistogram_full = output_dict["pdistogram"].full_tensor() * token_pair_pad_mask_global[:, :, :, None, None] + expected_pdistogram = output_pdistogram_expected_global_host.to(device=manager.device, dtype=dtype) + torch.testing.assert_close(pdistogram_full, expected_pdistogram) + + denoised_full = output_dict["denoised_atom_coords"].full_tensor() * atom_pad_mask_mul_global + expected_denoised_full = expected_denoised_dt.full_tensor() * atom_pad_mask_mul_global + torch.testing.assert_close(denoised_full, expected_denoised_full) + + pbfactor_full = output_dict["pbfactor"].full_tensor() * token_pad_mask_global[:, :, None] + expected_pbfactor = output_pbfactor_expected_global_host.to(device=manager.device, dtype=dtype) + torch.testing.assert_close(pbfactor_full, expected_pbfactor) + + sigmas_full = output_dict["sigmas"].full_tensor() + expected_sigmas = sigmas_expected_global_fp64_host.to(device=manager.device, dtype=dtype) + torch.testing.assert_close(sigmas_full, expected_sigmas) + + aligned_coords_full = output_dict["aligned_true_atom_coords"].full_tensor() * atom_pad_mask_mul_global + expected_aligned_coords_full = expected_aligned_coords_dt.full_tensor() * atom_pad_mask_mul_global + torch.testing.assert_close(aligned_coords_full, expected_aligned_coords_full) + + # Backward pass + d_pdistogram = d_output_pdistogram_expected_global_host.to(device=manager.device, dtype=dtype) + d_pdistogram_dt = distribute_tensor( + d_pdistogram, manager.device_mesh_subgroups, output_dict["pdistogram"].placements + ) + + d_pbfactor = d_output_pbfactor_expected_global_host.to(device=manager.device, dtype=dtype) + d_pbfactor_dt = distribute_tensor(d_pbfactor, manager.device_mesh_subgroups, output_dict["pbfactor"].placements) + + torch.autograd.backward( + [output_dict["pdistogram"], output_dict["denoised_atom_coords"], output_dict["pbfactor"]], + [d_pdistogram_dt, d_denoised_dt_expected, d_pbfactor_dt], + ) + + num_grads_checked = 0 + num_nonzero_grads = 0 + for name, param in module.named_parameters(): + canonical_name = name.replace("._serial.", ".") + if canonical_name in expected_param_grads_global_host_dict: + expected_grad = expected_param_grads_global_host_dict[canonical_name].to(device=manager.device, dtype=dtype) + if param.grad is None: + raise AssertionError(f"Missing gradient for {canonical_name}") + actual_grad = param.grad.full_tensor() if isinstance(param.grad, DTensor) else param.grad + num_grads_checked += 1 + if expected_grad.abs().max().item() > 0: + num_nonzero_grads += 1 + torch.testing.assert_close( + actual_grad, + expected_grad, + msg=lambda msg, cn=canonical_name: f"Gradient mismatch for {cn}: {msg}", + ) + + assert num_grads_checked > 0, "No gradients compared — test is vacuous" + assert num_nonzero_grads > 0, "All compared gradients are zero — test is vacuous" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + ("setup_env", "use_random_features"), + [ + (((2, (2, 2)), True, "cuda", "ENV"), True), + (((2, (2, 2)), True, "cuda", "ENV"), False), + ], + indirect=["setup_env"], + ids=["cuda-dp2-cp2x2-random", "cuda-dp2-cp2x2-dataloader"], +) +def test_boltz2_forward_backward_parity( + setup_env, + use_random_features, + test_cp_training_base_data_dir_boltz2, + canonical_mols_dir, + tmp_path, +): + """V13: Forward/backward parity between distributed Boltz2 and serial Boltz2. + + Tests that the distributed Boltz2 wrapper produces numerically identical + forward outputs and backward gradients compared to the serial implementation, + with training=True and structure_prediction_training=True. + + This test uses custom upstream gradients (not loss-derived) to isolate + the forward/backward pipeline from the loss computation. The model + configuration, initialization, and features are intentionally aligned + with test_boltz2_training_step_parity so that a pass here proves the + forward/backward path is correct under the same numerical regime. + + Parametrized by use_random_features: + - True: synthetic random features (fast, small dimensions) + - False: real 7ylz + 8b2e training data via TrainingDataModule + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + dtype = torch.float64 + B = grid_group_sizes["dp"] + size_cp = grid_group_sizes["cp"][0] + + min_val_init = -0.1 + max_val_init = 0.1 + scale_glorot = 0.1 + + seed = 42 + seed_by_rank(0, seed=seed) + + boltz2_model_params = create_boltz2_model_init_params(use_large_model=False) + n_recycles = 0 + multiplicity_diffusion_train = 2 + + boltz2_model_params["training_args"].recycling_steps = n_recycles + boltz2_model_params["training_args"].diffusion_multiplicity = multiplicity_diffusion_train + boltz2_model_params["no_random_recycling_training"] = True + + if use_random_features: + n_tokens = 30 * size_cp + W = boltz2_model_params["atoms_per_window_queries"] + n_atoms_raw = n_tokens * 20 + n_atoms = ((n_atoms_raw + W - 1) // W) * W + n_msa = max(size_cp * 2, 2) + + input_feats_global_fp64 = random_features( + size_batch=B, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=n_msa, + atom_counts_per_token_range=(8, 20), + device=device_type, + float_value_range=(min_val_init, max_val_init), + selected_keys=_BOLTZ2_SELECTED_KEYS, + num_disto_bins=boltz2_model_params["num_bins"], + ) + input_feats_global_fp64["msa"] = torch.randint( + 0, const.num_tokens, (B, n_msa, n_tokens), dtype=torch.int64, device=device_type + ) + training_data_dir = None + else: + boltz2_model_params["num_bins"] = 64 + training_data_dir = _setup_training_data_7z64_8b2e( + tmp_path / "training_data", test_cp_training_base_data_dir_boltz2 + ) + _mp_data = pytest.MonkeyPatch() + cfg = setup_mock_training_datamodule_config(training_data_dir) + cfg.batch_size = B + cfg.samples_per_epoch = B + cfg.moldir = str(canonical_mols_dir) + cfg.return_train_symmetries = False + for ds_cfg in cfg.datasets: + ds_cfg.filters = None + seed_by_rank(0, seed=seed) + dm = Boltz2TrainingDataModuleSerial(cfg=cfg) + _deterministic_getitem_monkeypatch(_mp_data, dm._train_set, base_seed=seed) + dl = dm.train_dataloader() + raw_batch = next(iter(dl)) + input_feats_global_fp64 = {} + for k, v in raw_batch.items(): + if isinstance(v, torch.Tensor): + input_feats_global_fp64[k] = v.to( + device=device_type, dtype=dtype if v.dtype.is_floating_point else v.dtype + ) + elif isinstance(v, list): + input_feats_global_fp64[k] = [ + item.to(device=device_type, dtype=dtype if item.dtype.is_floating_point else item.dtype) + if isinstance(item, torch.Tensor) + else item + for item in v + ] + else: + input_feats_global_fp64[k] = v + n_atoms = input_feats_global_fp64["atom_pad_mask"].shape[-1] + _mp_data.undo() + + # Real data doesn't include token_pair_pad_mask; compute from token_pad_mask + if "token_pair_pad_mask" not in input_feats_global_fp64: + tpm = input_feats_global_fp64["token_pad_mask"] + input_feats_global_fp64["token_pair_pad_mask"] = tpm[:, :, None] * tpm[:, None, :] + + # Create serial reference module + reference_module = SerialBoltz2(**boltz2_model_params) + init_module_params_glorot(reference_module, gain=scale_glorot) + reference_module.apply(SetModuleInfValues()) + reference_module.structure_module.coordinate_augmentation = False + module_state_dict_fp64 = {k: v.detach().clone().cpu() for k, v in reference_module.state_dict().items()} + reference_module = reference_module.to(dtype=torch.float64, device=device_type).train() + + # Pre-generate deterministic sigmas and non-zero noise + sigmas_expected_global_fp64 = reference_module.structure_module.noise_distribution( + B * multiplicity_diffusion_train + ).to(device=device_type, dtype=torch.float64) + noise_expected_global_fp64 = torch.empty( + B * multiplicity_diffusion_train, n_atoms, 3, device=device_type, dtype=torch.float64 + ) + init_tensors_uniform([noise_expected_global_fp64], low=min_val_init, high=max_val_init) + + # Monkeypatch serial noise for determinism + _monkeypatch = pytest.MonkeyPatch() + _monkeypatch.setattr( + reference_module.structure_module, "noise_distribution", lambda bs, dtype=None: sigmas_expected_global_fp64 + ) + _monkeypatch.setattr(serial_diffusion_v2_module.torch, "randn_like", lambda t: noise_expected_global_fp64.to(t)) + + original_feat_keys = set(input_feats_global_fp64.keys()) + coords_backup = input_feats_global_fp64["coords"].detach().clone() + + # Serial forward + output_dict_serial = reference_module( + input_feats_global_fp64, + recycling_steps=n_recycles, + multiplicity_diffusion_train=multiplicity_diffusion_train, + ) + + output_pdistogram = output_dict_serial["pdistogram"] + output_denoised = output_dict_serial["denoised_atom_coords"] + output_pbfactor = output_dict_serial["pbfactor"] + output_s = output_dict_serial["s"] + output_z = output_dict_serial["z"] + output_aligned_true_coords = output_dict_serial["aligned_true_atom_coords"] + + # Create upstream gradients + d_output_pdistogram = torch.empty_like(output_pdistogram) + d_output_denoised = torch.empty_like(output_denoised) + d_output_pbfactor = torch.empty_like(output_pbfactor) + init_tensors_uniform( + [d_output_pdistogram, d_output_denoised, d_output_pbfactor], low=min_val_init, high=max_val_init + ) + + # Mask upstream gradients + atom_pad_mask = input_feats_global_fp64["atom_pad_mask"] + atom_pad_mask_mul = atom_pad_mask[:, :, None].repeat_interleave(multiplicity_diffusion_train, 0) + d_output_denoised = d_output_denoised * atom_pad_mask_mul + + token_pair_pad_mask = input_feats_global_fp64["token_pair_pad_mask"] + d_output_pdistogram = d_output_pdistogram * token_pair_pad_mask[:, :, :, None, None] + + token_pad_mask = input_feats_global_fp64["token_pad_mask"] + d_output_pbfactor = d_output_pbfactor * token_pad_mask[:, :, None] + + # Serial backward + torch.autograd.backward( + [output_pdistogram, output_denoised, output_pbfactor], + [d_output_pdistogram, d_output_denoised, d_output_pbfactor], + ) + + grad_params_expected = { + name: param.grad.detach().to(dtype=dtype, device="cpu", copy=True) + for name, param in reference_module.named_parameters() + if param.grad is not None + } + + # Restore features — serial forward mutates coords in-place and may add keys + input_feats_global_fp64["coords"] = coords_backup + for key in list(input_feats_global_fp64.keys()): + if key not in original_feat_keys: + del input_feats_global_fp64[key] + + output_pdistogram_host = ( + (output_pdistogram * token_pair_pad_mask[:, :, :, None, None]).detach().to(device="cpu", copy=True) + ) + output_denoised_host = (output_denoised * atom_pad_mask_mul).detach().to(device="cpu", copy=True) + output_pbfactor_host = (output_pbfactor * token_pad_mask[:, :, None]).detach().to(device="cpu", copy=True) + output_s_host = (output_s * token_pad_mask[:, :, None]).detach().to(device="cpu", copy=True) + output_z_host = (output_z * token_pair_pad_mask[:, :, :, None]).detach().to(device="cpu", copy=True) + output_aligned_true_coords_host = ( + (output_aligned_true_coords * atom_pad_mask_mul).detach().to(device="cpu", copy=True) + ) + + sigmas_host = sigmas_expected_global_fp64.detach().to(device="cpu", copy=True) + noise_host = (noise_expected_global_fp64 * atom_pad_mask_mul).detach().to(device="cpu", copy=True) + d_output_pdistogram_host = d_output_pdistogram.detach().to(device="cpu", copy=True) + d_output_denoised_host = d_output_denoised.detach().to(device="cpu", copy=True) + d_output_pbfactor_host = d_output_pbfactor.detach().to(device="cpu", copy=True) + + input_feats_host = {} + for k, v in input_feats_global_fp64.items(): + if isinstance(v, torch.Tensor): + input_feats_host[k] = v.detach().to(device="cpu", copy=True) + elif isinstance(v, list): + input_feats_host[k] = [ + item.detach().to(device="cpu", copy=True) if isinstance(item, torch.Tensor) else item for item in v + ] + else: + input_feats_host[k] = v + + _monkeypatch.undo() + + spawn_multiprocessing( + _worker_forward_backward_parity, + world_size, + grid_group_sizes, + device_type, + backend, + dtype, + boltz2_model_params, + module_state_dict_fp64, + n_recycles, + multiplicity_diffusion_train, + input_feats_host, + sigmas_host, + noise_host, + output_pdistogram_host, + output_denoised_host, + output_pbfactor_host, + output_s_host, + output_z_host, + output_aligned_true_coords_host, + d_output_pdistogram_host, + d_output_denoised_host, + d_output_pbfactor_host, + grad_params_expected, + env_per_rank, + use_random_features, + training_data_dir, + canonical_mols_dir, + seed, + ) + + +# ====================================================================== # +# V16: Serial vs distributed training_step numerical parity # +# ====================================================================== # + + +# V15 (training_step smoke test) was removed — V16 (training_step parity) +# provides strictly stronger guarantees, and the recycling path is covered +# by test_boltz2_train_entrypoint / test_boltz2_stop_and_go in +# test_dtensor_boltz2_train.py. This matches the Boltz-1 test structure +# which has only forward parity + training_step parity (no smoke test). + + +def _worker_training_step_parity( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + module_state_dict, + boltz2_model_params, + input_feats_global_fp64_host, + sigmas_expected_host, + noise_expected_host, + serial_loss_host, + serial_log_metrics_host, + serial_grad_dict_host, + serial_post_opt_dict_host, + optimizer_step, + use_random_features=True, + training_data_dir=None, + canonical_mols_dir=None, + base_seed=42, +): + """V16 multi-rank worker: verify distributed training_step matches serial. + + Compares: + 1. Loss value (serial vs distributed) + 1b. Logged metric values via CSVLogger (serial vs distributed), + including component-wise grad_norms (compared after backward + + on_after_backward so grad_norm metrics are available) + 2. Per-parameter gradients after backward + on_after_backward + 3. Post-optimizer parameter values (if optimizer_step=True) + """ + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + dtype = torch.float64 + + # Build distributed model from serial state dict + reference_module = SerialBoltz2(**boltz2_model_params) + reference_module = reference_module.to(dtype=dtype) + reference_module.load_state_dict(module_state_dict) + reference_module.structure_module.coordinate_augmentation = False + reference_module.apply(SetModuleInfValues()) + reference_module = reference_module.to(device=manager.device) + module = Boltz2Distributed(reference_module, manager) + module = module.to(device=manager.device) + module.train() + + # Capture initial parameter values for non-vacuous optimizer check + initial_params = { + name: (p.full_tensor().detach().clone() if isinstance(p, DTensor) else p.detach().clone()) + for name, p in module.named_parameters() + if p.requires_grad + } + + # Inject CSVLogger to exercise the logging code path (instead of no-op) + worker_csv_logger = CSVLogger(save_dir=tempfile.mkdtemp(), name=f"distributed_rank{rank}") + dist_log = _LogCapture(worker_csv_logger) + monkeypatch.setattr(module, "log", dist_log) + monkeypatch.setattr(module, "training_log", lambda *a, **kw: None) + + multiplicity = boltz2_model_params["training_args"].diffusion_multiplicity + + if use_random_features: + # Distribute features (same pattern as V13) + host_tensor_keys = {k for k, v in input_feats_global_fp64_host.items() if isinstance(v, torch.Tensor)} + _placements = get_feature_placements( + token_keys=host_tensor_keys, + msa_keys=host_tensor_keys, + atom_keys={ + "ref_pos", + "atom_resolved_mask", + "ref_element", + "ref_charge", + "ref_atom_name_chars", + "ref_space_uid", + "coords", + "atom_pad_mask", + "atom_to_token", + "pair_mask", + "atom_counts_per_token", + "token_to_rep_atom", + "bfactor", + "plddt", + }, + model_io_keys={"noise"}, + model_io_fp32_keys=set(), + ) + + # Token + MSA features: broadcast from rank 0 + if manager.group_rank["world"] == 0: + input_feats_token_msa_global = { + k: v.to(device=manager.device, dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in input_feats_global_fp64_host.items() + if k in _placements["token_features"] or k in _placements["msa_features"] + } + else: + input_feats_token_msa_global = None + + feats_token_msa = distribute_features( + input_feats_token_msa_global, + _placements["token_features"] | _placements["msa_features"], + manager.group["world"], + manager.group_ranks["world"][0], + manager.device_mesh_subgroups, + ) + + # Atom features + noise: scatter across CP mesh with multiplicity + size_batch = input_feats_global_fp64_host["atom_pad_mask"].shape[0] + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in input_feats_global_fp64_host.items() + if k in _placements["cp_atom_features"] + } + + noise_unflat = noise_expected_host.unflatten(0, (size_batch, multiplicity)) + for i_mul in range(multiplicity): + inputs_atom[f"noise_{i_mul}"] = noise_unflat[:, i_mul].to(dtype=dtype) + + placements_cp_model_io_mul = { + f"{k}_{i_mul}": v for k, v in _placements["cp_model_io"].items() for i_mul in range(multiplicity) + } + placements_cp = _placements["cp_atom_features"] | placements_cp_model_io_mul + + placements_model_io_mul = { + f"{k}_{i_mul}": v for k, v in _placements["model_io"].items() for i_mul in range(multiplicity) + } + placements_dp_cp = _placements["atom_features"] | placements_model_io_mul + + feats_atom = distribute_atom_features( + inputs=inputs_atom, + placements_cp=placements_cp, + placements_dp_cp=placements_dp_cp, + device_mesh=manager.device_mesh_subgroups, + cp_group=manager.group["cp"], + multiplicities={"noise": multiplicity}, + ) + + noise_dt = feats_atom.pop("noise") + batch = {**feats_token_msa, **feats_atom} + + else: + # --- Dataloader path: load features from distributed data module --- + DistributedManager.create_group( + "world_cpu", manager.group_ranks["world"], backend="gloo", use_local_synchronization=True + ) + DistributedManager.create_group( + "cp_cpu", manager.group_ranks["cp"], backend="gloo", use_local_synchronization=True + ) + cp_device_mesh = map_subgroup_mesh_to_cpu(manager) + + _mp_data = pytest.MonkeyPatch() + cfg = setup_mock_training_datamodule_config(training_data_dir) + cfg.batch_size = 1 + cfg.samples_per_epoch = grid_group_sizes["dp"] + cfg.moldir = str(canonical_mols_dir) + cfg.return_train_symmetries = False + for ds_cfg in cfg.datasets: + ds_cfg.filters = None + seed_by_rank(0, seed=base_seed) + dm = BoltzTrainingDataModuleDTensor(cfg, manager.device_mesh_subgroups, cp_device_mesh) + _deterministic_getitem_monkeypatch(_mp_data, dm._serial_module._train_set, base_seed=base_seed) + dl = dm.train_dataloader() + batch_cpu = next(iter(dl)) + batch_gpu = dm.transfer_batch_to_device(batch_cpu, manager.device, 0) + _mp_data.undo() + + feats = {} + for k, v in batch_gpu.items(): + if isinstance(v, (DTensor, torch.Tensor)) and v.dtype.is_floating_point: + feats[k] = v.to(dtype=dtype) + elif isinstance(v, list): + feats[k] = [ + item.to(dtype=dtype) if isinstance(item, torch.Tensor) and item.dtype.is_floating_point else item + for item in v + ] + else: + feats[k] = v + + # Distribute noise via intersperse padding + homogenization + _io_keys = {"noise"} + _placements = get_feature_placements( + atom_keys=set(), + model_io_keys=_io_keys, + model_io_fp32_keys=set(), + ) + size_batch = input_feats_global_fp64_host["atom_pad_mask"].shape[0] + inputs_io = {"atom_counts_per_token": input_feats_global_fp64_host["atom_counts_per_token"].clone()} + + noise_unflat = noise_expected_host.unflatten(0, (size_batch, multiplicity)) + for i_mul in range(multiplicity): + inputs_io[f"noise_{i_mul}"] = noise_unflat[:, i_mul].to(dtype=dtype) + + placements_cp_model_io_mul = { + f"{k}_{i_mul}": v for k, v in _placements["cp_model_io"].items() for i_mul in range(multiplicity) + } + placements_cp = _placements["cp_atom_features"] | placements_cp_model_io_mul + + placements_model_io_mul = { + f"{k}_{i_mul}": v for k, v in _placements["model_io"].items() for i_mul in range(multiplicity) + } + placements_dp_cp = placements_model_io_mul + + io_feats = distribute_atom_features( + inputs=inputs_io, + placements_cp=placements_cp, + placements_dp_cp=placements_dp_cp, + device_mesh=manager.device_mesh_subgroups, + cp_group=manager.group["cp"], + multiplicities={"noise": multiplicity}, + ) + + # Homogenize model I/O DTensors to match dataloader batch atom dim + target_atoms_global = feats["atom_pad_mask"].shape[-1] + for k in list(io_feats.keys()): + if io_feats[k].shape[1] < target_atoms_global: + io_feats[k] = pad_to_length(io_feats[k], dim=1, length=target_atoms_global) + + noise_dt = io_feats.pop("noise") + batch = feats + + # Monkeypatch deterministic noise for distributed forward + sigmas_device = sigmas_expected_host.to(device=manager.device, dtype=dtype) + sigmas_dt = distribute_tensor(sigmas_device, manager.device_mesh_subgroups, (Shard(0), Replicate(), Replicate())) + monkeypatch.setattr(module.structure_module, "noise_distribution", lambda bs, dtype=None: sigmas_dt) + monkeypatch.setattr(distributed_diffusion_module, "create_distributed_randn", lambda *a, **kw: noise_dt) + + # Run distributed training_step + loss = module.training_step(batch, batch_idx=0) + + # Assert 1: loss matches serial + loss_local = loss.to_local() if isinstance(loss, DTensor) else loss + serial_loss_device = serial_loss_host.to(device=manager.device, dtype=dtype) + torch.testing.assert_close( + loss_local, + serial_loss_device, + msg=lambda msg: f"Rank {rank}: Loss mismatch: {msg}", + ) + + # Backward + loss_local.backward() + + # on_after_backward redistributes gradients to Replicate and logs grad_norm metrics + module.on_after_backward() + + # Assert 1b: logged metrics parity (CSVLogger output). + # Compared after backward + on_after_backward so grad_norm metrics are included. + assert len(dist_log.metrics) > 0, f"Rank {rank}: No metrics logged — test is vacuous" + assert set(dist_log.metrics.keys()) == set(serial_log_metrics_host.keys()), ( + f"Rank {rank}: Logged metric keys differ. " + f"Serial: {sorted(serial_log_metrics_host.keys())}, " + f"Distributed: {sorted(dist_log.metrics.keys())}" + ) + for key in sorted(serial_log_metrics_host.keys()): + torch.testing.assert_close( + torch.tensor(dist_log.metrics[key], dtype=torch.float64), + torch.tensor(serial_log_metrics_host[key], dtype=torch.float64), + msg=lambda msg, k=key: f"Rank {rank}: Logged metric mismatch for {k}: {msg}", + ) + dist_log.flush(step=0) + + # Assert 2: per-parameter gradient parity (default fp64 tolerances). + num_grads_checked = 0 + num_nonzero_grads = 0 + for name, param in module.named_parameters(): + canonical_name = name.replace("._serial.", ".") + if canonical_name not in serial_grad_dict_host: + continue + expected_grad = serial_grad_dict_host[canonical_name].to(device=manager.device, dtype=dtype) + assert param.grad is not None, f"Rank {rank}: Missing gradient for {canonical_name}" + actual_grad = param.grad.full_tensor() if isinstance(param.grad, DTensor) else param.grad + num_grads_checked += 1 + if expected_grad.abs().max().item() > 0: + num_nonzero_grads += 1 + torch.testing.assert_close( + actual_grad, + expected_grad, + msg=lambda msg, cn=canonical_name: ( + f"Rank {rank}: Gradient mismatch for {cn}. " + f"Serial grad norm: {expected_grad.norm().item():.10f}, " + f"Distributed grad norm: {actual_grad.norm().item():.10f}. {msg}" + ), + ) + + assert num_grads_checked > 0, f"Rank {rank}: No gradients compared — test is vacuous" + assert num_nonzero_grads > 0, f"Rank {rank}: All compared gradients are zero — test is vacuous" + + # Assert 3: optimizer step parity (if requested) + if optimizer_step: + optimizer = torch.optim.Adam(module.parameters(), lr=1e-3, betas=(0.9, 0.999)) + optimizer.step() + + num_params_checked = 0 + num_params_changed = 0 + for name, param in module.named_parameters(): + canonical_name = name.replace("._serial.", ".") + if canonical_name not in serial_post_opt_dict_host: + continue + expected_val = serial_post_opt_dict_host[canonical_name].to(device=manager.device, dtype=dtype) + actual_val = param.full_tensor() if isinstance(param, DTensor) else param.data + num_params_checked += 1 + if name in initial_params: + initial_val = initial_params[name].to(device=manager.device, dtype=dtype) + if not torch.equal(actual_val, initial_val): + num_params_changed += 1 + torch.testing.assert_close( + actual_val, + expected_val, + msg=lambda msg, cn=canonical_name: f"Rank {rank}: Post-optimizer mismatch for {cn}. {msg}", + ) + assert num_params_checked > 0, f"Rank {rank}: No post-optimizer params compared" + assert num_params_changed > 0, f"Rank {rank}: No parameters changed after optimizer step — test is vacuous" + + # Assert 4: cross-rank loss identity + torch.distributed.barrier() + + +@pytest.mark.slow +@pytest.mark.parametrize( + ("setup_env", "optimizer_step", "use_random_features"), + [ + (((2, (2, 2)), True, "cuda", "ENV"), True, True), + (((2, (2, 2)), True, "cuda", "ENV"), True, False), + ], + indirect=["setup_env"], + ids=["cuda-dp2-cp2x2-random", "cuda-dp2-cp2x2-dataloader"], +) +def test_boltz2_training_step_parity( + setup_env, + optimizer_step, + use_random_features, + test_cp_training_base_data_dir_boltz2, + canonical_mols_dir, + tmp_path, +): + """V16: Serial vs distributed training_step numerical parity. + + Verifies that the distributed Boltz2Distributed.training_step produces + numerically identical loss, gradients, post-optimizer parameters, and + logged metrics compared to the serial Boltz2.training_step, with all + randomness sources controlled: + + 1. Recycling: no_random_recycling_training=True (fixed recycling_steps) + 2. Diffusion noise: monkeypatched noise_distribution and randn_like / create_distributed_randn + 3. Coordinate augmentation: disabled + 4. Sampling steps: fixed (no sampling_steps_random) + 5. Dropout: 0.0 in all modules + + Both serial and distributed sessions inject a CSVLogger-backed + ``_LogCapture`` so that ``self.log()`` calls are exercised (not + silenced) and the logged metric keys/values are compared. + + Extends V13 (forward/backward parity) to the full training_step control + flow including loss aggregation, recycling-step broadcasting, and + gradient redistribution. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + dtype = torch.float64 + min_val_init = -0.1 + max_val_init = 0.1 + scale_glorot = 0.08 + + seed = 42 + seed_by_rank(0, seed=seed) + + # Small model config — dropout=0 for determinism + boltz2_model_params = create_boltz2_model_init_params(use_large_model=False) + recycling_steps = 0 + multiplicity = 2 + + boltz2_model_params["training_args"].recycling_steps = recycling_steps + boltz2_model_params["training_args"].diffusion_multiplicity = multiplicity + boltz2_model_params["training_args"].sampling_steps = -1 # not used in training path but read by training_step + boltz2_model_params["no_random_recycling_training"] = True + boltz2_model_params["predict_bfactor"] = True + boltz2_model_params["training_args"].bfactor_loss_weight = 1.0 + boltz2_model_params["validate_structure"] = True + + size_cp = grid_group_sizes["cp"][0] + B = grid_group_sizes["dp"] + + if use_random_features: + n_tokens = 30 * size_cp + W = boltz2_model_params["atoms_per_window_queries"] + n_atoms_raw = n_tokens * 20 + n_atoms = ((n_atoms_raw + W - 1) // W) * W + n_msa = max(size_cp * 2, 2) + + input_feats_global_fp64 = random_features( + size_batch=B, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=n_msa, + atom_counts_per_token_range=(8, 20), + device=device_type, + float_value_range=(min_val_init, max_val_init), + selected_keys=_BOLTZ2_SELECTED_KEYS, + num_disto_bins=boltz2_model_params["num_bins"], + ) + input_feats_global_fp64["msa"] = torch.randint( + 0, const.num_tokens, (B, n_msa, n_tokens), dtype=torch.int64, device=device_type + ) + input_feats_global_fp64["disto_target"] = input_feats_global_fp64["disto_target"].unsqueeze(3) + training_data_dir = None + else: + boltz2_model_params["num_bins"] = 64 + training_data_dir = _setup_training_data_7z64_8b2e( + tmp_path / "training_data", test_cp_training_base_data_dir_boltz2 + ) + _mp_data = pytest.MonkeyPatch() + cfg = setup_mock_training_datamodule_config(training_data_dir) + cfg.batch_size = B + cfg.samples_per_epoch = B + cfg.moldir = str(canonical_mols_dir) + cfg.return_train_symmetries = False + for ds_cfg in cfg.datasets: + ds_cfg.filters = None + seed_by_rank(0, seed=seed) + dm = Boltz2TrainingDataModuleSerial(cfg=cfg) + _deterministic_getitem_monkeypatch(_mp_data, dm._train_set, base_seed=seed) + dl = dm.train_dataloader() + raw_batch = next(iter(dl)) + input_feats_global_fp64 = {} + for k, v in raw_batch.items(): + if isinstance(v, torch.Tensor): + input_feats_global_fp64[k] = v.to( + device=device_type, dtype=dtype if v.dtype.is_floating_point else v.dtype + ) + elif isinstance(v, list): + input_feats_global_fp64[k] = [ + item.to(device=device_type, dtype=dtype if item.dtype.is_floating_point else item.dtype) + if isinstance(item, torch.Tensor) + else item + for item in v + ] + else: + input_feats_global_fp64[k] = v + n_atoms = input_feats_global_fp64["atom_pad_mask"].shape[-1] + _mp_data.undo() + + if "token_pair_pad_mask" not in input_feats_global_fp64: + tpm = input_feats_global_fp64["token_pad_mask"] + input_feats_global_fp64["token_pair_pad_mask"] = tpm[:, :, None] * tpm[:, None, :] + + # Create serial model with deterministic init + serial_model = SerialBoltz2(**boltz2_model_params) + init_module_params_glorot(serial_model, gain=scale_glorot) + serial_model.apply(SetModuleInfValues()) + serial_model.structure_module.coordinate_augmentation = False + serial_model = serial_model.to(dtype=dtype, device=device_type) + serial_model.train() + + # Save state dict for distributed model (before serial forward mutates state) + module_state_dict = {k: v.detach().clone().cpu() for k, v in serial_model.state_dict().items()} + + # Pre-generate deterministic sigmas and non-zero noise for serial model + sigmas_serial = serial_model.structure_module.noise_distribution(B * multiplicity).to(dtype=dtype) + noise_serial = torch.empty(B * multiplicity, n_atoms, 3, device=device_type, dtype=dtype) + init_tensors_uniform([noise_serial], low=min_val_init, high=max_val_init) + atom_pad_mask_mul = input_feats_global_fp64["atom_pad_mask"][:, :, None].repeat_interleave(multiplicity, 0) + noise_serial = noise_serial * atom_pad_mask_mul + + _serial_mp = pytest.MonkeyPatch() + serial_csv_logger = CSVLogger(save_dir=str(tmp_path), name="serial") + serial_log = _LogCapture(serial_csv_logger) + _serial_mp.setattr(serial_model, "log", serial_log) + _serial_mp.setattr(serial_model, "training_log", lambda *a, **kw: None) + _serial_mp.setattr(serial_model.structure_module, "noise_distribution", lambda bs, dtype=None: sigmas_serial) + _serial_mp.setattr(serial_diffusion_v2_module.torch, "randn_like", lambda t: noise_serial.to(t)) + # Monkeypatch serial smooth_lddt_loss to use dense pairwise distances. + # The original serial code uses sparse indexing (nonzero + F.pairwise_distance) + # which creates a different autograd backward graph than the distributed dense + # matrix computation (replicate_to_shard_outer_op CDIST). The different + # accumulation patterns amplify ~1e-12 forward differences into ~4.5e-7 + # gradient errors. Using dense distances here aligns the backward structure. + import boltz.model.loss.diffusionv2 as _serial_loss_mod + + def _smooth_lddt_loss_dense( + pred_coords, + true_coords, + is_nucleotide, + coords_mask=None, + nucleic_acid_cutoff=30.0, + other_cutoff=15.0, + multiplicity=1, + ): + compute_dtype = torch.promote_types(pred_coords.dtype, torch.float32) + N = pred_coords.shape[1] + lddt = [] + for i in range(true_coords.shape[0]): + true_dists = torch.cdist(true_coords[i], true_coords[i]) + + is_nuc_i = is_nucleotide[i // multiplicity] + mask_i = coords_mask[i // multiplicity] + + is_nuc_pair = is_nuc_i.unsqueeze(-1).expand(-1, is_nuc_i.shape[-1]) + + mask = is_nuc_pair * (true_dists < nucleic_acid_cutoff).to(compute_dtype) + mask += (1 - is_nuc_pair) * (true_dists < other_cutoff).to(compute_dtype) + mask *= 1 - torch.eye(N, device=pred_coords.device) + mask *= mask_i.unsqueeze(-1) + mask *= mask_i.unsqueeze(-2) + + diff = pred_coords[i].unsqueeze(0) - pred_coords[i].unsqueeze(1) + pred_dists = (diff * diff).sum(-1).add(1e-30).sqrt() + + dist_diff = (true_dists - pred_dists).abs() + + eps = ( + torch.sigmoid(0.5 - dist_diff) + + torch.sigmoid(1.0 - dist_diff) + + torch.sigmoid(2.0 - dist_diff) + + torch.sigmoid(4.0 - dist_diff) + ) * 0.25 + + lddt_i = (eps * mask).sum() / (mask.sum() + 1e-5) + lddt.append(lddt_i) + + return 1 - sum(lddt) / len(lddt) + + _serial_mp.setattr(_serial_loss_mod, "smooth_lddt_loss", _smooth_lddt_loss_dense) + _serial_mp.setattr(serial_diffusion_v2_module, "smooth_lddt_loss", _smooth_lddt_loss_dense) + + # Save coords — serial forward mutates feats["coords"] in-place (flattens ensemble dim) + coords_backup = input_feats_global_fp64["coords"].detach().clone() + original_feat_keys = set(input_feats_global_fp64.keys()) + + # Run serial training_step + serial_loss = serial_model.training_step(input_feats_global_fp64, batch_idx=0) + assert serial_loss is not None, "Serial training_step returned None" + assert serial_loss.isfinite(), f"Serial loss is not finite: {serial_loss.item()}" + + # Backward on serial model + serial_loss.backward() + + # on_after_backward logs grad_norm metrics (component-wise and global) + serial_model.on_after_backward() + + # Flush serial CSVLogger and collect logged metrics (after backward so + # grad_norm metrics are included alongside the training_step metrics) + serial_log.flush(step=0) + serial_log_metrics = dict(serial_log.metrics) + assert len(serial_log_metrics) > 0, "Serial model logged no metrics — validate_structure may be False" + + # Collect serial gradients + serial_grad_dict = {} + for name, param in serial_model.named_parameters(): + if param.grad is not None: + serial_grad_dict[name] = param.grad.detach().clone().cpu() + + # Collect serial post-optimizer parameters (if needed) + serial_post_opt_dict = {} + if optimizer_step: + # Use a simple Adam (no weight decay) for reproducible parity. + # configure_optimizers uses AdamW with lr schedulers which complicates comparison. + serial_optimizer = torch.optim.Adam(serial_model.parameters(), lr=1e-3, betas=(0.9, 0.999)) + serial_optimizer.step() + for name, param in serial_model.named_parameters(): + serial_post_opt_dict[name] = param.data.detach().clone().cpu() + + serial_loss_host = serial_loss.detach().clone().cpu() + sigmas_host = sigmas_serial.detach().clone().cpu() + noise_host = noise_serial.detach().clone().cpu() + + # Restore features — serial forward mutates coords in-place and may add keys + input_feats_global_fp64["coords"] = coords_backup + for key in list(input_feats_global_fp64.keys()): + if key not in original_feat_keys: + del input_feats_global_fp64[key] + _serial_mp.undo() + + # Move features to CPU for spawned workers + input_feats_host = {} + for k, v in input_feats_global_fp64.items(): + if isinstance(v, torch.Tensor): + input_feats_host[k] = v.detach().to(device="cpu", copy=True) + elif isinstance(v, list): + input_feats_host[k] = [ + item.detach().to(device="cpu", copy=True) if isinstance(item, torch.Tensor) else item for item in v + ] + else: + input_feats_host[k] = v + + # Spawn parallel test + spawn_multiprocessing( + _worker_training_step_parity, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + module_state_dict, + boltz2_model_params, + input_feats_host, + sigmas_host, + noise_host, + serial_loss_host, + serial_log_metrics, + serial_grad_dict, + serial_post_opt_dict, + optimizer_step, + use_random_features, + training_data_dir, + canonical_mols_dir, + seed, + ) + + +# ====================================================================== # +# V15: setup() – validator wiring, predict no-op, datamodule=None # +# ====================================================================== # + + +def _worker_setup( + rank, + grid_group_sizes, + device_type, + backend, + env_map=None, +): + """Worker: verify setup() wires validators and handles predict stage.""" + monkeypatch, dm = _init_distributed(rank, grid_group_sizes, device_type, backend, env_map) + + validators = [DistributedRCSBValidator(val_names=["RCSB"], confidence_prediction=False, physicalism_metrics=True)] + serial_model = _create_minimal_serial_boltz2(validate_structure=True, validators=validators, num_val_datasets=1).to( + dm.device + ) + dist_model = Boltz2Distributed(serial_model, dm) + + # --- setup("predict") does not populate validator_mapper --- + dist_model._trainer = SimpleNamespace( + datamodule=SimpleNamespace( + val_group_mapper={0: {"label": "RCSB", "symmetry_correction": False}}, + ), + ) + dist_model.setup("predict") + assert len(dist_model.validator_mapper) == 0, "predict stage should not wire validators" + + # --- setup("fit") populates validator_mapper from trainer.datamodule --- + dist_model.setup("fit") + assert len(dist_model.validator_mapper) == 1, "fit stage should populate validator_mapper" + assert 0 in dist_model.validator_mapper + assert isinstance(dist_model.validator_mapper[0], DistributedRCSBValidator) + + # --- val_group_mapper updated from datamodule --- + assert len(dist_model.val_group_mapper) == 1 + assert dist_model.val_group_mapper[0]["label"] == "RCSB" + assert dist_model.val_group_mapper[0]["symmetry_correction"] is False + + del dist_model, serial_model + + # --- setup without datamodule is a no-op --- + validators2 = [DistributedRCSBValidator(val_names=["RCSB"], confidence_prediction=False, physicalism_metrics=True)] + serial2 = _create_minimal_serial_boltz2(validate_structure=True, validators=validators2, num_val_datasets=1).to( + dm.device + ) + dist2 = Boltz2Distributed(serial2, dm) + dist2._trainer = SimpleNamespace(datamodule=None) + dist2.setup("fit") + assert len(dist2.validator_mapper) == 0, "No datamodule should leave validator_mapper empty" + del dist2, serial2 + + # --- setup with validate_structure=False is a no-op regardless of stage --- + serial3 = _create_minimal_serial_boltz2(validate_structure=False).to(dm.device) + dist3 = Boltz2Distributed(serial3, dm) + dist3._trainer = SimpleNamespace( + datamodule=SimpleNamespace( + val_group_mapper={0: {"label": "RCSB", "symmetry_correction": False}}, + ), + ) + dist3.setup("fit") + assert not hasattr(dist3, "validator_mapper"), "validate_structure=False model should not have validator_mapper" + del dist3, serial3 + + # --- setup("fit") with mismatched num_val_datasets raises AssertionError --- + validators_mismatch = [ + DistributedRCSBValidator(val_names=["RCSB"], confidence_prediction=False, physicalism_metrics=True) + ] + serial_mm = _create_minimal_serial_boltz2( + validate_structure=True, validators=validators_mismatch, num_val_datasets=1 + ).to(dm.device) + dist_mm = Boltz2Distributed(serial_mm, dm) + dist_mm._trainer = SimpleNamespace( + datamodule=SimpleNamespace( + val_group_mapper={ + 0: {"label": "RCSB", "symmetry_correction": False}, + 1: {"label": "EXTRA", "symmetry_correction": False}, + }, + ), + ) + with pytest.raises(AssertionError, match="num_val_datasets"): + dist_mm.setup("fit") + del dist_mm, serial_mm + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda val: f"{val[2]}-dp:{val[0][0]}-cp{'x'.join(map(str, val[0][1]))}", +) +def test_setup(setup_env): + """V15: setup() wires validators correctly and predict stage is a no-op. + + Verifies: + - setup("predict") does not populate validator_mapper + - setup("fit") populates validator_mapper from trainer.datamodule + - setup without datamodule is a no-op + - validate_structure=False means setup("fit") doesn't touch validators + - Mismatched num_val_datasets raises AssertionError + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + spawn_multiprocessing( + _worker_setup, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +# ====================================================================== # +# V18: validation_step and on_validation_epoch_end (real forward pass) # +# ====================================================================== # + + +def _worker_validation_step_parity( + rank, + grid_group_sizes, + device_type, + backend, + boltz2_model_params, + module_state_dict, + input_feats_host, + noise_host_list, + serial_per_sample, + serial_epoch_end_metrics, + env_per_rank=None, +): + """V18 multi-rank worker: verify distributed validation_step matches serial. + + Phase 1: Compare raw accumulated metrics after validation_step against + the serial per-sample reference for this DP rank's sample. + + Phase 2: Compare aggregated metrics after on_validation_epoch_end + against the serial epoch-end reference. + """ + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + dtype = torch.float64 + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + dp_rank = manager.group_rank["dp"] + + reference_module = SerialBoltz2(**boltz2_model_params) + reference_module = reference_module.to(dtype=dtype) + reference_module.load_state_dict(module_state_dict) + reference_module.structure_module.coordinate_augmentation = False + reference_module.apply(SetModuleInfValues()) + reference_module = reference_module.to(device=manager.device) + module = Boltz2Distributed(reference_module, manager) + module.eval() + + # Wire validators via setup + symmetry_correction = boltz2_model_params["validation_args"].symmetry_correction + num_validators = boltz2_model_params["num_val_datasets"] + val_names = [v.val_names[0] for v in boltz2_model_params["validators"]] + worker_val_group_mapper = { + vi: {"label": val_names[vi], "symmetry_correction": symmetry_correction} for vi in range(num_validators) + } + module._trainer = SimpleNamespace( + datamodule=SimpleNamespace(val_group_mapper=worker_val_group_mapper), + sanity_checking=False, + ) + module.setup("fit") + assert len(module.validator_mapper) == num_validators + + # ------------------------------------------------------------------ + # Distribute token + MSA + atom features for this rank's sample + # ------------------------------------------------------------------ + diffusion_samples = boltz2_model_params["validation_args"].diffusion_samples + num_sampling_steps = boltz2_model_params["validation_args"].sampling_steps + + host_tensor_keys = {k for k, v in input_feats_host.items() if isinstance(v, torch.Tensor)} + _placements = get_feature_placements( + token_keys=host_tensor_keys, + msa_keys=host_tensor_keys, + atom_keys=host_tensor_keys + & { + "ref_pos", + "atom_resolved_mask", + "ref_element", + "ref_charge", + "ref_atom_name_chars", + "ref_space_uid", + "coords", + "atom_pad_mask", + "atom_to_token", + "pair_mask", + "atom_counts_per_token", + "token_to_rep_atom", + "bfactor", + "plddt", + }, + model_io_keys={"noise"}, + model_io_fp32_keys=set(), + ) + + token_msa_placements = _placements["token_features"] | _placements["msa_features"] + + if manager.group_rank["world"] == 0: + input_feats_token_msa_global = { + k: v.to(device=manager.device, dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in input_feats_host.items() + if k in token_msa_placements + } + else: + input_feats_token_msa_global = None + + feats_token_msa = distribute_features( + input_feats_token_msa_global, + token_msa_placements, + manager.group["world"], + manager.group_ranks["world"][0], + manager.device_mesh_subgroups, + ) + + size_batch = input_feats_host["atom_pad_mask"].shape[0] + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in input_feats_host.items() + if k in _placements["cp_atom_features"] + } + + n_noise = 1 + num_sampling_steps + for i_noise, noise_host in enumerate(noise_host_list): + unflat = noise_host.unflatten(0, (size_batch, diffusion_samples)) + for i_mul in range(diffusion_samples): + inputs_atom[f"_noise_{i_noise}_{i_mul}"] = unflat[:, i_mul].to(dtype=dtype) + + noise_cp_placements = {} + noise_placements = {} + for i_noise in range(n_noise): + for i_mul in range(diffusion_samples): + key = f"_noise_{i_noise}_{i_mul}" + noise_cp_placements[key] = _placements["cp_model_io"]["noise"] + noise_placements[key] = _placements["model_io"]["noise"] + + feats_and_noise = distribute_atom_features( + inputs=inputs_atom, + placements_cp=_placements["cp_atom_features"] | noise_cp_placements, + placements_dp_cp=_placements["atom_features"] | noise_placements, + device_mesh=manager.device_mesh_subgroups, + cp_group=manager.group["cp"], + multiplicities={f"_noise_{i}": diffusion_samples for i in range(n_noise)}, + ) + + noise_dts = [] + for i_noise in range(n_noise): + noise_dts.append(feats_and_noise.pop(f"_noise_{i_noise}")) + init_noise_dt = noise_dts[0] + step_noise_dts = noise_dts[1:] + + feats_dt = {**feats_token_msa, **feats_and_noise} + + # Add symmetry features as non-sharded plain tensors/lists (DP-local slice) + if symmetry_correction: + _SYM_KEYS = { + "all_coords", + "all_resolved_mask", + "crop_to_all_atom_map", + "chain_swaps", + "amino_acids_symmetries", + "ligand_symmetries", + } + for sk in _SYM_KEYS: + if sk not in input_feats_host: + continue + val = input_feats_host[sk] + if isinstance(val, torch.Tensor): + feats_dt[sk] = val[dp_rank : dp_rank + 1].to( + device=manager.device, dtype=dtype if val.dtype.is_floating_point else val.dtype + ) + elif isinstance(val, list): + elem = val[dp_rank] + if isinstance(elem, torch.Tensor): + feats_dt[sk] = elem.unsqueeze(0).to( + device=manager.device, dtype=dtype if elem.dtype.is_floating_point else elem.dtype + ) + else: + feats_dt[sk] = [elem] + + # Non-sharded physicalism keys (clash + PB): DP-sliced, not CP-sharded + _PHYS_KEYS = LIGAND_GEOMETRY_FEATURES | {"chain_symmetries"} + for pk in _PHYS_KEYS: + if pk not in input_feats_host: + continue + val = input_feats_host[pk] + if isinstance(val, torch.Tensor): + feats_dt[pk] = val[dp_rank : dp_rank + 1].to( + device=manager.device, dtype=dtype if val.dtype.is_floating_point else val.dtype + ) + elif isinstance(val, list): + elem = val[dp_rank] + if isinstance(elem, torch.Tensor): + feats_dt[pk] = elem.unsqueeze(0).to( + device=manager.device, dtype=dtype if elem.dtype.is_floating_point else elem.dtype + ) + else: + feats_dt[pk] = [elem] + + # ------------------------------------------------------------------ + # Monkeypatch distributed sample() for determinism + # ------------------------------------------------------------------ + _orig_center_random_augmentation = distributed_diffusion_module.center_random_augmentation + + def _centering_only_augmentation(atom_coords, atom_mask, **kwargs): + kwargs["augmentation"] = False + kwargs["centering"] = True + return _orig_center_random_augmentation(atom_coords, atom_mask, **kwargs) + + _dt_randn_calls = [] + _dt_randn_sequence = [init_noise_dt] + step_noise_dts + + def _fixed_create_distributed_randn(shape, device_mesh, placements, dtype=torch.float32, scale=1.0): + idx = len(_dt_randn_calls) + _dt_randn_calls.append(idx) + noise_dt = _dt_randn_sequence[idx] + if scale != 1.0: + noise_dt = scalar_tensor_op(scale, noise_dt, ElementwiseOp.PROD) + return noise_dt + + monkeypatch.setattr(distributed_diffusion_module, "center_random_augmentation", _centering_only_augmentation) + monkeypatch.setattr(distributed_diffusion_module, "create_distributed_randn", _fixed_create_distributed_randn) + + # ------------------------------------------------------------------ + # Phase 1: Run validation_step (validator 0), compare accumulated metrics + # ------------------------------------------------------------------ + feats_dt["idx_dataset"] = [torch.tensor([0], device=manager.device)] + + with torch.no_grad(): + module.validation_step(feats_dt, batch_idx=0) + + validator = module.validator_mapper[0] + fm = validator.folding_metrics + val_idx = 0 + + serial_ref = serial_per_sample[dp_rank] + compared_phase1 = 0 + + disto_loss_metric = fm["disto_loss"][val_idx]["disto_loss"] + if disto_loss_metric.weight > 0: + dist_disto_loss = disto_loss_metric.compute().item() + serial_disto_loss_avg = sum(s["disto_loss"] for s in serial_per_sample) / len(serial_per_sample) + torch.testing.assert_close( + torch.tensor(dist_disto_loss, dtype=dtype), + torch.tensor(serial_disto_loss_avg, dtype=dtype), + msg=lambda msg: f"Rank {rank}: Phase 1 disto_loss mismatch: {msg}", + ) + compared_phase1 += 1 + + for key in [*const.out_types, "pocket_ligand_protein", "contact_protein_protein"]: + if key in fm["disto_lddt"][val_idx]: + metric = fm["disto_lddt"][val_idx][key] + if metric.weight > 0 and key in serial_ref.get("disto_lddt", {}): + dist_val = metric.compute().item() + serial_val = serial_ref["disto_lddt"][key] + torch.testing.assert_close( + torch.tensor(dist_val, dtype=dtype), + torch.tensor(serial_val, dtype=dtype), + msg=lambda msg, k=key: f"Rank {rank}: Phase 1 disto_lddt_{k} mismatch: {msg}", + ) + compared_phase1 += 1 + + for key in [*const.out_types, "pocket_ligand_protein", "contact_protein_protein"]: + if key in fm["lddt"][val_idx]: + metric = fm["lddt"][val_idx][key] + if metric.weight > 0 and key in serial_ref.get("lddt", {}): + dist_val = metric.compute().item() + serial_val = serial_ref["lddt"][key] + torch.testing.assert_close( + torch.tensor(dist_val, dtype=dtype), + torch.tensor(serial_val, dtype=dtype), + msg=lambda msg, k=key: f"Rank {rank}: Phase 1 lddt_{k} mismatch: {msg}", + ) + compared_phase1 += 1 + + for key in [*const.out_types, "pocket_ligand_protein", "contact_protein_protein"]: + if key in fm["complex_lddt"][val_idx]: + metric = fm["complex_lddt"][val_idx][key] + if metric.weight > 0 and key in serial_ref.get("complex_lddt", {}): + dist_val = metric.compute().item() + serial_val = serial_ref["complex_lddt"][key] + torch.testing.assert_close( + torch.tensor(dist_val, dtype=dtype), + torch.tensor(serial_val, dtype=dtype), + msg=lambda msg, k=key: f"Rank {rank}: Phase 1 complex_lddt_{k} mismatch: {msg}", + ) + compared_phase1 += 1 + + assert compared_phase1 >= 3, f"Rank {rank}: Phase 1 compared only {compared_phase1} metrics — test may be vacuous" + + # Run remaining validators for Phase 2 accumulation + for vi in range(1, num_validators): + feats_dt["idx_dataset"] = [torch.tensor([vi], device=manager.device)] + + _dt_randn_calls_vi = [] + _dt_randn_sequence_vi = [init_noise_dt] + step_noise_dts + + def _fixed_randn_vi( + shape, + device_mesh, + placements, + dtype=torch.float32, + scale=1.0, + _seq=_dt_randn_sequence_vi, + _calls=_dt_randn_calls_vi, + ): + idx = len(_calls) + _calls.append(idx) + noise_dt = _seq[idx] + if scale != 1.0: + noise_dt = scalar_tensor_op(scale, noise_dt, ElementwiseOp.PROD) + return noise_dt + + monkeypatch.setattr(distributed_diffusion_module, "create_distributed_randn", _fixed_randn_vi) + + # In Phase 1, only 1 validator is used. For tests with multiple validators, validation_step is called for other validators. + with torch.no_grad(): + module.validation_step(feats_dt, batch_idx=vi) + + # ------------------------------------------------------------------ + # Phase 2: on_validation_epoch_end and compare aggregated metrics + # ------------------------------------------------------------------ + dist_log = _LogCapture(CSVLogger(save_dir=tempfile.mkdtemp(), name=f"dist_val_rank{rank}")) + monkeypatch.setattr(module, "log", dist_log) + + module.on_validation_epoch_end() + + _forward_dependent_prefixes = ("val/lddt", "val/complex_lddt", "val/clash", "val/pb", "val/rmsd") + compared_phase2 = 0 + for key in sorted(serial_epoch_end_metrics): + if key in dist_log.metrics: + got = torch.tensor(dist_log.metrics[key], dtype=dtype) + exp = torch.tensor(serial_epoch_end_metrics[key], dtype=dtype) + if any(key.startswith(p) for p in _forward_dependent_prefixes): + torch.testing.assert_close( + got, + exp, + msg=lambda msg, k=key: f"Rank {rank}: Phase 2 epoch-end metric '{k}' mismatch: {msg}", + ) + else: + torch.testing.assert_close( + got, + exp, + msg=lambda msg, k=key: f"Rank {rank}: Phase 2 epoch-end metric '{k}' mismatch: {msg}", + ) + compared_phase2 += 1 + + assert compared_phase2 >= 3, f"Rank {rank}: Phase 2 compared only {compared_phase2} metrics — test may be vacuous" + + required_base_metrics = ("val/lddt", "val/disto_lddt", "val/complex_lddt") + for base_metric in required_base_metrics: + for vn in val_names: + suffix = "" if vn == "RCSB" else f"__{vn}" + required_metric = f"{base_metric}{suffix}" + assert required_metric in serial_epoch_end_metrics, ( + f"Rank {rank}: serial epoch-end metrics missing '{required_metric}' — " + f"available: {sorted(serial_epoch_end_metrics)}" + ) + assert required_metric in dist_log.metrics, ( + f"Rank {rank}: distributed epoch-end metrics missing '{required_metric}' — " + f"available: {sorted(dist_log.metrics)}" + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + ("setup_env", "use_random_features", "symmetry_correction", "num_validators"), + [ + (((2, (2, 2)), True, "cuda", "ENV"), True, False, 2), + (((2, (2, 2)), True, "cuda", "ENV"), False, True, 1), + ], + indirect=["setup_env"], + ids=["cuda-dp2-cp2x2-random-2val", "cuda-dp2-cp2x2-dataloader-sc-1val"], +) +def test_boltz2_validation_step_parity( + setup_env, + use_random_features, + symmetry_correction, + num_validators, + test_cp_training_base_data_dir_boltz2, + canonical_mols_dir, + tmp_path, +): + """V18: validation_step parity between distributed and serial Boltz2. + + Two-phase comparison: + Phase 1: After validation_step, compare validator MeanMetric values + (disto_loss, disto_lddt_*, lddt_*, complex_lddt_*) between serial + (per-sample) and distributed (per DP-rank sample). + Phase 2: After on_validation_epoch_end (which DP-all-reduces + MeanMetric internals), compare aggregated logged metrics between + serial and distributed. + + Uses DP=2, CP=(2,2) -> 8 GPUs, FP64 with default tolerances. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + if use_random_features and symmetry_correction: + pytest.skip("Symmetry correction not supported with random features") + + if num_validators > 1 and not use_random_features: + pytest.skip("Multiple validators require multiple val datasets; only supported with use_random_features=True") + + dtype = torch.float64 + min_val_init = -0.01 + max_val_init = 0.01 + scale_glorot = 0.05 + + num_sampling_steps = 2 + diffusion_samples = 1 + + seed = 42 + seed_by_rank(0, seed=seed) + + boltz2_model_params = create_boltz2_model_init_params(use_large_model=False) + boltz2_model_params["diffusion_process_args"]["alignment_reverse_diff"] = False + boltz2_model_params["validate_structure"] = True + val_names = [f"RCSB_{i}" for i in range(num_validators)] if num_validators > 1 else ["RCSB"] + boltz2_model_params["validators"] = [ + RCSBValidator(val_names=[vn], confidence_prediction=False, physicalism_metrics=False) for vn in val_names + ] + boltz2_model_params["num_val_datasets"] = num_validators + boltz2_model_params["confidence_prediction"] = False + boltz2_model_params["validation_args"] = _make_validation_args( + recycling_steps=0, + sampling_steps=num_sampling_steps, + diffusion_samples=diffusion_samples, + symmetry_correction=symmetry_correction, + ) + + size_cp = grid_group_sizes["cp"][0] + B = grid_group_sizes["dp"] + + if use_random_features: + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + n_tokens = 30 * size_cp + W = boltz2_model_params["atoms_per_window_queries"] + n_atoms_raw = n_tokens * n_atoms_per_token_max + n_atoms = ((n_atoms_raw + W - 1) // W) * W + n_msa = size_cp * 2 + + assert n_atoms % size_cp == 0 + assert n_atoms % W == 0 + + input_feats_global_fp64 = random_features( + size_batch=B, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=n_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=device_type, + float_value_range=(min_val_init, max_val_init), + selected_keys=_BOLTZ2_SELECTED_KEYS, + num_disto_bins=boltz2_model_params["num_bins"], + ) + input_feats_global_fp64["msa"] = torch.randint( + 0, const.num_tokens, (B, n_msa, n_tokens), dtype=torch.int64, device=device_type + ) + + input_feats_global_fp64["disto_target"] = input_feats_global_fp64["disto_target"].unsqueeze(3) + + token_to_rep_atom = input_feats_global_fp64["token_to_rep_atom"] + coords = input_feats_global_fp64["coords"] + disto_coords_ensemble = torch.bmm(token_to_rep_atom.to(dtype=dtype), coords[:, 0]) + input_feats_global_fp64["disto_coords_ensemble"] = disto_coords_ensemble + + input_feats_global_fp64["connections_edge_index"] = [ + torch.empty(2, 0, dtype=torch.long, device=device_type) for _ in range(B) + ] + input_feats_global_fp64["chain_symmetries"] = [[] for _ in range(B)] + else: + boltz2_model_params["num_bins"] = 64 + training_data_dir = _setup_training_data_7z64_8b2e( + tmp_path / "training_data", test_cp_training_base_data_dir_boltz2 + ) + cfg = setup_mock_training_datamodule_config(training_data_dir) + cfg.overfit = B + cfg.samples_per_epoch = B + cfg.moldir = str(canonical_mols_dir) + cfg.return_train_symmetries = symmetry_correction + cfg.pad_to_max_tokens = True + cfg.pad_to_max_atoms = True + cfg.pad_to_max_seqs = True + W = boltz2_model_params["atoms_per_window_queries"] + token_align = size_cp + atom_align = math.lcm(W, size_cp) + cfg.max_tokens = ((cfg.max_tokens + token_align - 1) // token_align) * token_align + cfg.max_atoms = ((cfg.max_atoms + atom_align - 1) // atom_align) * atom_align + cfg.max_seqs = ((cfg.max_seqs + size_cp - 1) // size_cp) * size_cp + for ds_cfg in cfg.datasets: + ds_cfg.filters = None + seed_by_rank(0, seed=seed) + dm = Boltz2TrainingDataModuleSerial(cfg=cfg) + dl = dm.val_dataloader() + dl_iter = iter(dl) + raw_samples = [next(dl_iter) for _ in range(B)] + + def _unwrap_bs1(v): + if isinstance(v, torch.Tensor): + return v.squeeze(0) + if isinstance(v, list) and len(v) == 1: + return v[0] + return v + + raw_batch = collate([{k: _unwrap_bs1(v) for k, v in s.items()} for s in raw_samples]) + + input_feats_global_fp64 = {} + for k, v in raw_batch.items(): + if isinstance(v, torch.Tensor): + input_feats_global_fp64[k] = v.to( + device=device_type, dtype=dtype if v.dtype.is_floating_point else v.dtype + ) + elif isinstance(v, list): + input_feats_global_fp64[k] = [ + item.to(device=device_type, dtype=dtype if item.dtype.is_floating_point else item.dtype) + if isinstance(item, torch.Tensor) + else item + for item in v + ] + else: + input_feats_global_fp64[k] = v + n_atoms = input_feats_global_fp64["atom_pad_mask"].shape[-1] + + if "token_pair_pad_mask" not in input_feats_global_fp64: + tpm = input_feats_global_fp64["token_pad_mask"] + input_feats_global_fp64["token_pair_pad_mask"] = tpm[:, :, None] * tpm[:, None, :] + + # ------------------------------------------------------------------ + # Slice global batch into individual samples for serial validation + # ------------------------------------------------------------------ + def _slice_batch(feats, idx): + batch_i = {} + for k, v in feats.items(): + if isinstance(v, torch.Tensor): + batch_i[k] = v[idx : idx + 1].clone() + elif isinstance(v, list): + elem = v[idx] + if isinstance(elem, torch.Tensor): + batch_i[k] = elem.unsqueeze(0).clone() + else: + batch_i[k] = [elem] + else: + batch_i[k] = v + batch_i["idx_dataset"] = torch.tensor([0], device=device_type) + return batch_i + + num_val_samples = B + + # ------------------------------------------------------------------ + # Build serial model + # ------------------------------------------------------------------ + reference_module = SerialBoltz2(**boltz2_model_params) + init_module_params_glorot(reference_module, gain=scale_glorot) + reference_module.apply(SetModuleInfValues()) + reference_module.structure_module.coordinate_augmentation = False + module_state_dict = reference_module.state_dict() + reference_module = reference_module.to(dtype=dtype, device=device_type).eval() + + serial_validators = [] + reference_module.val_group_mapper = {} + reference_module.validator_mapper = {} + for vi in range(num_validators): + vn = val_names[vi] + v = RCSBValidator(val_names=[vn], confidence_prediction=False, physicalism_metrics=True) + v = v.to(device=device_type, dtype=dtype) + serial_validators.append(v) + reference_module.val_group_mapper[vi] = {"label": vn, "symmetry_correction": symmetry_correction} + reference_module.validator_mapper[vi] = v + + # ------------------------------------------------------------------ + # Pre-generate deterministic noise for sampling + # ------------------------------------------------------------------ + _B_M = B * diffusion_samples + init_noise = torch.empty((_B_M, n_atoms, 3), device=device_type, dtype=dtype) + step_noise_list = [ + torch.empty((_B_M, n_atoms, 3), device=device_type, dtype=dtype) for _ in range(num_sampling_steps) + ] + init_tensors_uniform([init_noise, *step_noise_list], low=min_val_init, high=max_val_init) + all_noise = [init_noise] + step_noise_list + + # ------------------------------------------------------------------ + # Phase 1 serial: per-sample metrics + # ------------------------------------------------------------------ + serial_per_sample = [{} for _ in range(num_val_samples)] + + _original_torch_randn = torch.randn + + def _run_serial_validation_step_batch(batch_i, sample_idx): + _serial_randn_calls = [] + noise_for_sample = [n[sample_idx : sample_idx + 1].clone() for n in all_noise] + _serial_randn_sequence = noise_for_sample + + def _fixed_randn(*args, _seq=_serial_randn_sequence, _calls=_serial_randn_calls, **kwargs): + idx = len(_calls) + _calls.append(idx) + if idx < len(_seq): + return _seq[idx].clone() + return _original_torch_randn(*args, **kwargs) + + _serial_mp = pytest.MonkeyPatch() + _serial_mp.setattr( + serial_diffusion_v2_module, + "compute_random_augmentation", + lambda mult, device=None, dtype=None: ( + torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(diffusion_samples, -1, -1), + torch.zeros(diffusion_samples, 1, 3, device=device, dtype=dtype), + ), + ) + _serial_mp.setattr(serial_diffusion_v2_module.torch, "randn", _fixed_randn) + _serial_mp.setattr(reference_module, "log", lambda *a, **kw: None) + + with torch.no_grad(): + reference_module.validation_step(batch_i, batch_idx=sample_idx) + + _serial_mp.undo() + + def _extract_validator_metrics(validator): + fm = validator.folding_metrics + val_idx = 0 + sample_metrics = {} + disto_loss_metric = fm["disto_loss"][val_idx]["disto_loss"] + sample_metrics["disto_loss"] = disto_loss_metric.compute().item() + sample_metrics["disto_lddt"] = {} + sample_metrics["lddt"] = {} + sample_metrics["complex_lddt"] = {} + for m_ in [*const.out_types, "pocket_ligand_protein", "contact_protein_protein"]: + if m_ in fm["disto_lddt"][val_idx]: + val = fm["disto_lddt"][val_idx][m_].compute() + if not torch.isnan(val): + sample_metrics["disto_lddt"][m_] = val.item() + if m_ in fm["lddt"][val_idx]: + val = fm["lddt"][val_idx][m_].compute() + if not torch.isnan(val): + sample_metrics["lddt"][m_] = val.item() + if m_ in fm["complex_lddt"][val_idx]: + val = fm["complex_lddt"][val_idx][m_].compute() + if not torch.isnan(val): + sample_metrics["complex_lddt"][m_] = val.item() + return sample_metrics + + def _reset_validator_metrics(validator): + fm = validator.folding_metrics + val_idx = 0 + for metric_group in ["lddt", "disto_lddt", "complex_lddt", "disto_loss"]: + for k, metric_obj in fm[metric_group][val_idx].items(): + metric_obj.reset() + # Reset physicalism metrics so the next sample does not accumulate on top + if getattr(validator, "physicalism_metrics", None) and hasattr(validator.physicalism_metrics, "keys"): + for group in ["clash", "pb"]: + if group in validator.physicalism_metrics: + for metric_obj in validator.physicalism_metrics[group][val_idx].values(): + metric_obj.reset() + + for sample_idx in range(B): + batch_i = _slice_batch(input_feats_global_fp64, sample_idx) + batch_i["idx_dataset"] = torch.tensor([0], device=device_type) + _run_serial_validation_step_batch(batch_i, sample_idx) + + serial_per_sample[sample_idx] = _extract_validator_metrics(serial_validators[0]) + _reset_validator_metrics(serial_validators[0]) + + # ------------------------------------------------------------------ + # Phase 2 serial: epoch-end metrics (accumulate both samples) + # ------------------------------------------------------------------ + for vi in range(num_validators): + for sample_idx in range(B): + batch_i = _slice_batch(input_feats_global_fp64, sample_idx) + batch_i["idx_dataset"] = torch.tensor([vi], device=device_type) + _run_serial_validation_step_batch(batch_i, sample_idx) + + serial_log = _LogCapture(CSVLogger(save_dir=tempfile.mkdtemp(), name="serial_val")) + _serial_mp2 = pytest.MonkeyPatch() + _serial_mp2.setattr(reference_module, "log", serial_log) + + reference_module.on_validation_epoch_end() + _serial_mp2.undo() + + serial_epoch_end_metrics = dict(serial_log.metrics) + + assert len(serial_per_sample[0]) > 0, "Serial phase 1 produced no metrics for sample 0" + assert len(serial_per_sample[1]) > 0, "Serial phase 1 produced no metrics for sample 1" + assert len(serial_epoch_end_metrics) > 0, "Serial phase 2 produced no epoch-end metrics" + + # ------------------------------------------------------------------ + # Move to CPU for spawn_multiprocessing + # ------------------------------------------------------------------ + input_feats_host = {} + for k, v in input_feats_global_fp64.items(): + if isinstance(v, torch.Tensor): + input_feats_host[k] = v.detach().to(device="cpu", copy=True) + elif isinstance(v, list): + input_feats_host[k] = [ + item.detach().to(device="cpu", copy=True) if isinstance(item, torch.Tensor) else item for item in v + ] + else: + input_feats_host[k] = v + noise_host_list = [n.detach().cpu() for n in all_noise] + serial_per_sample_cpu = list(serial_per_sample) + + boltz2_model_params["validators"] = [ + DistributedRCSBValidator(val_names=[vn], confidence_prediction=False, physicalism_metrics=True) + for vn in val_names + ] + spawn_multiprocessing( + _worker_validation_step_parity, + world_size, + grid_group_sizes, + device_type, + backend, + boltz2_model_params, + module_state_dict, + input_feats_host, + noise_host_list, + serial_per_sample_cpu, + serial_epoch_end_metrics, + env_per_rank, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/distributed/model/modules/__init__.py b/tests/distributed/model/modules/__init__.py new file mode 100644 index 000000000..b14afb317 --- /dev/null +++ b/tests/distributed/model/modules/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for distributed model modules.""" diff --git a/tests/distributed/model/modules/test_dtensor_adaln.py b/tests/distributed/model/modules/test_dtensor_adaln.py new file mode 100644 index 000000000..2ce22a9c1 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_adaln.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor AdaLN module. + +Tests both Boltz-1x and Boltz-2 serial AdaLN modules against the unified +DTensor AdaLN implementation, verifying forward and backward equivalence. + +""" + +import pytest +import torch +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.transformers import AdaLN as DTensorAdaLN +from boltz.model.modules.transformers import AdaLN as AdaLNSerialBoltz1 +from boltz.model.modules.transformersv2 import AdaLN as AdaLNSerialBoltz2 +from boltz.testing.utils import ( + assert_tensors_identical, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_assert_adaln( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + serial_module_version: str, + dim: int, + dim_single_cond: int, + B: int, + N: int, + layer_state_dict, + a_global_host: torch.Tensor, + s_global_host: torch.Tensor, + d_out_global_host: torch.Tensor, + out_expected_global_host: torch.Tensor, + d_a_expected_global_host: torch.Tensor, + d_s_expected_global_host: torch.Tensor, + expected_param_grads_global_host_dict: dict[str, torch.Tensor], +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Create serial module from state dict + AdaLNSerial = AdaLNSerialBoltz1 if serial_module_version == "boltz1" else AdaLNSerialBoltz2 + module_serial = AdaLNSerial(dim=dim, dim_single_cond=dim_single_cond) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.to(device=manager.device).train() + + # Create DTensor module from serial + module_dt = DTensorAdaLN( + ada_layer_norm=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # Placements: shard batch over dp (dim 0), shard tokens over cp axis-0 (dim 1), + # replicate over cp axis-1 + placements = (Shard(0), Shard(1), Replicate()) if manager.device_mesh_subgroups.ndim == 3 else (Shard(0), Shard(1)) + + a_dt = distribute_tensor( + a_global_host.to(device=manager.device), manager.device_mesh_subgroups, placements + ).requires_grad_(True) + s_dt = distribute_tensor( + s_global_host.to(device=manager.device), manager.device_mesh_subgroups, placements + ).requires_grad_(True) + + # Copies to verify inputs aren't modified + a_dt_copy = a_dt.detach().clone().requires_grad_(True) + s_dt_copy = s_dt.detach().clone().requires_grad_(True) + + # Forward pass + out_dt: DTensor = module_dt(a_dt, s_dt) + + # Ensure no input mutation + assert_tensors_identical(a_dt_copy.to_local(), a_dt.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(s_dt_copy.to_local(), s_dt.to_local(), check_grad=False, check_grad_fn=False) + + # Forward compare (full gather) + torch.testing.assert_close(out_dt.full_tensor().cpu(), out_expected_global_host) + + # Backward pass + d_out_dt = distribute_tensor(d_out_global_host.to(device=manager.device), manager.device_mesh_subgroups, placements) + out_dt.backward(d_out_dt) + + # Compare input gradients + torch.testing.assert_close(a_dt.grad.full_tensor().cpu(), d_a_expected_global_host) + torch.testing.assert_close(s_dt.grad.full_tensor().cpu(), d_s_expected_global_host) + + # Compare parameter gradients + for name, param in module_dt.named_parameters(): + assert param.grad is not None, f"Parameter {name} has no gradient" + expected_grad = expected_param_grads_global_host_dict[name] + torch.testing.assert_close( + param.grad.full_tensor().cpu(), + expected_grad, + msg=lambda m: f"Parameter gradient mismatch for {name}: {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize("serial_module_version", ["boltz1", "boltz2"]) +def test_adaln(setup_env, serial_module_version: str): + """Test AdaLN DTensor vs serial equivalence for both Boltz-1x and Boltz-2.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + # Module dimensions + dim = 64 + dim_single_cond = 32 + + # Data dimensions - must be divisible by dp * cp + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 8 # tokens + + seed_by_rank(0, seed=42) + + # Create serial module + AdaLNSerial = AdaLNSerialBoltz1 if serial_module_version == "boltz1" else AdaLNSerialBoltz2 + module_serial = AdaLNSerial(dim=dim, dim_single_cond=dim_single_cond) + module_serial = module_serial.train() + layer_state_dict = module_serial.state_dict() + + # Create input tensors + a_global = torch.randn(B, N, dim, requires_grad=True) + s_global = torch.randn(B, N, dim_single_cond, requires_grad=True) + + # Serial forward pass + out_serial = module_serial(a_global, s_global) + + # Create upstream gradient + d_out = torch.randn_like(out_serial) + + # Serial backward pass + out_serial.backward(d_out) + + # Collect expected results + out_expected = out_serial.detach().clone().cpu() + d_a_expected = a_global.grad.detach().clone().cpu() + d_s_expected = s_global.grad.detach().clone().cpu() + + expected_param_grads = {} + for name, param in module_serial.named_parameters(): + assert param.grad is not None, f"Serial parameter {name} has no gradient" + expected_param_grads[name] = param.grad.detach().clone().cpu() + + # Launch parallel test + spawn_multiprocessing( + parallel_assert_adaln, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_module_version, + dim, + dim_single_cond, + B, + N, + layer_state_dict, + a_global.detach().clone().cpu(), + s_global.detach().clone().cpu(), + d_out.detach().clone().cpu(), + out_expected, + d_a_expected, + d_s_expected, + expected_param_grads, + ) diff --git a/tests/distributed/model/modules/test_dtensor_atom_attn_decoder_wb.py b/tests/distributed/model/modules/test_dtensor_atom_attn_decoder_wb.py new file mode 100644 index 000000000..62203c919 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_atom_attn_decoder_wb.py @@ -0,0 +1,736 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor AtomAttentionDecoder with window batching.""" + +from functools import partial + +import pytest +import torch +from torch.distributed.tensor import distribute_tensor + +from boltz.distributed.data.feature.featurizer import pack_atom_features +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.flatten_and_unflatten import shardwise_unflatten_sharded +from boltz.distributed.model.layers.utils import convert_single_repr_to_window_batched_key +from boltz.distributed.model.modules.encoders import ( + AtomAttentionDecoder as DistributedAtomAttentionDecoder, +) +from boltz.model.modules.encoders import AtomAttentionDecoder as SerialAtomAttentionDecoderBoltz1 +from boltz.model.modules.encoders import get_indexing_matrix as get_indexing_matrix_v1 +from boltz.model.modules.encoders import single_to_keys as single_to_keys_v1 +from boltz.model.modules.encodersv2 import AtomAttentionDecoder as SerialAtomAttentionDecoderBoltz2 +from boltz.model.modules.encodersv2 import get_indexing_matrix as get_indexing_matrix_v2 +from boltz.model.modules.encodersv2 import single_to_keys as single_to_keys_v2 +from boltz.testing.utils import ( + SetModuleInfValues, + assert_all_identical, + assert_tensors_close_with_pad, + distribute_atom_features, + get_feature_placements, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + pad_or_shrink_to_length, + random_features, + seed_by_rank, + spawn_multiprocessing, +) + +# Subset of keys needed for AtomAttentionDecoder window batching test +_selected_atom_keys = { + "atom_pad_mask", + "atom_to_token", + "atom_counts_per_token", # Required by pad_and_scatter_atom_features_dtensor +} + +_placements = get_feature_placements( + token_keys=set(), + msa_keys=set(), + atom_keys=_selected_atom_keys, + model_io_keys=set(), + model_io_fp32_keys=set(), +) +_placements_single = _placements["single"] +_placements_cp_atom_features = _placements["cp_atom_features"] +_placements_atom_features = _placements["atom_features"] + + +def parallel_assert_atom_attention_decoder_wb( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + serial_module_version: str, + dtype: torch.dtype, + multiplicity: int, + atom_s: int, + atom_z: int, + token_s: int, + W: int, + H: int, + atom_decoder_depth: int, + atom_decoder_heads: int, + layer_state_dict, + feats_global_host: dict[str, torch.Tensor], + a_global_host: torch.Tensor, + q_global_host: torch.Tensor, + c_global_host: torch.Tensor, + p_global_host: torch.Tensor, + d_r_update_global_host: torch.Tensor, + r_update_expected_global_host: torch.Tensor, + d_a_expected_global_host: torch.Tensor, + d_q_expected_global_host: torch.Tensor, + d_c_expected_global_host: torch.Tensor, + d_p_expected_global_host: torch.Tensor, + expected_param_grads_global_host_dict: dict[str, torch.Tensor], +): + """Parallel worker function for testing DTensor AtomAttentionDecoder with window batching.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Recreate serial module from state dict + if serial_module_version == "boltz1": + module_serial = SerialAtomAttentionDecoderBoltz1( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + attn_window_queries=W, + attn_window_keys=H, + atom_decoder_depth=atom_decoder_depth, + atom_decoder_heads=atom_decoder_heads, + ) + else: + module_serial = SerialAtomAttentionDecoderBoltz2( + atom_s=atom_s, + token_s=token_s, + attn_window_queries=W, + attn_window_keys=H, + atom_decoder_depth=atom_decoder_depth, + atom_decoder_heads=atom_decoder_heads, + ) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.to(device=manager.device, dtype=dtype).train() + module_serial.apply(SetModuleInfValues()) + + # Create distributed module + module = DistributedAtomAttentionDecoder( + layer=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # Get global masks + atom_pad_mask_global = feats_global_host["atom_pad_mask"].to(device=manager.device, dtype=torch.bool) + atom_pad_mask_expanded_global = atom_pad_mask_global.unsqueeze(-1) + atom_pad_mask_expanded_global_mul = atom_pad_mask_expanded_global.repeat_interleave(multiplicity, dim=0) + + # Distribute atom features + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in _placements_cp_atom_features + } + feats_dt = distribute_atom_features( + inputs_atom, + _placements_cp_atom_features, + _placements_atom_features, + manager.device_mesh_subgroups, + manager.group["cp"], + ) + + # Pack atom features + feats_dt_packed = pack_atom_features(feats_dt, set(feats_dt.keys()), W) + N_atoms_packed = feats_dt_packed["atom_pad_mask"].shape[1] + K_packed = N_atoms_packed // W + + # Distribute input tensors + a_dt = distribute_tensor( + a_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + _placements_single, + ).requires_grad_(True) + + q_adjusted = pad_or_shrink_to_length( + q_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=N_atoms_packed + ) + q_dt = distribute_tensor(q_adjusted, manager.device_mesh_subgroups, _placements_single).requires_grad_(True) + + c_adjusted = pad_or_shrink_to_length( + c_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=N_atoms_packed + ) + c_dt = distribute_tensor(c_adjusted, manager.device_mesh_subgroups, _placements_single).requires_grad_(True) + + p_adjusted = pad_or_shrink_to_length( + p_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=K_packed + ) + p_dt = distribute_tensor(p_adjusted, manager.device_mesh_subgroups, _placements_single).requires_grad_(True) + + # Forward pass + r_update_dt = module(a=a_dt, q=q_dt, c=c_dt, p=p_dt, feats=feats_dt_packed, multiplicity=multiplicity) + + # Forward comparison + r_update_expected_device = r_update_expected_global_host.to(device=manager.device, dtype=dtype) + r_update_dt_full = r_update_dt.full_tensor() + + mask_dt_full = feats_dt_packed["atom_pad_mask"].full_tensor() + mask_dt_full_mul = mask_dt_full.repeat_interleave(multiplicity, dim=0) + mask_dt_full_mul_expanded = mask_dt_full_mul.unsqueeze(-1) + + assert_tensors_close_with_pad( + r_update_dt_full * mask_dt_full_mul_expanded, + r_update_expected_device * atom_pad_mask_expanded_global_mul, + axis=1, + pad_val=0, + ) + + # Backward pass + d_r_update_padded = pad_or_shrink_to_length( + d_r_update_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=r_update_dt.shape[1] + ) + d_r_update_dtensor = distribute_tensor(d_r_update_padded, manager.device_mesh_subgroups, r_update_dt.placements) + torch.autograd.backward([r_update_dt], [d_r_update_dtensor]) + + # Check a gradient + d_a_expected_device = d_a_expected_global_host.to(device=manager.device, dtype=dtype) + d_a_dt_full = a_dt.grad.full_tensor() + torch.testing.assert_close(d_a_dt_full, d_a_expected_device) + + # Check q gradient + d_q_expected_device = d_q_expected_global_host.to(device=manager.device, dtype=dtype) + q_grad_full = q_dt.grad.full_tensor() + assert_tensors_close_with_pad( + q_grad_full * mask_dt_full_mul_expanded, + d_q_expected_device * atom_pad_mask_expanded_global_mul, + axis=1, + pad_val=0, + ) + + # Check c gradient + d_c_expected_device = d_c_expected_global_host.to(device=manager.device, dtype=dtype) + c_grad_full = c_dt.grad.full_tensor() + assert_tensors_close_with_pad( + c_grad_full * mask_dt_full_mul_expanded, + d_c_expected_device * atom_pad_mask_expanded_global_mul, + axis=1, + pad_val=0, + ) + + # Check p gradient + d_p_expected_device = d_p_expected_global_host.to(device=manager.device, dtype=dtype) + p_grad_full = p_dt.grad.full_tensor() + + mask_dt_query = shardwise_unflatten_sharded( + feats_dt_packed["atom_pad_mask"], axis=1, sizes=(feats_dt_packed["atom_pad_mask"].shape[1] // W, W) + ) + mask_dt_query_full = mask_dt_query.full_tensor() + mask_dt_query_full_expanded = mask_dt_query_full[:, :, :, None, None] + mask_dt_key = convert_single_repr_to_window_batched_key(feats_dt_packed["atom_pad_mask"], W, H) + mask_dt_key_full = mask_dt_key.full_tensor() + mask_dt_key_full_expanded = mask_dt_key_full[:, :, None, :, None] + mask_dt_pair_full_expanded = mask_dt_query_full_expanded * mask_dt_key_full_expanded + + N_atoms_serial = feats_global_host["atom_pad_mask"].shape[1] + K_serial = N_atoms_serial // W + if serial_module_version == "boltz1": + index_matrix = get_indexing_matrix_v1(K_serial, W, H, manager.device).to(dtype=dtype) + to_keys_fn = partial(single_to_keys_v1, indexing_matrix=index_matrix, W=W, H=H) + else: + index_matrix = get_indexing_matrix_v2(K_serial, W, H, manager.device).to(dtype=dtype) + to_keys_fn = partial(single_to_keys_v2, indexing_matrix=index_matrix, W=W, H=H) + + mask_key_expected = to_keys_fn( + feats_global_host["atom_pad_mask"].to(device=manager.device, dtype=dtype).unsqueeze(-1) + ) + mask_key_expected_expanded = mask_key_expected[:, :, None, :, :] + mask_query_expected_expanded = atom_pad_mask_expanded_global.unflatten( + 1, (atom_pad_mask_expanded_global.shape[1] // W, W) + )[:, :, :, None, :] + mask_pair_expected_expanded = mask_query_expected_expanded * mask_key_expected_expanded + + assert_tensors_close_with_pad( + p_grad_full * mask_dt_pair_full_expanded, + d_p_expected_device * mask_pair_expected_expanded, + axis=1, + pad_val=0, + ) + + # Parameter grads comparison + for name, grad_expected_global in expected_param_grads_global_host_dict.items(): + grad_param = get_param_by_key(module, name).grad + assert grad_param is not None, f"Missing grad for param {name}" + + if hasattr(grad_param, "full_tensor"): + grad_global_host = grad_param.full_tensor().cpu() + grad_to_check = grad_param.full_tensor() + else: + grad_global_host = grad_param.detach().cpu() + grad_to_check = grad_param + + torch.testing.assert_close(grad_global_host, grad_expected_global.to(dtype=dtype)) + assert_all_identical(grad_to_check, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env, dtype, multiplicity", + ( + params_test := [ + (((1, (2, 2)), True, "cuda", "ENV"), torch.float32, 1), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, 4), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, device_type:{x[0][2]}, dtype:{x[1]}, multiplicity:{x[2]}" + for x in params_test + ], +) +@pytest.mark.parametrize("serial_module_version", ["boltz1", "boltz2"]) +def test_atom_attention_decoder_window_batching(setup_env, dtype, multiplicity, serial_module_version): + """Test DTensor AtomAttentionDecoder with window batching for V1 and V2.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + seed_by_rank(0, seed=seed) + + size_cp = grid_group_sizes["cp"][0] + B = 1 * grid_group_sizes["dp"] + + W = 32 + H = 128 + val_init_min_max = (-0.1, 0.1) + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + N_tokens = 100 * size_cp + N_atoms_raw = N_tokens * n_atoms_per_token_max + N_atoms = ((N_atoms_raw + W - 1) // W) * W + N_msa = 1 + + atom_s = 8 + atom_z = 8 + token_s = 2 + atom_decoder_depth = 2 + atom_decoder_heads = 2 + + # For Boltz-2, p last dim = num_heads * depth (pre-computed bias) + # For Boltz-1, p last dim = atom_z (pair representation) + p_last_dim = atom_z if serial_module_version == "boltz1" else atom_decoder_heads * atom_decoder_depth + + selected_keys = list(_selected_atom_keys) + + assert N_tokens % size_cp == 0 + + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=selected_keys, + ) + feats = {k: v.to(dtype=dtype) if v.dtype == torch.float64 else v for k, v in feats.items()} + + N_atoms_actual = feats["atom_pad_mask"].shape[1] + assert N_atoms_actual % W == 0 + K = N_atoms_actual // W + + # Generate input tensors + a = torch.empty((B * multiplicity, N_tokens, token_s * 2), device=device_type, dtype=dtype, requires_grad=True) + q = torch.empty((B * multiplicity, N_atoms_actual, atom_s), device=device_type, dtype=dtype, requires_grad=True) + c = torch.empty((B * multiplicity, N_atoms_actual, atom_s), device=device_type, dtype=dtype, requires_grad=True) + p = torch.empty((B, K, W, H, p_last_dim), device=device_type, dtype=dtype, requires_grad=True) + init_tensors_uniform([a, q, c, p], low=val_init_min_max[0], high=val_init_min_max[1]) + + # Build serial module + if serial_module_version == "boltz1": + get_indexing_matrix = get_indexing_matrix_v1 + single_to_keys = single_to_keys_v1 + reference_module = SerialAtomAttentionDecoderBoltz1( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + attn_window_queries=W, + attn_window_keys=H, + atom_decoder_depth=atom_decoder_depth, + atom_decoder_heads=atom_decoder_heads, + ).to(device=device_type, dtype=dtype) + else: + get_indexing_matrix = get_indexing_matrix_v2 + single_to_keys = single_to_keys_v2 + reference_module = SerialAtomAttentionDecoderBoltz2( + atom_s=atom_s, + token_s=token_s, + attn_window_queries=W, + attn_window_keys=H, + atom_decoder_depth=atom_decoder_depth, + atom_decoder_heads=atom_decoder_heads, + ).to(device=device_type, dtype=dtype) + + reference_module.train() + init_module_params_uniform(reference_module, low=val_init_min_max[0], high=val_init_min_max[1]) + reference_module.apply(SetModuleInfValues()) + layer_state_dict = reference_module.state_dict() + + # Serial forward + feats_serial = {k: v.detach().clone() for k, v in feats.items()} + a_serial = a.detach().clone().requires_grad_(True) + q_serial = q.detach().clone().requires_grad_(True) + c_serial = c.detach().clone().requires_grad_(True) + p_serial = p.detach().clone().requires_grad_(True) + + index_matrix = get_indexing_matrix(K, W, H, device_type).to(dtype=dtype) + to_keys = partial(single_to_keys, indexing_matrix=index_matrix, W=W, H=H) + + if serial_module_version == "boltz1": + r_update_expected = reference_module( + a=a_serial, + q=q_serial, + c=c_serial, + p=p_serial, + feats=feats_serial, + to_keys=to_keys, + multiplicity=multiplicity, + model_cache=None, + ) + else: + r_update_expected = reference_module( + a=a_serial, + q=q_serial, + c=c_serial, + atom_dec_bias=p_serial, + feats=feats_serial, + to_keys=to_keys, + multiplicity=multiplicity, + ) + + # Upstream gradient + d_r_update = torch.empty_like(r_update_expected) + init_tensors_uniform([d_r_update], low=val_init_min_max[0], high=val_init_min_max[1]) + d_r_update = d_r_update * feats_serial["atom_pad_mask"].unsqueeze(-1).repeat_interleave(multiplicity, dim=0) + + torch.autograd.backward([r_update_expected], [d_r_update]) + + # Collect expected outputs + spawn_multiprocessing( + parallel_assert_atom_attention_decoder_wb, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_module_version, + dtype, + multiplicity, + atom_s, + atom_z, + token_s, + W, + H, + atom_decoder_depth, + atom_decoder_heads, + layer_state_dict, + {k: v.detach().cpu() for k, v in feats.items()}, + a.detach().cpu(), + q.detach().cpu(), + c.detach().cpu(), + p.detach().cpu(), + d_r_update.detach().cpu(), + r_update_expected.detach().cpu(), + a_serial.grad.detach().cpu(), + q_serial.grad.detach().cpu(), + c_serial.grad.detach().cpu(), + p_serial.grad.detach().cpu(), + { + name: param.grad.detach().cpu() + for name, param in reference_module.named_parameters() + if param.grad is not None + }, + ) + + +# ====================================================================== +# Test 2: AtomAttentionDecoder under autocast bf16 (dtype-only comparison) +# ====================================================================== + + +def parallel_assert_atom_attention_decoder_wb_autocast_bf16( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + atom_s: int, + token_s: int, + W: int, + H: int, + atom_decoder_depth: int, + atom_decoder_heads: int, + layer_state_dict, + feats_global_host: dict[str, torch.Tensor], + a_global_host: torch.Tensor, + q_global_host: torch.Tensor, + c_global_host: torch.Tensor, + p_global_host: torch.Tensor, + serial_output_dtype: torch.dtype, + serial_grad_dtypes: dict[str, torch.dtype], + serial_param_grad_dtypes: dict[str, torch.dtype], +): + """Parallel worker for bf16 autocast dtype test on AtomAttentionDecoder.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + dtype = torch.float32 + multiplicity = 1 + + module_serial = SerialAtomAttentionDecoderBoltz2( + atom_s=atom_s, + token_s=token_s, + attn_window_queries=W, + attn_window_keys=H, + atom_decoder_depth=atom_decoder_depth, + atom_decoder_heads=atom_decoder_heads, + ).to(device=manager.device, dtype=dtype) + module_serial.load_state_dict(layer_state_dict) + + module = DistributedAtomAttentionDecoder( + layer=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # Distribute atom features + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in _placements_cp_atom_features + } + feats_dt = distribute_atom_features( + inputs_atom, + _placements_cp_atom_features, + _placements_atom_features, + manager.device_mesh_subgroups, + manager.group["cp"], + ) + feats_dt_packed = pack_atom_features(feats_dt, set(feats_dt.keys()), W) + N_atoms_packed = feats_dt_packed["atom_pad_mask"].shape[1] + K_packed = N_atoms_packed // W + + # Distribute inputs + a_dt = distribute_tensor( + a_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + _placements_single, + ).requires_grad_(True) + q_padded = pad_or_shrink_to_length( + q_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=N_atoms_packed + ) + q_dt = distribute_tensor(q_padded, manager.device_mesh_subgroups, _placements_single).requires_grad_(True) + c_padded = pad_or_shrink_to_length( + c_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=N_atoms_packed + ) + c_dt = distribute_tensor(c_padded, manager.device_mesh_subgroups, _placements_single).requires_grad_(True) + p_padded = pad_or_shrink_to_length( + p_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=K_packed + ) + p_dt = distribute_tensor(p_padded, manager.device_mesh_subgroups, _placements_single).requires_grad_(True) + + # Forward under autocast + with torch.autocast("cuda", dtype=torch.bfloat16): + r_update_dt = module(a=a_dt, q=q_dt, c=c_dt, p=p_dt, feats=feats_dt_packed, multiplicity=multiplicity) + + torch.autograd.backward([r_update_dt], [torch.ones_like(r_update_dt)]) + + # Assert output dtype + assert ( + r_update_dt.dtype == serial_output_dtype + ), f"r_update dtype mismatch: DTensor {r_update_dt.dtype} vs serial {serial_output_dtype}" + + # Assert input grad dtypes + for name, dt_tensor in [("a", a_dt), ("q", q_dt), ("c", c_dt)]: + assert dt_tensor.grad is not None, f"{name} grad is None" + assert ( + dt_tensor.grad.dtype == serial_grad_dtypes[name] + ), f"{name} grad dtype mismatch: DTensor {dt_tensor.grad.dtype} vs serial {serial_grad_dtypes[name]}" + + # Assert param grad dtypes + for name, param in module.named_parameters(): + if name in serial_param_grad_dtypes and param.grad is not None: + grad_dtype = param.grad.full_tensor().dtype if hasattr(param.grad, "full_tensor") else param.grad.dtype + assert ( + grad_dtype == serial_param_grad_dtypes[name] + ), f"param '{name}' grad dtype mismatch: DTensor {grad_dtype} vs serial {serial_param_grad_dtypes[name]}" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [((1, (1, 1)), True, "cuda", "ENV")], + indirect=["setup_env"], + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +def test_atom_attention_decoder_wb_autocast_bf16(setup_env): + """Test DTensor AtomAttentionDecoder output dtypes under autocast bf16.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + seed_by_rank(0, seed=seed) + + B = 1 + W = 32 + H = 128 + val_init_min_max = (-0.1, 0.1) + dtype = torch.float32 + multiplicity = 1 + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + N_tokens = 30 + N_atoms_raw = N_tokens * n_atoms_per_token_max + N_atoms = ((N_atoms_raw + W - 1) // W) * W + N_msa = 1 + + atom_s = 8 + token_s = 2 + atom_decoder_depth = 2 + atom_decoder_heads = 2 + p_last_dim = atom_decoder_heads * atom_decoder_depth + + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=list(_selected_atom_keys), + ) + feats = {k: v.to(dtype=dtype) if v.dtype.is_floating_point else v for k, v in feats.items()} + N_atoms_actual = feats["atom_pad_mask"].shape[1] + K = N_atoms_actual // W + + a = torch.empty((B * multiplicity, N_tokens, token_s * 2), device=device_type, dtype=dtype, requires_grad=True) + q = torch.empty((B * multiplicity, N_atoms_actual, atom_s), device=device_type, dtype=dtype, requires_grad=True) + c = torch.empty((B * multiplicity, N_atoms_actual, atom_s), device=device_type, dtype=dtype, requires_grad=True) + p = torch.empty((B, K, W, H, p_last_dim), device=device_type, dtype=dtype, requires_grad=True) + init_tensors_uniform([a, q, c, p], low=val_init_min_max[0], high=val_init_min_max[1]) + + reference_module = SerialAtomAttentionDecoderBoltz2( + atom_s=atom_s, + token_s=token_s, + attn_window_queries=W, + attn_window_keys=H, + atom_decoder_depth=atom_decoder_depth, + atom_decoder_heads=atom_decoder_heads, + ).to(device=device_type, dtype=dtype) + reference_module.train() + init_module_params_uniform(reference_module, low=val_init_min_max[0], high=val_init_min_max[1]) + reference_module.apply(SetModuleInfValues()) + layer_state_dict = reference_module.state_dict() + + # Serial forward under autocast + a_serial = a.detach().clone().requires_grad_(True) + q_serial = q.detach().clone().requires_grad_(True) + c_serial = c.detach().clone().requires_grad_(True) + p_serial = p.detach().clone().requires_grad_(True) + + index_matrix = get_indexing_matrix_v2(K, W, H, device_type).to(dtype=dtype) + to_keys = partial(single_to_keys_v2, indexing_matrix=index_matrix, W=W, H=H) + + with torch.autocast("cuda", dtype=torch.bfloat16): + r_update_serial = reference_module( + a=a_serial, + q=q_serial, + c=c_serial, + atom_dec_bias=p_serial, + feats={k: v.clone() for k, v in feats.items()}, + to_keys=to_keys, + multiplicity=multiplicity, + ) + + torch.autograd.backward([r_update_serial], [torch.ones_like(r_update_serial)]) + + serial_output_dtype = r_update_serial.dtype + serial_grad_dtypes = {"a": a_serial.grad.dtype, "q": q_serial.grad.dtype, "c": c_serial.grad.dtype} + serial_param_grad_dtypes = { + name: param.grad.dtype for name, param in reference_module.named_parameters() if param.grad is not None + } + + spawn_multiprocessing( + parallel_assert_atom_attention_decoder_wb_autocast_bf16, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + atom_s, + token_s, + W, + H, + atom_decoder_depth, + atom_decoder_heads, + {k: v.detach().cpu() for k, v in layer_state_dict.items()}, + {k: v.detach().cpu() for k, v in feats.items()}, + a.detach().cpu(), + q.detach().cpu(), + c.detach().cpu(), + p.detach().cpu(), + serial_output_dtype, + serial_grad_dtypes, + serial_param_grad_dtypes, + ) diff --git a/tests/distributed/model/modules/test_dtensor_atom_attn_encoder_wb.py b/tests/distributed/model/modules/test_dtensor_atom_attn_encoder_wb.py new file mode 100644 index 000000000..7625cf52e --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_atom_attn_encoder_wb.py @@ -0,0 +1,1040 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor AtomAttentionEncoder with window batching. + +Tests the DTensor AtomAttentionEncoder against V1 and V2 serial references, +verifying forward and backward numerical equivalence. + +Uses float64 to enable exact (default tolerance) comparison between +serial and distributed computation paths. +""" + +from functools import partial + +import pytest +import torch +from torch.distributed.tensor import distribute_tensor + +from boltz.distributed.data.feature.featurizer import pack_atom_features +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.flatten_and_unflatten import shardwise_unflatten_sharded +from boltz.distributed.model.layers.utils import convert_single_repr_to_window_batched_key +from boltz.distributed.model.modules.encoders import ( + AtomAttentionEncoder as DistributedAtomAttentionEncoder, +) +from boltz.model.modules.encoders import AtomAttentionEncoder as SerialAtomAttentionEncoderBoltz1 +from boltz.model.modules.encoders import get_indexing_matrix as get_indexing_matrix_v1 +from boltz.model.modules.encoders import single_to_keys as single_to_keys_v1 +from boltz.model.modules.encodersv2 import AtomAttentionEncoder as SerialAtomAttentionEncoderBoltz2 +from boltz.model.modules.encodersv2 import AtomEncoder as SerialAtomEncoderV2 +from boltz.model.modules.encodersv2 import get_indexing_matrix as get_indexing_matrix_v2 +from boltz.model.modules.encodersv2 import single_to_keys as single_to_keys_v2 +from boltz.testing.utils import ( + SetModuleInfValues, + assert_all_identical, + assert_tensors_close_with_pad, + distribute_atom_features, + get_feature_placements, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + pad_or_shrink_to_length, + random_features, + seed_by_rank, + spawn_multiprocessing, +) + +# Subset of keys needed for AtomAttentionEncoder test +_selected_atom_keys = { + "atom_pad_mask", + "ref_pos", + "ref_space_uid", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "atom_to_token", + "atom_counts_per_token", + "token_pad_mask", +} + +_placements = get_feature_placements( + token_keys=set(), + msa_keys=set(), + atom_keys=_selected_atom_keys, + model_io_keys=set(), + model_io_fp32_keys=set(), +) +_placements_single = _placements["single"] +_placements_pair = _placements["pair"] +_placements_cp_atom_features = _placements["cp_atom_features"] +_placements_atom_features = _placements["atom_features"] + + +def parallel_assert_atom_attention_encoder_wb( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + serial_module_version: str, + dtype: torch.dtype, + multiplicity: int, + atom_s: int, + atom_z: int, + token_s: int, + token_z: int, + atom_feature_dim: int, + W: int, + H: int, + atom_encoder_depth: int, + atom_encoder_heads: int, + structure_prediction: bool, + layer_state_dict, + # V2 only: AtomEncoder state dict for generating q/c/p inside worker + atom_encoder_state_dict, + feats_global_host: dict[str, torch.Tensor], + s_trunk_global_host: torch.Tensor | None, + z_global_host: torch.Tensor | None, + r_global_host: torch.Tensor | None, + # Upstream gradients + d_a_global_host: torch.Tensor, + d_q_out_global_host: torch.Tensor, + d_c_out_global_host: torch.Tensor, + d_p_out_global_host: torch.Tensor, + # Expected outputs + a_expected_global_host: torch.Tensor, + q_out_expected_global_host: torch.Tensor, + c_out_expected_global_host: torch.Tensor, + p_out_expected_global_host: torch.Tensor, + # Expected input grads + d_s_trunk_expected_global_host: torch.Tensor | None, + d_z_expected_global_host: torch.Tensor | None, + d_r_expected_global_host: torch.Tensor | None, + expected_param_grads_global_host_dict: dict[str, torch.Tensor], +): + """Parallel worker for DTensor AtomAttentionEncoder window batching test.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Recreate serial module -- move to device/dtype BEFORE load_state_dict + if serial_module_version == "boltz1": + module_serial = SerialAtomAttentionEncoderBoltz1( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=atom_feature_dim, + atom_encoder_depth=atom_encoder_depth, + atom_encoder_heads=atom_encoder_heads, + structure_prediction=structure_prediction, + ) + else: + module_serial = SerialAtomAttentionEncoderBoltz2( + atom_s=atom_s, + token_s=token_s, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_encoder_depth=atom_encoder_depth, + atom_encoder_heads=atom_encoder_heads, + structure_prediction=structure_prediction, + ) + module_serial = module_serial.to(device=manager.device, dtype=dtype) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.train() + module_serial.apply(SetModuleInfValues()) + + # Create distributed module + module = DistributedAtomAttentionEncoder( + layer=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # Global masks + token_pad_mask_global = feats_global_host.pop("token_pad_mask").to(device=manager.device, dtype=torch.bool) + token_pad_mask_expanded_global = token_pad_mask_global.unsqueeze(-1) + + atom_pad_mask_global = feats_global_host["atom_pad_mask"].to(device=manager.device, dtype=torch.bool) + atom_pad_mask_expanded_global = atom_pad_mask_global.unsqueeze(-1) + atom_pad_mask_expanded_global_mul = atom_pad_mask_expanded_global.repeat_interleave(multiplicity, dim=0) + + # ======================================================================== + # Distribute atom features + # ======================================================================== + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in _placements_cp_atom_features + } + feats_dt = distribute_atom_features( + inputs_atom, + _placements_cp_atom_features, + _placements_atom_features, + manager.device_mesh_subgroups, + manager.group["cp"], + ) + + # Pack atom features + feats_dt_packed = pack_atom_features(feats_dt, set(feats_dt.keys()), W) + N_atoms_packed = feats_dt_packed["atom_pad_mask"].shape[1] + K_packed = N_atoms_packed // W + + # Distribute token-level tensors + s_trunk_dt = None + z_dt = None + r_dt = None + if structure_prediction: + s_trunk_dt = distribute_tensor( + s_trunk_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + _placements_single, + ).requires_grad_(True) + + z_dt = distribute_tensor( + z_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + _placements_pair, + ).requires_grad_(True) + + r_padded = pad_or_shrink_to_length( + r_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=N_atoms_packed + ) + r_dt = distribute_tensor(r_padded, manager.device_mesh_subgroups, _placements_single).requires_grad_(True) + + # For V2: compute q, c from DTensor AtomEncoder; bias is passed in separately + q_dt = None + c_dt = None + atom_enc_bias_dt = None + if serial_module_version == "boltz2": + atom_encoder_serial = SerialAtomEncoderV2( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=atom_feature_dim, + structure_prediction=structure_prediction, + ) + atom_encoder_serial = atom_encoder_serial.to(device=manager.device, dtype=dtype) + atom_encoder_serial.load_state_dict(atom_encoder_state_dict) + atom_encoder_serial.train() + atom_encoder_serial.apply(SetModuleInfValues()) + + from boltz.distributed.model.modules.encoders import AtomEncoder as DistributedAtomEncoder + + atom_encoder_dt = DistributedAtomEncoder( + layer=atom_encoder_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + q_dt, c_dt, _ = atom_encoder_dt(feats=feats_dt_packed, s_trunk=s_trunk_dt, z=z_dt) + + # Distribute the pre-generated atom_enc_bias (p_out_expected_global_host for V2 is the bias) + # Pad to packed K windows + p_padded = pad_or_shrink_to_length( + p_out_expected_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=K_packed + ) + atom_enc_bias_dt = distribute_tensor( + p_padded, manager.device_mesh_subgroups, _placements_single + ).requires_grad_(True) + + # ======================================================================== + # Forward pass + # ======================================================================== + a_dt, q_out_dt, c_out_dt, p_out_dt = module( + feats=feats_dt_packed, + q=q_dt, + c=c_dt, + atom_enc_bias=atom_enc_bias_dt, + s_trunk=s_trunk_dt if serial_module_version == "boltz1" else None, + z=z_dt if serial_module_version == "boltz1" else None, + r=r_dt if structure_prediction else None, + multiplicity=multiplicity, + ) + + # ======================================================================== + # Forward comparison + # ======================================================================== + # a is token feature - compare with token mask + a_expected_device = a_expected_global_host.to(device=manager.device, dtype=dtype) + token_pad_mask_expanded_global_mul = token_pad_mask_expanded_global.repeat_interleave(multiplicity, dim=0) + torch.testing.assert_close( + a_dt.full_tensor() * token_pad_mask_expanded_global_mul, + a_expected_device * token_pad_mask_expanded_global_mul, + ) + + # q_out and c_out are atom features with multiplicity + mask_dt_full = feats_dt_packed["atom_pad_mask"].full_tensor() + mask_dt_full_mul = mask_dt_full.repeat_interleave(multiplicity, dim=0) + mask_dt_full_mul_expanded = mask_dt_full_mul.unsqueeze(-1) + + q_out_expected_device = q_out_expected_global_host.to(device=manager.device, dtype=dtype) + assert_tensors_close_with_pad( + q_out_dt.full_tensor() * mask_dt_full_mul_expanded, + q_out_expected_device * atom_pad_mask_expanded_global_mul, + axis=1, + pad_val=0, + ) + + c_out_expected_device = c_out_expected_global_host.to(device=manager.device, dtype=dtype) + assert_tensors_close_with_pad( + c_out_dt.full_tensor() * mask_dt_full_mul_expanded, + c_out_expected_device * atom_pad_mask_expanded_global_mul, + axis=1, + pad_val=0, + ) + + # Compare only the valid 'key' region of the pair repr. + # Due to pack_atom_features and the resulting difference in atom length, + # the two pair repr (DTensor vs serial) can have different number of (W, H) windows + # and the extra windows in either case should be invalid by definition of pack_atom_features' + # guaranteeing not removing valid atoms. However, for comparing the two pair repr for numerical + # consistency, we need to mask both (W, H) axes because otherwise the last window + # can contain non-zero values for the invalid query atoms, failing assert_tensors_close_with_pad. + # Example of last two windows' mask (from Boltz-1x CP test with W=32, H=128): + # mask_dt_key_full_expanded[0, -2:, 0, :, 0] -- key mask shows partial validity: + # window -2: [1,1,...,1, 0,0,...,0] (51 valid keys, 77 padding) + # window -1: [1,1,...,1, 0,0,...,0] (19 valid keys, 109 padding) + # mask_dt_query_full_expanded[0, -2:, :, 0, 0] -- query mask shows partial validity: + # window -2: [1,1,1, 0,...,0] (3 valid queries, 29 padding) + # window -1: [0,0,...,0] (all padding -- entirely invalid window) + # Without masking both axes, the all-padding window -1 would have non-zero pair values + # from the forward pass (computed on garbage padding data) that don't exist in the serial. + N_atoms_serial = feats_global_host["atom_pad_mask"].shape[1] + K_serial = N_atoms_serial // W + + mask_dt_query = shardwise_unflatten_sharded(feats_dt_packed["atom_pad_mask"], axis=1, sizes=(K_packed, W)) + mask_dt_query_full = mask_dt_query.full_tensor() + mask_dt_query_full_expanded = mask_dt_query_full[:, :, :, None, None] + mask_dt_key = convert_single_repr_to_window_batched_key(feats_dt_packed["atom_pad_mask"], W, H) + mask_dt_key_full = mask_dt_key.full_tensor() + mask_dt_key_full_expanded = mask_dt_key_full[:, :, None, :, None] + mask_dt_pair_full_expanded = mask_dt_query_full_expanded * mask_dt_key_full_expanded + + if serial_module_version == "boltz1": + index_matrix = get_indexing_matrix_v1(K_serial, W, H, manager.device).to(dtype=dtype) + to_keys_fn = partial(single_to_keys_v1, indexing_matrix=index_matrix, W=W, H=H) + else: + compute_dtype = torch.promote_types(dtype, torch.float32) + index_matrix = get_indexing_matrix_v2(K_serial, W, H, manager.device).to(dtype=compute_dtype) + to_keys_fn = partial(single_to_keys_v2, indexing_matrix=index_matrix, W=W, H=H) + + mask_key_expected = to_keys_fn( + feats_global_host["atom_pad_mask"].to(device=manager.device, dtype=dtype).unsqueeze(-1) + ) + mask_key_expected_expanded = mask_key_expected[:, :, None, :, :] + mask_query_expected_expanded = atom_pad_mask_expanded_global.unflatten( + 1, (atom_pad_mask_expanded_global.shape[1] // W, W) + )[:, :, :, None, :] + mask_pair_expected_expanded = mask_query_expected_expanded * mask_key_expected_expanded + + p_out_expected_device = p_out_expected_global_host.to(device=manager.device, dtype=dtype) + assert_tensors_close_with_pad( + p_out_dt.full_tensor() * mask_dt_pair_full_expanded, + p_out_expected_device * mask_pair_expected_expanded, + axis=1, + pad_val=0, + ) + + # ======================================================================== + # Backward pass + # ======================================================================== + d_a_dt = distribute_tensor( + d_a_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + a_dt.placements, + ) + + d_q_out_padded = pad_or_shrink_to_length( + d_q_out_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=q_out_dt.shape[1] + ) + d_q_out_dt = distribute_tensor(d_q_out_padded, manager.device_mesh_subgroups, q_out_dt.placements) + + d_c_out_padded = pad_or_shrink_to_length( + d_c_out_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=c_out_dt.shape[1] + ) + d_c_out_dt = distribute_tensor(d_c_out_padded, manager.device_mesh_subgroups, c_out_dt.placements) + + d_p_out_padded = pad_or_shrink_to_length( + d_p_out_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=p_out_dt.shape[1] + ) + d_p_out_dt = distribute_tensor(d_p_out_padded, manager.device_mesh_subgroups, p_out_dt.placements) + + torch.autograd.backward( + [a_dt, q_out_dt, c_out_dt, p_out_dt], + [d_a_dt, d_q_out_dt, d_c_out_dt, d_p_out_dt], + ) + + # Check input gradients (only for V1 where s_trunk/z flow through AtomAttentionEncoder) + if structure_prediction and s_trunk_dt is not None and d_s_trunk_expected_global_host is not None: + d_s_trunk_expected_device = d_s_trunk_expected_global_host.to(device=manager.device, dtype=dtype) + torch.testing.assert_close(s_trunk_dt.grad.full_tensor(), d_s_trunk_expected_device) + + if structure_prediction and z_dt is not None and d_z_expected_global_host is not None: + d_z_expected_device = d_z_expected_global_host.to(device=manager.device, dtype=dtype) + torch.testing.assert_close(z_dt.grad.full_tensor(), d_z_expected_device) + + if structure_prediction and r_dt is not None: + d_r_expected_device = d_r_expected_global_host.to(device=manager.device, dtype=dtype) + r_grad_full = r_dt.grad.full_tensor() + assert_tensors_close_with_pad( + r_grad_full * mask_dt_full_mul_expanded[:, :, :3], + d_r_expected_device * atom_pad_mask_expanded_global_mul[:, :, :3], + axis=1, + pad_val=0, + ) + + # Parameter grads + for name, grad_expected_global in expected_param_grads_global_host_dict.items(): + grad_param = get_param_by_key(module, name).grad + assert grad_param is not None, f"Missing grad for param {name}" + + if hasattr(grad_param, "full_tensor"): + grad_to_check = grad_param.full_tensor() + else: + grad_to_check = grad_param + + torch.testing.assert_close(grad_to_check.cpu(), grad_expected_global.to(dtype=dtype)) + assert_all_identical(grad_to_check, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env, dtype, multiplicity", + ( + params_test := [ + (((1, (2, 2)), True, "cuda", "ENV"), torch.float64, 1), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, 4), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, device_type:{x[0][2]}, dtype:{x[1]}, multiplicity:{x[2]}" + for x in params_test + ], +) +# TODO: Add "boltz1" to serial_module_version to test the V1 internalized_AtomEncoder path. +# Requirements for boltz1 test: +# - V1 serial AtomAttentionEncoder (monolithic: embed + pair + r_to_q + transformer + scatter) +# - V1 feature set includes atom_pad_mask in the atom_feats concat (V2 does not) +# - V1 r_to_q_trans takes 10-dim input: concat([r, zeros(B*M, N, 7)]) +# - V1 serial forward signature: forward(feats, s_trunk, z, r, multiplicity, model_cache) +# - V1 returns 5 values: (a, q, c, p, to_keys) vs V2's 4: (a, q, c, to_keys) +# - The DTensor path exercises _atom_encoder() shared function through the V1 code path +@pytest.mark.parametrize("serial_module_version", ["boltz2"]) +def test_atom_attention_encoder_window_batching(setup_env, dtype, multiplicity, serial_module_version): + """Test DTensor AtomAttentionEncoder with window batching.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + structure_prediction = multiplicity > 1 + seed = 42 + seed_by_rank(0, seed=seed) + + size_cp = grid_group_sizes["cp"][0] + B = 1 * grid_group_sizes["dp"] + + W = 32 + H = 128 + # Small init range needed because the serial AtomAttentionEncoder uses 1e-6 epsilon in + # atom_to_token mean normalization (/ (count + 1e-6)) while DTensor uses exact scatter mean. + # The ~1e-6 relative error scales with value magnitude; (-0.02, 0.02) keeps it within + # float64 default tolerance (atol=1e-7) while maintaining non-trivial gradient magnitudes + # (transition layers ~1e-7, attention layers ~1e-4, scatter ~1e-2). + val_init_min_max = (-0.03, 0.03) + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + N_tokens = 50 * size_cp + N_atoms_raw = N_tokens * n_atoms_per_token_max + N_atoms = ((N_atoms_raw + W - 1) // W) * W + N_msa = 1 + + atom_s = 8 + atom_z = 8 + token_s = 2 + token_z = 2 + atom_encoder_depth = 2 + atom_encoder_heads = 2 + + from boltz.data import const as boltz_const + + atom_feature_dim = 3 + 1 + boltz_const.num_elements + 4 * 64 + + selected_keys = list(_selected_atom_keys) + assert N_tokens % size_cp == 0 + + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=selected_keys, + ) + feats = {k: v.to(dtype=dtype) if v.dtype == torch.float64 else v for k, v in feats.items()} + + N_atoms_actual = feats["atom_pad_mask"].shape[1] + K = N_atoms_actual // W + + # Build serial modules + # For Boltz-2, p last dim = num_heads * depth (pre-computed bias split for DiffusionTransformer) + # For Boltz-1, p last dim = atom_z (pair representation) + p_last_dim = atom_z if serial_module_version == "boltz1" else atom_encoder_heads * atom_encoder_depth + + atom_encoder_state_dict = None + if serial_module_version == "boltz2": + # V2: need AtomEncoder to generate q, c (but NOT p -- bias is separate) + atom_encoder_module = SerialAtomEncoderV2( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=atom_feature_dim, + structure_prediction=structure_prediction, + ).to(device=device_type, dtype=dtype) + atom_encoder_module.train() + init_module_params_uniform(atom_encoder_module, low=val_init_min_max[0], high=val_init_min_max[1]) + atom_encoder_module.apply(SetModuleInfValues()) + atom_encoder_state_dict = {k: v.detach().cpu() for k, v in atom_encoder_module.state_dict().items()} + + reference_module = SerialAtomAttentionEncoderBoltz2( + atom_s=atom_s, + token_s=token_s, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_encoder_depth=atom_encoder_depth, + atom_encoder_heads=atom_encoder_heads, + structure_prediction=structure_prediction, + ).to(device=device_type, dtype=dtype) + else: + reference_module = SerialAtomAttentionEncoderBoltz1( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=atom_feature_dim, + atom_encoder_depth=atom_encoder_depth, + atom_encoder_heads=atom_encoder_heads, + structure_prediction=structure_prediction, + ).to(device=device_type, dtype=dtype) + + reference_module.train() + init_module_params_uniform(reference_module, low=val_init_min_max[0], high=val_init_min_max[1]) + reference_module.apply(SetModuleInfValues()) + layer_state_dict = {k: v.detach().cpu() for k, v in reference_module.state_dict().items()} + + # Prepare inputs + s_trunk = None + z = None + r = None + if structure_prediction: + s_trunk = torch.empty((B, N_tokens, token_s), device=device_type, dtype=dtype, requires_grad=True) + z = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype, requires_grad=True) + r = torch.empty((B * multiplicity, N_atoms_actual, 3), device=device_type, dtype=dtype, requires_grad=True) + init_tensors_uniform([s_trunk, z, r], low=val_init_min_max[0], high=val_init_min_max[1]) + + # Serial forward + feats_serial = {k: v.detach().clone() for k, v in feats.items()} + s_trunk_serial = s_trunk.detach().clone().requires_grad_(True) if s_trunk is not None else None + z_serial = z.detach().clone().requires_grad_(True) if z is not None else None + r_serial = r.detach().clone().requires_grad_(True) if r is not None else None + + if serial_module_version == "boltz2": + # Run AtomEncoder first to get q, c (we use a separate random bias) + q_enc, c_enc, _, to_keys_enc = atom_encoder_module( + feats=feats_serial, + s_trunk=s_trunk_serial, + z=z_serial, + ) + # Create random atom_enc_bias with correct shape (B, K, W, H, num_heads*depth) + atom_enc_bias = torch.empty((B, K, W, H, p_last_dim), device=device_type, dtype=dtype, requires_grad=True) + init_tensors_uniform([atom_enc_bias], low=val_init_min_max[0], high=val_init_min_max[1]) + atom_enc_bias_serial = atom_enc_bias.detach().clone().requires_grad_(True) + + # Run AtomAttentionEncoder + a_expected, q_out_expected, c_out_expected, _ = reference_module( + feats=feats_serial, + q=q_enc, + c=c_enc, + atom_enc_bias=atom_enc_bias_serial, + to_keys=to_keys_enc, + r=r_serial, + multiplicity=multiplicity, + ) + p_out_expected = atom_enc_bias_serial + else: + a_expected, q_out_expected, c_out_expected, p_out_expected, _ = reference_module( + feats=feats_serial, + s_trunk=s_trunk_serial, + z=z_serial, + r=r_serial, + multiplicity=multiplicity, + model_cache=None, + ) + + # Generate upstream gradients + d_a = torch.empty_like(a_expected) + d_q_out = torch.empty_like(q_out_expected) + d_c_out = torch.empty_like(c_out_expected) + d_p_out = torch.empty_like(p_out_expected) + init_tensors_uniform([d_a, d_q_out, d_c_out, d_p_out], low=val_init_min_max[0], high=val_init_min_max[1]) + + # Mask upstream gradients + compute_dtype = torch.promote_types(dtype, torch.float32) + if serial_module_version == "boltz1": + index_matrix = get_indexing_matrix_v1(K, W, H, device_type).to(dtype=dtype) + to_keys_mask = partial(single_to_keys_v1, indexing_matrix=index_matrix, W=W, H=H) + else: + index_matrix = get_indexing_matrix_v2(K, W, H, device_type).to(dtype=compute_dtype) + to_keys_mask = partial(single_to_keys_v2, indexing_matrix=index_matrix, W=W, H=H) + + mask_key_expected_full = to_keys_mask( + feats_serial["atom_pad_mask"].to(dtype=compute_dtype, device=d_p_out.device).unsqueeze(-1) + ) + d_a = d_a * feats_serial["token_pad_mask"].unsqueeze(-1).repeat_interleave(multiplicity, dim=0) + d_q_out = d_q_out * feats_serial["atom_pad_mask"].unsqueeze(-1).repeat_interleave(multiplicity, dim=0) + d_c_out = d_c_out * feats_serial["atom_pad_mask"].unsqueeze(-1).repeat_interleave(multiplicity, dim=0) + d_p_out = d_p_out * mask_key_expected_full[:, :, None, :, :] + + # Serial backward + torch.autograd.backward( + [a_expected, q_out_expected, c_out_expected, p_out_expected], + [d_a, d_q_out, d_c_out, d_p_out], + ) + + expected_param_grads = { + name: param.grad.detach().cpu() for name, param in reference_module.named_parameters() if param.grad is not None + } + + spawn_multiprocessing( + parallel_assert_atom_attention_encoder_wb, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_module_version, + dtype, + multiplicity, + atom_s, + atom_z, + token_s, + token_z, + atom_feature_dim, + W, + H, + atom_encoder_depth, + atom_encoder_heads, + structure_prediction, + layer_state_dict, + atom_encoder_state_dict, + {k: v.detach().cpu() for k, v in feats.items()}, + s_trunk.detach().cpu() if s_trunk is not None else None, + z.detach().cpu() if z is not None else None, + r.detach().cpu() if r is not None else None, + d_a.detach().cpu(), + d_q_out.detach().cpu(), + d_c_out.detach().cpu(), + d_p_out.detach().cpu(), + a_expected.detach().cpu(), + q_out_expected.detach().cpu(), + c_out_expected.detach().cpu(), + p_out_expected.detach().cpu(), + s_trunk_serial.grad.detach().cpu() if s_trunk_serial is not None and s_trunk_serial.grad is not None else None, + z_serial.grad.detach().cpu() if z_serial is not None and z_serial.grad is not None else None, + r_serial.grad.detach().cpu() if r_serial is not None and r_serial.grad is not None else None, + expected_param_grads, + ) + + +# ====================================================================== +# Test 2: AtomAttentionEncoder under autocast bf16 (dtype-only comparison) +# ====================================================================== + + +def parallel_assert_atom_attention_encoder_wb_autocast_bf16( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + atom_s: int, + atom_z: int, + token_s: int, + token_z: int, + atom_feature_dim: int, + W: int, + H: int, + atom_encoder_depth: int, + atom_encoder_heads: int, + layer_state_dict, + atom_encoder_state_dict, + feats_global_host: dict[str, torch.Tensor], + s_trunk_global_host: torch.Tensor, + z_global_host: torch.Tensor, + r_global_host: torch.Tensor, + serial_output_dtypes: dict[str, torch.dtype], + serial_grad_dtypes: dict[str, torch.dtype], + serial_param_grad_dtypes: dict[str, torch.dtype], +): + """Parallel worker for bf16 autocast dtype test on AtomAttentionEncoder.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + dtype = torch.float32 + multiplicity = 1 + structure_prediction = True + + module_serial = SerialAtomAttentionEncoderBoltz2( + atom_s=atom_s, + token_s=token_s, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_encoder_depth=atom_encoder_depth, + atom_encoder_heads=atom_encoder_heads, + structure_prediction=structure_prediction, + ) + module_serial = module_serial.to(device=manager.device, dtype=dtype) + module_serial.load_state_dict(layer_state_dict) + + module = DistributedAtomAttentionEncoder( + layer=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # Distribute atom features + feats_global_host.pop("token_pad_mask", None) + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in _placements_cp_atom_features + } + feats_dt = distribute_atom_features( + inputs_atom, + _placements_cp_atom_features, + _placements_atom_features, + manager.device_mesh_subgroups, + manager.group["cp"], + ) + feats_dt_packed = pack_atom_features(feats_dt, set(feats_dt.keys()), W) + N_atoms_packed = feats_dt_packed["atom_pad_mask"].shape[1] + + s_trunk_dt = distribute_tensor( + s_trunk_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + _placements_single, + ).requires_grad_(True) + z_dt = distribute_tensor( + z_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + _placements_pair, + ).requires_grad_(True) + r_padded = pad_or_shrink_to_length( + r_global_host.to(device=manager.device, dtype=dtype), + axis=1, + target_length=N_atoms_packed, + ) + r_dt = distribute_tensor(r_padded, manager.device_mesh_subgroups, _placements_single).requires_grad_(True) + + # Create DTensor AtomEncoder for q/c + from boltz.distributed.model.modules.encoders import AtomEncoder as DistributedAtomEncoder + + atom_encoder_serial = SerialAtomEncoderV2( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=atom_feature_dim, + structure_prediction=structure_prediction, + ).to(device=manager.device, dtype=dtype) + atom_encoder_serial.load_state_dict(atom_encoder_state_dict) + atom_encoder_serial.eval() + + atom_encoder_dt = DistributedAtomEncoder( + layer=atom_encoder_serial, device_mesh=manager.device_mesh_subgroups + ).eval() + + with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): + q_dt, c_dt, _ = atom_encoder_dt(feats=feats_dt_packed, s_trunk=s_trunk_dt.detach(), z=z_dt.detach()) + + p_last_dim = atom_encoder_heads * atom_encoder_depth + K_packed = N_atoms_packed // W + atom_enc_bias_dt = distribute_tensor( + torch.randn(1, K_packed, W, H, p_last_dim, device=manager.device, dtype=dtype).expand( + s_trunk_dt.shape[0], -1, -1, -1, -1 + ), + manager.device_mesh_subgroups, + _placements_single, + ) + + # Forward under autocast + with torch.autocast("cuda", dtype=torch.bfloat16): + a_dt, q_out_dt, c_out_dt, _ = module( + feats=feats_dt_packed, + q=q_dt, + c=c_dt, + atom_enc_bias=atom_enc_bias_dt, + s_trunk=s_trunk_dt, + z=z_dt, + r=r_dt, + multiplicity=multiplicity, + ) + + outputs_with_grad = [(n, t) for n, t in [("a", a_dt), ("q_out", q_out_dt), ("c_out", c_out_dt)] if t.requires_grad] + torch.autograd.backward( + [t for _, t in outputs_with_grad], + [torch.ones_like(t) for _, t in outputs_with_grad], + ) + + # Assert output dtypes + for name, dt_tensor in [("a", a_dt), ("q_out", q_out_dt), ("c_out", c_out_dt)]: + assert ( + dt_tensor.dtype == serial_output_dtypes[name] + ), f"{name} dtype mismatch: DTensor {dt_tensor.dtype} vs serial {serial_output_dtypes[name]}" + + # Assert input grad dtypes + for name, dt_tensor in [("s_trunk", s_trunk_dt), ("z", z_dt), ("r", r_dt)]: + if name not in serial_grad_dtypes: + continue + assert dt_tensor.grad is not None, f"{name} grad is None" + assert ( + dt_tensor.grad.dtype == serial_grad_dtypes[name] + ), f"{name} grad dtype mismatch: DTensor {dt_tensor.grad.dtype} vs serial {serial_grad_dtypes[name]}" + + # Assert param grad dtypes + for name, param in module.named_parameters(): + if name in serial_param_grad_dtypes and param.grad is not None: + grad_dtype = param.grad.full_tensor().dtype if hasattr(param.grad, "full_tensor") else param.grad.dtype + assert ( + grad_dtype == serial_param_grad_dtypes[name] + ), f"param '{name}' grad dtype mismatch: DTensor {grad_dtype} vs serial {serial_param_grad_dtypes[name]}" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [((1, (1, 1)), True, "cuda", "ENV")], + indirect=["setup_env"], + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +def test_atom_attention_encoder_wb_autocast_bf16(setup_env): + """Test DTensor AtomAttentionEncoder output dtypes under autocast bf16.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + seed_by_rank(0, seed=seed) + + B = 1 + W = 32 + H = 128 + val_init_min_max = (-0.1, 0.1) + dtype = torch.float32 + multiplicity = 1 + structure_prediction = True + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + N_tokens = 30 + N_atoms_raw = N_tokens * n_atoms_per_token_max + N_atoms = ((N_atoms_raw + W - 1) // W) * W + N_msa = 1 + + atom_s = 8 + atom_z = 8 + token_s = 2 + token_z = 2 + atom_encoder_depth = 2 + atom_encoder_heads = 2 + + from boltz.data import const as boltz_const + + atom_feature_dim = 3 + 1 + boltz_const.num_elements + 4 * 64 + + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=list(_selected_atom_keys), + ) + feats = {k: v.to(dtype=dtype) if v.dtype.is_floating_point else v for k, v in feats.items()} + N_atoms_actual = feats["atom_pad_mask"].shape[1] + K = N_atoms_actual // W + p_last_dim = atom_encoder_heads * atom_encoder_depth + + s_trunk = torch.empty((B, N_tokens, token_s), device=device_type, dtype=dtype, requires_grad=True) + z = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype, requires_grad=True) + r = torch.empty((B * multiplicity, N_atoms_actual, 3), device=device_type, dtype=dtype, requires_grad=True) + init_tensors_uniform([s_trunk, z, r], low=val_init_min_max[0], high=val_init_min_max[1]) + + atom_encoder_module = SerialAtomEncoderV2( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=atom_feature_dim, + structure_prediction=structure_prediction, + ).to(device=device_type, dtype=dtype) + atom_encoder_module.eval() + init_module_params_uniform(atom_encoder_module, low=val_init_min_max[0], high=val_init_min_max[1]) + atom_encoder_module.apply(SetModuleInfValues()) + atom_encoder_state_dict = {k: v.detach().cpu() for k, v in atom_encoder_module.state_dict().items()} + + reference_module = SerialAtomAttentionEncoderBoltz2( + atom_s=atom_s, + token_s=token_s, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_encoder_depth=atom_encoder_depth, + atom_encoder_heads=atom_encoder_heads, + structure_prediction=structure_prediction, + ).to(device=device_type, dtype=dtype) + reference_module.train() + init_module_params_uniform(reference_module, low=val_init_min_max[0], high=val_init_min_max[1]) + reference_module.apply(SetModuleInfValues()) + layer_state_dict = {k: v.detach().cpu() for k, v in reference_module.state_dict().items()} + + # Serial forward under autocast + s_trunk_serial = s_trunk.detach().clone().requires_grad_(True) + z_serial = z.detach().clone().requires_grad_(True) + r_serial = r.detach().clone().requires_grad_(True) + + with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): + q_enc, c_enc, _, _ = atom_encoder_module( + feats={k: v.clone() for k, v in feats.items()}, + s_trunk=s_trunk.detach(), + z=z.detach(), + ) + + atom_enc_bias = torch.randn(B, K, W, H, p_last_dim, device=device_type, dtype=dtype) + + with torch.autocast("cuda", dtype=torch.bfloat16): + a_serial, q_out_serial, c_out_serial, _ = reference_module( + feats={k: v.clone() for k, v in feats.items()}, + q=q_enc, + c=c_enc, + atom_enc_bias=atom_enc_bias, + to_keys=partial( + single_to_keys_v2, + indexing_matrix=get_indexing_matrix_v2(K, W, H, device_type).to(dtype=torch.float32), + W=W, + H=H, + ), + r=r_serial, + multiplicity=multiplicity, + ) + + outputs_with_grad = [ + (n, t) for n, t in [("a", a_serial), ("q_out", q_out_serial), ("c_out", c_out_serial)] if t.requires_grad + ] + torch.autograd.backward( + [t for _, t in outputs_with_grad], + [torch.ones_like(t) for _, t in outputs_with_grad], + ) + + serial_output_dtypes = {"a": a_serial.dtype, "q_out": q_out_serial.dtype, "c_out": c_out_serial.dtype} + serial_grad_dtypes = {} + if s_trunk_serial.grad is not None: + serial_grad_dtypes["s_trunk"] = s_trunk_serial.grad.dtype + if z_serial.grad is not None: + serial_grad_dtypes["z"] = z_serial.grad.dtype + if r_serial.grad is not None: + serial_grad_dtypes["r"] = r_serial.grad.dtype + serial_param_grad_dtypes = { + name: param.grad.dtype for name, param in reference_module.named_parameters() if param.grad is not None + } + + spawn_multiprocessing( + parallel_assert_atom_attention_encoder_wb_autocast_bf16, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + atom_s, + atom_z, + token_s, + token_z, + atom_feature_dim, + W, + H, + atom_encoder_depth, + atom_encoder_heads, + layer_state_dict, + atom_encoder_state_dict, + {k: v.detach().cpu() for k, v in feats.items()}, + s_trunk.detach().cpu(), + z.detach().cpu(), + r.detach().cpu(), + serial_output_dtypes, + serial_grad_dtypes, + serial_param_grad_dtypes, + ) diff --git a/tests/distributed/model/modules/test_dtensor_atom_encoder_wb.py b/tests/distributed/model/modules/test_dtensor_atom_encoder_wb.py new file mode 100644 index 000000000..aecfee5b7 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_atom_encoder_wb.py @@ -0,0 +1,745 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor AtomEncoder (V2) with window batching. + +Tests the DTensor AtomEncoder against the V2 serial AtomEncoder reference, +verifying forward and backward numerical equivalence. + +Uses float64 to enable exact (default tolerance) comparison between +serial and distributed computation paths. +""" + +from functools import partial + +import pytest +import torch +from torch.distributed.tensor import distribute_tensor + +from boltz.distributed.data.feature.featurizer import pack_atom_features +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.flatten_and_unflatten import shardwise_unflatten_sharded +from boltz.distributed.model.layers.utils import convert_single_repr_to_window_batched_key +from boltz.distributed.model.modules.encoders import AtomEncoder as DistributedAtomEncoder +from boltz.model.modules.encodersv2 import AtomEncoder as SerialAtomEncoderV2 +from boltz.model.modules.encodersv2 import get_indexing_matrix, single_to_keys +from boltz.testing.utils import ( + SetModuleInfValues, + assert_all_identical, + assert_tensors_close_with_pad, + distribute_atom_features, + get_feature_placements, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + pad_or_shrink_to_length, + random_features, + seed_by_rank, + spawn_multiprocessing, +) + +# Subset of keys needed for AtomEncoder test +_selected_atom_keys = { + "atom_pad_mask", + "ref_pos", + "ref_space_uid", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "atom_to_token", # Needed by serial module and pack_atom_features (creates atom_to_token_ids_global) + "atom_counts_per_token", # Required by pad_and_scatter_atom_features_dtensor +} + +_placements = get_feature_placements( + token_keys=set(), + msa_keys=set(), + atom_keys=_selected_atom_keys, + model_io_keys=set(), + model_io_fp32_keys=set(), +) +_placements_single = _placements["single"] +_placements_pair = _placements["pair"] +_placements_cp_atom_features = _placements["cp_atom_features"] +_placements_atom_features = _placements["atom_features"] + + +def parallel_assert_atom_encoder_wb( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + dtype: torch.dtype, + atom_s: int, + atom_z: int, + token_s: int, + token_z: int, + atom_feature_dim: int, + W: int, + H: int, + structure_prediction: bool, + layer_state_dict, + feats_global_host: dict[str, torch.Tensor], + s_trunk_global_host: torch.Tensor | None, + z_global_host: torch.Tensor | None, + # Expected outputs + q_expected_global_host: torch.Tensor, + c_expected_global_host: torch.Tensor, + p_expected_global_host: torch.Tensor, + # Upstream gradients + d_q_global_host: torch.Tensor, + d_c_global_host: torch.Tensor, + d_p_global_host: torch.Tensor, + # Expected input grads + d_s_trunk_expected_global_host: torch.Tensor | None, + d_z_expected_global_host: torch.Tensor | None, + expected_param_grads_global_host_dict: dict[str, torch.Tensor], +): + """Parallel worker function for testing DTensor AtomEncoder.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Recreate serial module from state dict + module_serial = SerialAtomEncoderV2( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=atom_feature_dim, + structure_prediction=structure_prediction, + ) + # CRITICAL: Move module to target device/dtype BEFORE loading state dict. + # Otherwise float64 state dict values get truncated to float32 during copy_() + # into the default float32 nn.Linear params, then .to(float64) can't recover precision. + module_serial = module_serial.to(device=manager.device, dtype=dtype) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.train() + module_serial.apply(SetModuleInfValues()) + + # Create distributed module + module = DistributedAtomEncoder( + layer=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # Get global masks + atom_pad_mask_global = feats_global_host["atom_pad_mask"].to(device=manager.device, dtype=torch.bool) + atom_pad_mask_expanded_global = atom_pad_mask_global.unsqueeze(-1) + + # ======================================================================== + # Distribute atom features + # ======================================================================== + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in _placements_cp_atom_features + } + feats_dt = distribute_atom_features( + inputs_atom, + _placements_cp_atom_features, + _placements_atom_features, + manager.device_mesh_subgroups, + manager.group["cp"], + ) + + # Pack atom features + feats_dt_packed = pack_atom_features(feats_dt, set(feats_dt.keys()), W) + N_atoms_packed = feats_dt_packed["atom_pad_mask"].shape[1] + + # Distribute token-level tensors (s_trunk, z) + s_trunk_dt = None + z_dt = None + if structure_prediction and s_trunk_global_host is not None: + s_trunk_dt = distribute_tensor( + s_trunk_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + _placements_single, + ).requires_grad_(True) + if structure_prediction and z_global_host is not None: + z_dt = distribute_tensor( + z_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + _placements_pair, + ).requires_grad_(True) + + # ======================================================================== + # Forward pass + # ======================================================================== + q_dt, c_dt, p_dt = module(feats=feats_dt_packed, s_trunk=s_trunk_dt, z=z_dt) + + # ======================================================================== + # Forward comparison + # ======================================================================== + q_expected_device = q_expected_global_host.to(device=manager.device, dtype=dtype) + c_expected_device = c_expected_global_host.to(device=manager.device, dtype=dtype) + p_expected_device = p_expected_global_host.to(device=manager.device, dtype=dtype) + + mask_dt_full = feats_dt_packed["atom_pad_mask"].full_tensor() + mask_dt_full_expanded = mask_dt_full.unsqueeze(-1) + + # q and c: (B, N_atoms_packed, atom_s) + assert_tensors_close_with_pad( + q_dt.full_tensor() * mask_dt_full_expanded, + q_expected_device * atom_pad_mask_expanded_global, + axis=1, + pad_val=0, + ) + assert_tensors_close_with_pad( + c_dt.full_tensor() * mask_dt_full_expanded, + c_expected_device * atom_pad_mask_expanded_global, + axis=1, + pad_val=0, + ) + + # Compare only the valid 'key' region of the pair repr. + # Due to pack_atom_features and the resulting difference in atom length, + # the two pair repr (DTensor vs serial) can have different number of (W, H) windows + # and the extra windows in either case should be invalid by definition of pack_atom_features' + # guaranteeing not removing valid atoms. However, for comparing the two pair repr for numerical + # consistency, we need to mask both (W, H) axes because otherwise the last window + # can contain non-zero values for the invalid query atoms, failing assert_tensors_close_with_pad. + # Example of last two windows' mask (from Boltz-1x CP test with W=32, H=128): + # mask_dt_key_full_expanded[0, -2:, 0, :, 0] -- key mask shows partial validity: + # window -2: [1,1,...,1, 0,0,...,0] (51 valid keys, 77 padding) + # window -1: [1,1,...,1, 0,0,...,0] (19 valid keys, 109 padding) + # mask_dt_query_full_expanded[0, -2:, :, 0, 0] -- query mask shows partial validity: + # window -2: [1,1,1, 0,...,0] (3 valid queries, 29 padding) + # window -1: [0,0,...,0] (all padding -- entirely invalid window) + # Without masking both axes, the all-padding window -1 would have non-zero pair values + # from the forward pass (computed on garbage padding data) that don't exist in the serial. + K_packed = N_atoms_packed // W + N_atoms_serial = feats_global_host["atom_pad_mask"].shape[1] + K_serial = N_atoms_serial // W + + mask_dt_query = shardwise_unflatten_sharded(feats_dt_packed["atom_pad_mask"], axis=1, sizes=(K_packed, W)) + mask_dt_query_full = mask_dt_query.full_tensor() + mask_dt_query_full_expanded = mask_dt_query_full[:, :, :, None, None] + mask_dt_key = convert_single_repr_to_window_batched_key(feats_dt_packed["atom_pad_mask"], W, H) + mask_dt_key_full = mask_dt_key.full_tensor() + mask_dt_key_full_expanded = mask_dt_key_full[:, :, None, :, None] + mask_dt_pair_full_expanded = mask_dt_query_full_expanded * mask_dt_key_full_expanded + + index_matrix = get_indexing_matrix(K_serial, W, H, manager.device).to(dtype=dtype) + to_keys_fn = partial(single_to_keys, indexing_matrix=index_matrix, W=W, H=H) + + mask_key_expected = to_keys_fn( + feats_global_host["atom_pad_mask"].to(device=manager.device, dtype=dtype).unsqueeze(-1) + ) + mask_key_expected_expanded = mask_key_expected[:, :, None, :, :] + mask_query_expected_expanded = atom_pad_mask_expanded_global.unflatten( + 1, (atom_pad_mask_expanded_global.shape[1] // W, W) + )[:, :, :, None, :] + mask_pair_expected_expanded = mask_query_expected_expanded * mask_key_expected_expanded + + assert_tensors_close_with_pad( + p_dt.full_tensor() * mask_dt_pair_full_expanded, + p_expected_device * mask_pair_expected_expanded, + axis=1, + pad_val=0, + ) + + # ======================================================================== + # Backward pass + # ======================================================================== + d_q_padded = pad_or_shrink_to_length( + d_q_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=N_atoms_packed + ) + d_c_padded = pad_or_shrink_to_length( + d_c_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=N_atoms_packed + ) + d_p_padded = pad_or_shrink_to_length( + d_p_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=K_packed + ) + + d_q_dt = distribute_tensor(d_q_padded, manager.device_mesh_subgroups, q_dt.placements) + d_c_dt = distribute_tensor(d_c_padded, manager.device_mesh_subgroups, c_dt.placements) + d_p_dt = distribute_tensor(d_p_padded, manager.device_mesh_subgroups, p_dt.placements) + + torch.autograd.backward([q_dt, c_dt, p_dt], [d_q_dt, d_c_dt, d_p_dt]) + + # Check token-level input gradients + if structure_prediction and s_trunk_dt is not None: + d_s_trunk_expected_device = d_s_trunk_expected_global_host.to(device=manager.device, dtype=dtype) + torch.testing.assert_close(s_trunk_dt.grad.full_tensor(), d_s_trunk_expected_device) + + if structure_prediction and z_dt is not None: + d_z_expected_device = d_z_expected_global_host.to(device=manager.device, dtype=dtype) + torch.testing.assert_close(z_dt.grad.full_tensor(), d_z_expected_device) + + # Parameter grads + for name, grad_expected_global in expected_param_grads_global_host_dict.items(): + grad_param = get_param_by_key(module, name).grad + assert grad_param is not None, f"Missing grad for param {name}" + + if hasattr(grad_param, "full_tensor"): + grad_global_host = grad_param.full_tensor().cpu() + grad_to_check = grad_param.full_tensor() + else: + grad_global_host = grad_param.detach().cpu() + grad_to_check = grad_param + + torch.testing.assert_close(grad_global_host, grad_expected_global.to(dtype=dtype)) + assert_all_identical(grad_to_check, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env, dtype", + ( + params_test := [ + (((1, (2, 2)), True, "cuda", "ENV"), torch.float64), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64), + ] + ), + indirect=["setup_env"], + ids=[f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, device_type:{x[0][2]}, dtype:{x[1]}" for x in params_test], +) +@pytest.mark.parametrize("structure_prediction", [True, False], ids=lambda x: f"sp:{x}") +def test_atom_encoder_wb(setup_env, dtype, structure_prediction): + """Test DTensor AtomEncoder (V2) with window batching.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + seed_by_rank(0, seed=seed) + + size_cp = grid_group_sizes["cp"][0] + B = 1 * grid_group_sizes["dp"] + + W = 32 + H = 128 + val_init_min_max = (-0.1, 0.1) + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + N_tokens = 100 * size_cp + N_atoms_raw = N_tokens * n_atoms_per_token_max + N_atoms = ((N_atoms_raw + W - 1) // W) * W + N_msa = 1 + + atom_s = 8 + atom_z = 8 + token_s = 4 + token_z = 4 + + # Compute atom_feature_dim: ref_pos(3) + ref_charge(1) + ref_element(128) + ref_atom_name_chars(256) + from boltz.data import const as boltz_const + + atom_feature_dim = 3 + 1 + boltz_const.num_elements + 4 * 64 # 388 with default settings + + selected_keys = list(_selected_atom_keys) + + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=selected_keys, + ) + feats = {k: v.to(dtype=dtype) if v.dtype == torch.float64 else v for k, v in feats.items()} + + N_atoms_actual = feats["atom_pad_mask"].shape[1] + K = N_atoms_actual // W + + # Token-level inputs (only for structure_prediction) + s_trunk = None + z = None + if structure_prediction: + s_trunk = torch.empty((B, N_tokens, token_s), device=device_type, dtype=dtype, requires_grad=True) + z = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype, requires_grad=True) + init_tensors_uniform([s_trunk, z], low=val_init_min_max[0], high=val_init_min_max[1]) + + # Build serial reference module + reference_module = SerialAtomEncoderV2( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=atom_feature_dim, + structure_prediction=structure_prediction, + ).to(device=device_type, dtype=dtype) + reference_module.train() + init_module_params_uniform(reference_module, low=val_init_min_max[0], high=val_init_min_max[1]) + reference_module.apply(SetModuleInfValues()) + layer_state_dict = reference_module.state_dict() + + # Serial forward + feats_serial = {k: v.detach().clone() for k, v in feats.items()} + s_trunk_serial = s_trunk.detach().clone().requires_grad_(True) if s_trunk is not None else None + z_serial = z.detach().clone().requires_grad_(True) if z is not None else None + + q_expected, c_expected, p_expected, _ = reference_module( + feats=feats_serial, + s_trunk=s_trunk_serial, + z=z_serial, + ) + + # Upstream gradients + d_q = torch.empty_like(q_expected) + d_c = torch.empty_like(c_expected) + d_p = torch.empty_like(p_expected) + init_tensors_uniform([d_q, d_c, d_p], low=val_init_min_max[0], high=val_init_min_max[1]) + + # Apply masks to upstream gradients to zero invalid positions + mask_expanded = feats_serial["atom_pad_mask"].unsqueeze(-1) + d_q = d_q * mask_expanded + d_c = d_c * mask_expanded + + # Mask d_p with pair mask (query AND key masks) -- matches V1x test pattern + compute_dtype = torch.promote_types(dtype, torch.float32) + index_matrix = get_indexing_matrix(K, W, H, device_type).to(dtype=compute_dtype) + to_keys_fn_serial = partial(single_to_keys, indexing_matrix=index_matrix, W=W, H=H) + mask_key_serial = to_keys_fn_serial( + feats_serial["atom_pad_mask"].to(dtype=compute_dtype, device=d_p.device).unsqueeze(-1) + ) + # d_p: (B, K, W, H, atom_z) * mask_key_serial: (B, K, 1, H, 1) → (B, K, W, H, atom_z) + d_p = d_p * mask_key_serial[:, :, None, :, :] + + torch.autograd.backward([q_expected, c_expected, p_expected], [d_q, d_c, d_p]) + + expected_param_grads = { + name: param.grad.detach().cpu() for name, param in reference_module.named_parameters() if param.grad is not None + } + + spawn_multiprocessing( + parallel_assert_atom_encoder_wb, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + atom_s, + atom_z, + token_s, + token_z, + atom_feature_dim, + W, + H, + structure_prediction, + { + k: v.detach().cpu() for k, v in layer_state_dict.items() + }, # CPU state dict avoids cross-process CUDA IPC issues + {k: v.detach().cpu() for k, v in feats.items()}, + s_trunk.detach().cpu() if s_trunk is not None else None, + z.detach().cpu() if z is not None else None, + q_expected.detach().cpu(), + c_expected.detach().cpu(), + p_expected.detach().cpu(), + d_q.detach().cpu(), + d_c.detach().cpu(), + d_p.detach().cpu(), + s_trunk_serial.grad.detach().cpu() if s_trunk_serial is not None and s_trunk_serial.grad is not None else None, + z_serial.grad.detach().cpu() if z_serial is not None and z_serial.grad is not None else None, + expected_param_grads, + ) + + +# ====================================================================== +# Test 2: AtomEncoder under autocast bf16 (dtype-only comparison) +# ====================================================================== + + +def parallel_assert_atom_encoder_wb_autocast_bf16( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + atom_s: int, + atom_z: int, + token_s: int, + token_z: int, + atom_feature_dim: int, + W: int, + H: int, + layer_state_dict, + feats_global_host: dict[str, torch.Tensor], + s_trunk_global_host: torch.Tensor, + z_global_host: torch.Tensor, + q_serial_dtype: torch.dtype, + c_serial_dtype: torch.dtype, + p_serial_dtype: torch.dtype, + serial_grad_dtypes: dict[str, torch.dtype], + serial_param_grad_dtypes: dict[str, torch.dtype], +): + """Parallel worker for bf16 autocast dtype test. + + Runs DTensor AtomEncoder forward + backward under + torch.autocast("cuda", dtype=torch.bfloat16) and asserts output and + gradient dtypes match the serial reference (computed in main process). + """ + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + dtype = torch.float32 + + module_serial = SerialAtomEncoderV2( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=atom_feature_dim, + structure_prediction=True, + ) + module_serial = module_serial.to(device=manager.device, dtype=dtype) + module_serial.load_state_dict(layer_state_dict) + + module_dt = DistributedAtomEncoder( + layer=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # Distribute atom features (same pattern as existing test) + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in _placements_cp_atom_features + } + feats_dt = distribute_atom_features( + inputs_atom, + _placements_cp_atom_features, + _placements_atom_features, + manager.device_mesh_subgroups, + manager.group["cp"], + ) + feats_dt_packed = pack_atom_features(feats_dt, set(feats_dt.keys()), W) + + s_trunk_dt = distribute_tensor( + s_trunk_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + _placements_single, + ).requires_grad_(True) + z_dt = distribute_tensor( + z_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + _placements_pair, + ).requires_grad_(True) + + # DTensor forward + backward under autocast + with torch.autocast("cuda", dtype=torch.bfloat16): + q_dt, c_dt, p_dt = module_dt( + feats=feats_dt_packed, + s_trunk=s_trunk_dt, + z=z_dt, + ) + + torch.autograd.backward( + [q_dt, c_dt, p_dt], + [torch.ones_like(q_dt), torch.ones_like(c_dt), torch.ones_like(p_dt)], + ) + + # Assert forward output dtypes match serial reference + assert q_dt.dtype == q_serial_dtype, f"q dtype mismatch: DTensor {q_dt.dtype} vs serial {q_serial_dtype}" + assert c_dt.dtype == c_serial_dtype, f"c dtype mismatch: DTensor {c_dt.dtype} vs serial {c_serial_dtype}" + assert p_dt.dtype == p_serial_dtype, f"p dtype mismatch: DTensor {p_dt.dtype} vs serial {p_serial_dtype}" + + # Assert input gradient dtypes match serial reference + assert s_trunk_dt.grad is not None, "s_trunk_dt.grad is None" + assert z_dt.grad is not None, "z_dt.grad is None" + assert ( + s_trunk_dt.grad.dtype == serial_grad_dtypes["s_trunk"] + ), f"s_trunk grad dtype mismatch: DTensor {s_trunk_dt.grad.dtype} vs serial {serial_grad_dtypes['s_trunk']}" + assert ( + z_dt.grad.dtype == serial_grad_dtypes["z"] + ), f"z grad dtype mismatch: DTensor {z_dt.grad.dtype} vs serial {serial_grad_dtypes['z']}" + + # Assert parameter gradient dtypes match serial reference + for name, param in module_dt.named_parameters(): + if name in serial_param_grad_dtypes: + grad = param.grad + if grad is None: + continue + if hasattr(grad, "full_tensor"): + grad_dtype = grad.full_tensor().dtype + else: + grad_dtype = grad.dtype + assert ( + grad_dtype == serial_param_grad_dtypes[name] + ), f"param '{name}' grad dtype mismatch: DTensor {grad_dtype} vs serial {serial_param_grad_dtypes[name]}" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [((1, (1, 1)), True, "cuda", "ENV")], + indirect=["setup_env"], + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +def test_atom_encoder_wb_autocast_bf16(setup_env): + """Test DTensor AtomEncoder output dtypes under autocast bf16. + + Verifies that DTensor AtomEncoder produces the same output dtypes as the + V2 serial AtomEncoder when both run under torch.autocast("cuda", dtype=torch.bfloat16). + Uses dp=1, cp=(1,1) (1 GPU) since this is a dtype consistency test, not a CP correctness test. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + seed_by_rank(0, seed=seed) + + B = 1 * grid_group_sizes["dp"] + W = 32 + H = 128 + val_init_min_max = (-0.1, 0.1) + dtype = torch.float32 + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + size_cp = grid_group_sizes["cp"][0] + N_tokens = 30 * size_cp + N_atoms_raw = N_tokens * n_atoms_per_token_max + N_atoms = ((N_atoms_raw + W - 1) // W) * W + N_msa = 1 + + atom_s = 8 + atom_z = 8 + token_s = 4 + token_z = 4 + + from boltz.data import const as boltz_const + + atom_feature_dim = 3 + 1 + boltz_const.num_elements + 4 * 64 + + selected_keys = list(_selected_atom_keys) + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=selected_keys, + ) + feats = {k: v.to(dtype=dtype) if v.dtype.is_floating_point else v for k, v in feats.items()} + + s_trunk = torch.empty((B, N_tokens, token_s), device=device_type, dtype=dtype) + z = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype) + init_tensors_uniform([s_trunk, z], low=val_init_min_max[0], high=val_init_min_max[1]) + + reference_module = SerialAtomEncoderV2( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=atom_feature_dim, + structure_prediction=True, + ).to(device=device_type, dtype=dtype) + reference_module.eval() + init_module_params_uniform(reference_module, low=val_init_min_max[0], high=val_init_min_max[1]) + reference_module.apply(SetModuleInfValues()) + layer_state_dict = reference_module.state_dict() + + # Serial forward + backward under autocast (in main process) + reference_module.train() + s_trunk_serial = s_trunk.detach().clone().requires_grad_(True) + z_serial = z.detach().clone().requires_grad_(True) + + with torch.autocast("cuda", dtype=torch.bfloat16): + q_serial, c_serial, p_serial, _ = reference_module( + feats={k: v.clone() for k, v in feats.items()}, + s_trunk=s_trunk_serial, + z=z_serial, + ) + + torch.autograd.backward( + [q_serial, c_serial, p_serial], + [torch.ones_like(q_serial), torch.ones_like(c_serial), torch.ones_like(p_serial)], + ) + + serial_grad_dtypes = { + "s_trunk": s_trunk_serial.grad.dtype, + "z": z_serial.grad.dtype, + } + serial_param_grad_dtypes = {} + for name, param in reference_module.named_parameters(): + if param.grad is not None: + serial_param_grad_dtypes[name] = param.grad.dtype + + spawn_multiprocessing( + parallel_assert_atom_encoder_wb_autocast_bf16, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + atom_s, + atom_z, + token_s, + token_z, + atom_feature_dim, + W, + H, + {k: v.detach().cpu() for k, v in layer_state_dict.items()}, + {k: v.detach().cpu() for k, v in feats.items()}, + s_trunk.detach().cpu(), + z.detach().cpu(), + q_serial.dtype, + c_serial.dtype, + p_serial.dtype, + serial_grad_dtypes, + serial_param_grad_dtypes, + ) diff --git a/tests/distributed/model/modules/test_dtensor_atom_transformer.py b/tests/distributed/model/modules/test_dtensor_atom_transformer.py new file mode 100644 index 000000000..d2bdad86f --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_atom_transformer.py @@ -0,0 +1,670 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor AtomTransformer module (window batching). + +Tests both Boltz-1x and Boltz-2 serial AtomTransformer modules against the +unified DTensor AtomTransformer implementation, verifying forward and backward +equivalence. + +""" + +import pytest +import torch +from torch.distributed.tensor import ( + DTensor, + Replicate, + Shard, + distribute_tensor, +) + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.utils import convert_single_repr_to_window_batched_key +from boltz.distributed.model.modules.transformers import ( + AtomTransformer as DistributedAtomTransformer, +) +from boltz.model.modules.encoders import get_indexing_matrix as get_indexing_matrix_v1 +from boltz.model.modules.encoders import single_to_keys as single_to_keys_v1 +from boltz.model.modules.encodersv2 import get_indexing_matrix as get_indexing_matrix_v2 +from boltz.model.modules.encodersv2 import single_to_keys as single_to_keys_v2 +from boltz.model.modules.transformers import AtomTransformer as SerialAtomTransformerBoltz1 +from boltz.model.modules.transformersv2 import AtomTransformer as SerialAtomTransformerBoltz2 +from boltz.testing.utils import ( + SetModuleInfValues, + assert_all_identical, + assert_tensors_identical, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_assert_atom_transformer( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + serial_module_version: str, + dtype: torch.dtype, + multiplicity: int, + depth: int, + heads: int, + dim: int, + dim_single_cond: int, + dim_pairwise: int, + W: int, + H: int, + layer_state_dict, + q_global_host: torch.Tensor, + c_global_host: torch.Tensor, + p_global_host: torch.Tensor, + mask_global_host: torch.Tensor, + d_out_global_host: torch.Tensor, + out_expected_global_host: torch.Tensor, + d_q_expected_global_host: torch.Tensor, + d_c_expected_global_host: torch.Tensor, + d_p_expected_global_host: torch.Tensor, + expected_param_grads_global_host_dict: dict[str, torch.Tensor], +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Recreate serial module from state dict + if serial_module_version == "boltz1": + module_serial = SerialAtomTransformerBoltz1( + depth=depth, + heads=heads, + dim=dim, + dim_single_cond=dim_single_cond, + dim_pairwise=dim_pairwise, + attn_window_queries=W, + attn_window_keys=H, + ) + else: + module_serial = SerialAtomTransformerBoltz2( + attn_window_queries=W, + attn_window_keys=H, + depth=depth, + heads=heads, + dim=dim, + dim_single_cond=dim_single_cond, + ) + + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.to(device=manager.device, dtype=dtype).train() + module_serial.apply(SetModuleInfValues()) + + module = DistributedAtomTransformer( + layer=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # AtomTransformer inputs are in single repr view: (B, N, D) where N = K * W + placements_single = ( + (Shard(0), Shard(1), Replicate()) if manager.device_mesh_subgroups.ndim == 3 else (Shard(0), Shard(1)) + ) + # For pair representation p: (B, K, W, H, D) + placements_pair = ( + (Shard(0), Shard(1), Replicate()) if manager.device_mesh_subgroups.ndim == 3 else (Shard(0), Shard(1)) + ) + + q_dt = distribute_tensor( + q_global_host.to(device=manager.device, dtype=dtype), manager.device_mesh_subgroups, placements_single + ).requires_grad_(True) + c_dt = distribute_tensor( + c_global_host.to(device=manager.device, dtype=dtype), manager.device_mesh_subgroups, placements_single + ).requires_grad_(True) + p_dt = distribute_tensor( + p_global_host.to(device=manager.device, dtype=dtype), manager.device_mesh_subgroups, placements_pair + ).requires_grad_(True) + mask_dt = distribute_tensor( + mask_global_host.to(device=manager.device, dtype=dtype), manager.device_mesh_subgroups, placements_single + ).requires_grad_(False) + + # Copies to ensure inputs aren't modified in-place + q_dt_copy = q_dt.detach().clone().requires_grad_(True) + c_dt_copy = c_dt.detach().clone().requires_grad_(True) + p_dt_copy = p_dt.detach().clone().requires_grad_(True) + mask_dt_copy = mask_dt.detach().clone() + + # multiplicity must be 1 for window batching + out_dt: DTensor = module( + q=q_dt, + c=c_dt, + p=p_dt, + mask=mask_dt, + multiplicity=1, + model_cache=None, + pair_mask=None, + ) + + # Ensure no input mutation + assert_tensors_identical(q_dt_copy.to_local(), q_dt.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(c_dt_copy.to_local(), c_dt.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(p_dt_copy.to_local(), p_dt.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(mask_dt_copy.to_local(), mask_dt.to_local(), check_grad=False, check_grad_fn=False) + + # Forward compare (local shards + full gather) + out_expected_dt = distribute_tensor( + out_expected_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ) + torch.testing.assert_close(out_dt.to_local(), out_expected_dt.to_local()) + torch.testing.assert_close(out_dt.full_tensor().cpu(), out_expected_global_host.to(dtype=dtype)) + + # Backward compare + d_out_dt = distribute_tensor( + d_out_global_host.to(device=manager.device, dtype=dtype), manager.device_mesh_subgroups, placements_single + ) + d_out_dt_copy = d_out_dt.detach().clone() + out_dt.backward(d_out_dt) + assert_tensors_identical(d_out_dt_copy.to_local(), d_out_dt.to_local(), check_grad=False, check_grad_fn=False) + + # Input grad checks + d_q_expected_dt = distribute_tensor( + d_q_expected_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ) + d_c_expected_dt = distribute_tensor( + d_c_expected_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ) + torch.testing.assert_close(q_dt.grad.to_local(), d_q_expected_dt.to_local()) + torch.testing.assert_close(c_dt.grad.to_local(), d_c_expected_dt.to_local()) + + p_grad_full = p_dt.grad.full_tensor() + torch.testing.assert_close(p_grad_full.cpu(), d_p_expected_global_host.to(dtype=dtype)) + + # Mask-based gradient checks: gradients at invalid (masked) positions should be zero + mask_dt_local = ( + mask_dt.to_local().repeat_interleave(multiplicity, 0).unsqueeze(-1).bool() + ) # (B*mult_local, N_local, 1) + torch.testing.assert_close(q_dt.grad.to_local() * ~(mask_dt_local), torch.zeros_like(q_dt.grad.to_local())) + torch.testing.assert_close(c_dt.grad.to_local() * ~(mask_dt_local), torch.zeros_like(c_dt.grad.to_local())) + + # For pair repr grad masking, compute key mask from single repr mask (B, N) -> (B, K, H) + mask_key_dt = convert_single_repr_to_window_batched_key(mask_dt, W=W, H=H) # DTensor (B, K, H) + # (B, K_local, H) -> (B, K_local, 1, H, 1) for masking pair view (B, K, W, H, D) + mask_key_local = mask_key_dt.to_local().unsqueeze(2).unsqueeze(-1).bool() + torch.testing.assert_close(p_dt.grad.to_local() * ~(mask_key_local), torch.zeros_like(p_dt.grad.to_local())) + + # Parameter grads (gather full tensors) + for name, grad_expected_global in expected_param_grads_global_host_dict.items(): + grad_param = get_param_by_key(module, name).grad + assert grad_param is not None, f"Missing grad for param {name}" + + if isinstance(grad_param, DTensor): + grad_global_host = grad_param.full_tensor().cpu() + grad_to_check = grad_param.full_tensor() + else: + grad_global_host = grad_param.detach().cpu() + grad_to_check = grad_param + + torch.testing.assert_close(grad_global_host, grad_expected_global.to(dtype=dtype)) + assert_all_identical(grad_to_check, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +def _create_serial_reference( + serial_module_version: str, + depth: int, + heads: int, + dim: int, + dim_single_cond: int, + dim_pairwise: int, + W: int, + H: int, + B: int, + K: int, + multiplicity: int, + device_type: str, + dtype: torch.dtype, + val_init_min_max: tuple[float, float], +): + """Create serial module and compute reference forward/backward outputs.""" + N = K * W + + if serial_module_version == "boltz1": + get_indexing_matrix = get_indexing_matrix_v1 + single_to_keys = single_to_keys_v1 + reference_module = SerialAtomTransformerBoltz1( + depth=depth, + heads=heads, + dim=dim, + dim_single_cond=dim_single_cond, + dim_pairwise=dim_pairwise, + attn_window_queries=W, + attn_window_keys=H, + ).to(device=device_type, dtype=dtype) + else: + get_indexing_matrix = get_indexing_matrix_v2 + single_to_keys = single_to_keys_v2 + reference_module = SerialAtomTransformerBoltz2( + attn_window_queries=W, + attn_window_keys=H, + depth=depth, + heads=heads, + dim=dim, + dim_single_cond=dim_single_cond, + ).to(device=device_type, dtype=dtype) + + reference_module.train() + + # Inputs in single repr view: (B*M, N, D) for q, c + q = torch.empty((B * multiplicity, N, dim), device=device_type, dtype=dtype, requires_grad=True) + c = torch.empty((B * multiplicity, N, dim_single_cond), device=device_type, dtype=dtype, requires_grad=True) + mask = torch.ones((B, N), device=device_type, dtype=dtype) + mask[0, N // 2 :] = 0 # mask out second half for first sample + + # Pair repr in window-batched view: (B, K, W, H, D_z) + # Boltz-1: D_z = dim_pairwise (projected per layer) + # Boltz-2: D_z = num_heads * depth (pre-computed bias, split across layers) + z_last_dim = dim_pairwise if serial_module_version == "boltz1" else heads * depth + p = torch.empty((B, K, W, H, z_last_dim), device=device_type, dtype=dtype, requires_grad=True) + + init_tensors_uniform([q, c, p], low=val_init_min_max[0], high=val_init_min_max[1]) + init_module_params_uniform(reference_module, low=val_init_min_max[0], high=val_init_min_max[1]) + reference_module.apply(SetModuleInfValues()) + + layer_state_dict = reference_module.state_dict() + + # Serial forward + q_serial = q.detach().clone().requires_grad_(True) + c_serial = c.detach().clone().requires_grad_(True) + p_serial = p.detach().clone().requires_grad_(True) + mask_multiplexed = mask.repeat_interleave(multiplicity, 0) + + # to_keys for serial AtomTransformer + indexing_matrix = get_indexing_matrix(K=K, W=W, H=H, device=device_type).to(dtype=dtype) + + def to_keys_serial(x: torch.Tensor) -> torch.Tensor: + return single_to_keys(x, indexing_matrix, W=W, H=H) + + if serial_module_version == "boltz1": + out_expected = reference_module( + q=q_serial, + c=c_serial, + p=p_serial, + mask=mask_multiplexed, + multiplicity=multiplicity, + to_keys=to_keys_serial, + model_cache=None, + ) + else: + out_expected = reference_module( + q=q_serial, + c=c_serial, + bias=p_serial, + to_keys=to_keys_serial, + mask=mask_multiplexed, + multiplicity=multiplicity, + ) + + d_out = torch.empty_like(out_expected) + init_tensors_uniform([d_out], low=val_init_min_max[0], high=val_init_min_max[1]) + d_out = d_out * mask_multiplexed.unsqueeze(-1) + + out_expected.backward(d_out) + + return ( + layer_state_dict, + q.detach().cpu(), + c.detach().cpu(), + p.detach().cpu(), + mask.detach().cpu(), + d_out.detach().cpu(), + out_expected.detach().cpu(), + q_serial.grad.detach().cpu(), + c_serial.grad.detach().cpu(), + p_serial.grad.detach().cpu(), + { + name: param.grad.detach().cpu() + for name, param in reference_module.named_parameters() + if param.grad is not None + }, + ) + + +@pytest.mark.slow +@pytest.mark.parametrize("multiplicity", [1, 4], ids=lambda m: f"multiplicity:{m}") +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize("serial_module_version", ["boltz1", "boltz2"]) +def test_atom_transformer(setup_env, multiplicity: int, serial_module_version: str): + """Test AtomTransformer DTensor vs serial equivalence (window batching).""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + dtype = torch.float32 + seed = 42 + seed_by_rank(0, seed=seed) + + B = 2 * grid_group_sizes["dp"] + K = 10 * grid_group_sizes["cp"][0] # windows, divisible by cp size + W = 32 + H = 128 + val_init_min_max = (-0.2, 0.2) + + dim = 32 + dim_single_cond = dim + dim_pairwise = 32 + heads = 2 + depth = 2 + + ( + layer_state_dict, + q_host, + c_host, + p_host, + mask_host, + d_out_host, + out_expected_host, + d_q_expected_host, + d_c_expected_host, + d_p_expected_host, + expected_param_grads_host, + ) = _create_serial_reference( + serial_module_version=serial_module_version, + depth=depth, + heads=heads, + dim=dim, + dim_single_cond=dim_single_cond, + dim_pairwise=dim_pairwise, + W=W, + H=H, + B=B, + K=K, + multiplicity=multiplicity, + device_type=device_type, + dtype=dtype, + val_init_min_max=val_init_min_max, + ) + + spawn_multiprocessing( + parallel_assert_atom_transformer, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_module_version, + dtype, + multiplicity, + depth, + heads, + dim, + dim_single_cond, + dim_pairwise, + W, + H, + layer_state_dict, + q_host, + c_host, + p_host, + mask_host, + d_out_host, + out_expected_host, + d_q_expected_host, + d_c_expected_host, + d_p_expected_host, + expected_param_grads_host, + ) + + +# ====================================================================== +# Test 2: AtomTransformer under autocast bf16 (dtype-only comparison) +# ====================================================================== + + +def parallel_assert_atom_transformer_autocast_bf16( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + depth: int, + heads: int, + dim: int, + dim_single_cond: int, + W: int, + H: int, + layer_state_dict, + q_global_host: torch.Tensor, + c_global_host: torch.Tensor, + p_global_host: torch.Tensor, + mask_global_host: torch.Tensor, + serial_output_dtype: torch.dtype, + serial_grad_dtypes: dict[str, torch.dtype], + serial_param_grad_dtypes: dict[str, torch.dtype], +): + """Parallel worker for bf16 autocast dtype test on AtomTransformer.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + dtype = torch.float32 + + module_serial = SerialAtomTransformerBoltz2( + attn_window_queries=W, + attn_window_keys=H, + depth=depth, + heads=heads, + dim=dim, + dim_single_cond=dim_single_cond, + ) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.to(device=manager.device, dtype=dtype).train() + module_serial.apply(SetModuleInfValues()) + + module = DistributedAtomTransformer( + layer=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + placements_single = ( + (Shard(0), Shard(1), Replicate()) if manager.device_mesh_subgroups.ndim == 3 else (Shard(0), Shard(1)) + ) + placements_pair = ( + (Shard(0), Shard(1), Replicate()) if manager.device_mesh_subgroups.ndim == 3 else (Shard(0), Shard(1)) + ) + + q_dt = distribute_tensor( + q_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ).requires_grad_(True) + c_dt = distribute_tensor( + c_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ).requires_grad_(True) + p_dt = distribute_tensor( + p_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_pair, + ).requires_grad_(True) + mask_dt = distribute_tensor( + mask_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ).requires_grad_(False) + + with torch.autocast("cuda", dtype=torch.bfloat16): + out_dt = module(q=q_dt, c=c_dt, p=p_dt, mask=mask_dt, multiplicity=1, model_cache=None, pair_mask=None) + + torch.autograd.backward([out_dt], [torch.ones_like(out_dt)]) + + assert ( + out_dt.dtype == serial_output_dtype + ), f"out dtype mismatch: DTensor {out_dt.dtype} vs serial {serial_output_dtype}" + + for name, dt_tensor in [("q", q_dt), ("c", c_dt)]: + assert dt_tensor.grad is not None, f"{name} grad is None" + assert ( + dt_tensor.grad.dtype == serial_grad_dtypes[name] + ), f"{name} grad dtype mismatch: DTensor {dt_tensor.grad.dtype} vs serial {serial_grad_dtypes[name]}" + + for name, param in module.named_parameters(): + if name in serial_param_grad_dtypes and param.grad is not None: + grad_dtype = param.grad.full_tensor().dtype if hasattr(param.grad, "full_tensor") else param.grad.dtype + assert ( + grad_dtype == serial_param_grad_dtypes[name] + ), f"param '{name}' grad dtype mismatch: DTensor {grad_dtype} vs serial {serial_param_grad_dtypes[name]}" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [((1, (1, 1)), True, "cuda", "ENV")], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +def test_atom_transformer_autocast_bf16(setup_env): + """Test DTensor AtomTransformer output dtypes under autocast bf16.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + seed_by_rank(0, seed=seed) + dtype = torch.float32 + + B = 1 + K = 10 + W = 32 + H = 128 + val_init_min_max = (-0.2, 0.2) + N = K * W + multiplicity = 1 + + dim = 32 + dim_single_cond = dim + heads = 2 + depth = 2 + z_last_dim = heads * depth + + reference_module = SerialAtomTransformerBoltz2( + attn_window_queries=W, + attn_window_keys=H, + depth=depth, + heads=heads, + dim=dim, + dim_single_cond=dim_single_cond, + ).to(device=device_type, dtype=dtype) + reference_module.train() + + q = torch.empty((B * multiplicity, N, dim), device=device_type, dtype=dtype, requires_grad=True) + c = torch.empty((B * multiplicity, N, dim_single_cond), device=device_type, dtype=dtype, requires_grad=True) + mask = torch.ones((B, N), device=device_type, dtype=dtype) + p = torch.empty((B, K, W, H, z_last_dim), device=device_type, dtype=dtype, requires_grad=True) + init_tensors_uniform([q, c, p], low=val_init_min_max[0], high=val_init_min_max[1]) + init_module_params_uniform(reference_module, low=val_init_min_max[0], high=val_init_min_max[1]) + reference_module.apply(SetModuleInfValues()) + layer_state_dict = reference_module.state_dict() + + q_serial = q.detach().clone().requires_grad_(True) + c_serial = c.detach().clone().requires_grad_(True) + p_serial = p.detach().clone().requires_grad_(True) + + indexing_matrix = get_indexing_matrix_v2(K=K, W=W, H=H, device=device_type).to(dtype=dtype) + + with torch.autocast("cuda", dtype=torch.bfloat16): + out_serial = reference_module( + q=q_serial, + c=c_serial, + bias=p_serial, + to_keys=lambda x: single_to_keys_v2(x, indexing_matrix, W=W, H=H), + mask=mask.repeat_interleave(multiplicity, 0), + multiplicity=multiplicity, + ) + + torch.autograd.backward([out_serial], [torch.ones_like(out_serial)]) + + serial_output_dtype = out_serial.dtype + serial_grad_dtypes = {"q": q_serial.grad.dtype, "c": c_serial.grad.dtype} + serial_param_grad_dtypes = { + name: param.grad.dtype for name, param in reference_module.named_parameters() if param.grad is not None + } + + spawn_multiprocessing( + parallel_assert_atom_transformer_autocast_bf16, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + depth, + heads, + dim, + dim_single_cond, + W, + H, + {k: v.detach().cpu() for k, v in layer_state_dict.items()}, + q.detach().cpu(), + c.detach().cpu(), + p.detach().cpu(), + mask.detach().cpu(), + serial_output_dtype, + serial_grad_dtypes, + serial_param_grad_dtypes, + ) diff --git a/tests/distributed/model/modules/test_dtensor_conditioned_transition_block.py b/tests/distributed/model/modules/test_dtensor_conditioned_transition_block.py new file mode 100644 index 000000000..02b40f0f7 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_conditioned_transition_block.py @@ -0,0 +1,214 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor ConditionedTransitionBlock module. + +Tests both Boltz-1x and Boltz-2 serial ConditionedTransitionBlock modules against +the unified DTensor implementation, verifying forward and backward equivalence. + +""" + +import pytest +import torch +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.transformers import ConditionedTransitionBlock as DTensorCTB +from boltz.model.modules.transformers import ConditionedTransitionBlock as CTBSerialBoltz1 +from boltz.model.modules.transformersv2 import ConditionedTransitionBlock as CTBSerialBoltz2 +from boltz.testing.utils import ( + assert_tensors_identical, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_assert_conditioned_transition_block( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + serial_module_version: str, + dim_single: int, + dim_single_cond: int, + B: int, + N: int, + layer_state_dict, + a_global_host: torch.Tensor, + s_global_host: torch.Tensor, + d_out_global_host: torch.Tensor, + out_expected_global_host: torch.Tensor, + d_a_expected_global_host: torch.Tensor, + d_s_expected_global_host: torch.Tensor, + expected_param_grads_global_host_dict: dict[str, torch.Tensor], +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Create serial module from state dict + CTBSerial = CTBSerialBoltz1 if serial_module_version == "boltz1" else CTBSerialBoltz2 + module_serial = CTBSerial(dim_single=dim_single, dim_single_cond=dim_single_cond) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.to(device=manager.device).train() + + # Create DTensor module from serial + module_dt = DTensorCTB( + conditioned_trans_block=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # Placements + placements = (Shard(0), Shard(1), Replicate()) if manager.device_mesh_subgroups.ndim == 3 else (Shard(0), Shard(1)) + + a_dt = distribute_tensor( + a_global_host.to(device=manager.device), manager.device_mesh_subgroups, placements + ).requires_grad_(True) + s_dt = distribute_tensor( + s_global_host.to(device=manager.device), manager.device_mesh_subgroups, placements + ).requires_grad_(True) + + # Copies to verify inputs aren't modified + a_dt_copy = a_dt.detach().clone().requires_grad_(True) + s_dt_copy = s_dt.detach().clone().requires_grad_(True) + + # Forward pass + out_dt: DTensor = module_dt(a_dt, s_dt) + + # Ensure no input mutation + assert_tensors_identical(a_dt_copy.to_local(), a_dt.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(s_dt_copy.to_local(), s_dt.to_local(), check_grad=False, check_grad_fn=False) + + # Forward compare (full gather) + torch.testing.assert_close(out_dt.full_tensor().cpu(), out_expected_global_host) + + # Backward pass + d_out_dt = distribute_tensor(d_out_global_host.to(device=manager.device), manager.device_mesh_subgroups, placements) + out_dt.backward(d_out_dt) + + # Compare input gradients + torch.testing.assert_close(a_dt.grad.full_tensor().cpu(), d_a_expected_global_host) + torch.testing.assert_close(s_dt.grad.full_tensor().cpu(), d_s_expected_global_host) + + # Compare parameter gradients + for name, param in module_dt.named_parameters(): + assert param.grad is not None, f"Parameter {name} has no gradient" + expected_grad = expected_param_grads_global_host_dict[name] + torch.testing.assert_close( + param.grad.full_tensor().cpu(), + expected_grad, + msg=lambda m: f"Parameter gradient mismatch for {name}: {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize("serial_module_version", ["boltz1", "boltz2"]) +def test_conditioned_transition_block(setup_env, serial_module_version: str): + """Test ConditionedTransitionBlock DTensor vs serial equivalence.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + # Module dimensions + dim_single = 64 + dim_single_cond = 32 + + # Data dimensions + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 8 + + seed_by_rank(0, seed=42) + + # Create serial module + CTBSerial = CTBSerialBoltz1 if serial_module_version == "boltz1" else CTBSerialBoltz2 + module_serial = CTBSerial(dim_single=dim_single, dim_single_cond=dim_single_cond) + module_serial = module_serial.train() + layer_state_dict = module_serial.state_dict() + + # Create input tensors + a_global = torch.randn(B, N, dim_single, requires_grad=True) + s_global = torch.randn(B, N, dim_single_cond, requires_grad=True) + + # Serial forward pass + out_serial = module_serial(a_global, s_global) + + # Create upstream gradient + d_out = torch.randn_like(out_serial) + + # Serial backward pass + out_serial.backward(d_out) + + # Collect expected results + out_expected = out_serial.detach().clone().cpu() + d_a_expected = a_global.grad.detach().clone().cpu() + d_s_expected = s_global.grad.detach().clone().cpu() + + expected_param_grads = {} + for name, param in module_serial.named_parameters(): + assert param.grad is not None, f"Serial parameter {name} has no gradient" + expected_param_grads[name] = param.grad.detach().clone().cpu() + + # Launch parallel test + spawn_multiprocessing( + parallel_assert_conditioned_transition_block, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_module_version, + dim_single, + dim_single_cond, + B, + N, + layer_state_dict, + a_global.detach().clone().cpu(), + s_global.detach().clone().cpu(), + d_out.detach().clone().cpu(), + out_expected, + d_a_expected, + d_s_expected, + expected_param_grads, + ) diff --git a/tests/distributed/model/modules/test_dtensor_confidence_utils.py b/tests/distributed/model/modules/test_dtensor_confidence_utils.py new file mode 100644 index 000000000..3541acb46 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_confidence_utils.py @@ -0,0 +1,731 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor compute_aggregated_metric and compute_ptms functions. + +Ported from boltz-1x-cp with import path updates for the Boltz-2 branch. +Tests verify both forward and backward passes across different device mesh +configurations and placements. +""" + +from math import gcd + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.data.feature.featurizer import BoltzFeaturizer +from boltz.data.module.inference import load_input +from boltz.data.tokenize.boltz import BoltzTokenizer +from boltz.distributed.comm import TransposeComm +from boltz.distributed.data.feature.featurizer_utils import get_num_atoms_tokens +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.confidence_utils import ( + CHAIN_IPTM_SENTINEL, + compute_aggregated_metric, + compute_ptms, +) +from boltz.model.layers.confidence_utils import ( + compute_aggregated_metric as serial_compute_aggregated_metric, +) +from boltz.model.layers.confidence_utils import ( + compute_ptms as serial_compute_ptms, +) +from boltz.testing.utils import ( + assert_tensors_identical, + distribute_atom_features, + init_tensors_uniform, + random_features, + spawn_multiprocessing, +) + + +def _assert_nontrivial_metric(metric: torch.Tensor, metric_name: str) -> None: + """Guard against degenerate metrics that can make parity checks trivial.""" + metric_cpu = metric.detach().cpu().to(torch.float32) + if metric_cpu.numel() == 0: + raise AssertionError(f"{metric_name} is empty") + if not torch.isfinite(metric_cpu).all(): + raise AssertionError(f"{metric_name} contains non-finite values") + if torch.all(metric_cpu == 0): + raise AssertionError( + f"{metric_name} is trivially all zeros (likely empty effective mask support in compute_ptms)" + ) + if torch.all(metric_cpu == 1): + raise AssertionError(f"{metric_name} is trivially all ones") + if metric_cpu.numel() > 1 and torch.allclose(metric_cpu, metric_cpu.reshape(-1)[0]): + raise AssertionError(f"{metric_name} is a trivial constant value: {metric_cpu.reshape(-1)[0].item():.6f}") + + +# --------------------------------------------------------------------------- +# compute_aggregated_metric tests +# --------------------------------------------------------------------------- + + +def parallel_assert_compute_aggregated_metric( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_logits_global_host, + output_expected_global_host, + d_output_expected_global_host, + d_input_logits_expected_global_host, + end: float, +): + """Compare DTensor compute_aggregated_metric with serial reference.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + input_logits_dtensor = distribute_tensor( + input_logits_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + ).requires_grad_(True) + + output_placements = placements + d_output_expected_dtensor = distribute_tensor( + d_output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=output_placements, + ) + output_expected_dtensor = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=output_placements, + src_data_rank=None, + ) + d_input_logits_expected_dtensor = distribute_tensor( + d_input_logits_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements, + src_data_rank=None, + ) + + input_logits_dtensor_copy = input_logits_dtensor.detach().clone().requires_grad_(True) + + output_dtensor_result = compute_aggregated_metric(input_logits_dtensor, end=end) + + assert_tensors_identical( + input_logits_dtensor_copy.to_local(), + input_logits_dtensor.to_local(), + check_grad=False, + check_grad_fn=False, + ) + + assert ( + output_dtensor_result.shape == output_expected_dtensor.shape + ), f"Output shape mismatch: {output_dtensor_result.shape} vs {output_expected_dtensor.shape}" + assert ( + output_dtensor_result.stride() == output_expected_dtensor.stride() + ), f"Output stride mismatch: {output_dtensor_result.stride()} vs {output_expected_dtensor.stride()}" + torch.testing.assert_close( + output_dtensor_result.to_local(), + output_expected_dtensor.to_local(), + ) + + d_output_expected_dtensor_copy = d_output_expected_dtensor.detach().clone() + output_dtensor_result.backward(d_output_expected_dtensor) + + assert_tensors_identical( + d_output_expected_dtensor_copy.to_local(), + d_output_expected_dtensor.to_local(), + ) + + assert input_logits_dtensor.grad is not None, "Input gradient should not be None" + assert ( + input_logits_dtensor.grad.shape == d_input_logits_expected_dtensor.shape + ), f"Gradient shape mismatch: {input_logits_dtensor.grad.shape} vs {d_input_logits_expected_dtensor.shape}" + assert ( + input_logits_dtensor.grad.stride() == d_input_logits_expected_dtensor.stride() + ), f"Gradient stride mismatch: {input_logits_dtensor.grad.stride()} vs {d_input_logits_expected_dtensor.stride()}" + torch.testing.assert_close( + input_logits_dtensor.grad.to_local(), + d_input_logits_expected_dtensor.to_local(), + ) + + output_global_result_host = output_dtensor_result.full_tensor().cpu() + d_input_logits_global_result_host = input_logits_dtensor.grad.full_tensor().cpu() + + torch.testing.assert_close( + output_global_result_host, + output_expected_global_host, + ) + torch.testing.assert_close( + d_input_logits_global_result_host, + d_input_logits_expected_global_host, + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp2-cp2x2"], +) +@pytest.mark.parametrize( + "placements,input_shape_type", + [ + ((Shard(0), Shard(1), Shard(2)), "pair"), + ((Shard(0), Shard(1), Replicate()), "single"), + ], + ids=["pair-shard-all", "single-shard-B-N"], +) +@pytest.mark.parametrize("end", [1.0, 32.0], ids=["end:1.0", "end:32.0"]) +def test_compute_aggregated_metric_parallel(setup_env, placements, input_shape_type, end): + """Test compute_aggregated_metric with DTensor across multiple configurations.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 10 + + num_bins = 50 if end == 1.0 else 64 + + if input_shape_type == "pair": + input_shape = (B, N, N, num_bins) + else: + input_shape = (B, N, num_bins) + + init_min, init_max = -1.0, 1.0 + output_shape = input_shape[:-1] + input_logits_global = torch.empty(input_shape, requires_grad=True, device=device_type) + d_output_expected_global = torch.empty(output_shape, device=device_type) + torch.manual_seed(42) + init_tensors_uniform([input_logits_global, d_output_expected_global], low=init_min, high=init_max) + + input_logits_global_host = input_logits_global.detach().clone().cpu() + output_expected_global = serial_compute_aggregated_metric(input_logits_global, end=end) + output_expected_global_host = output_expected_global.detach().clone().cpu() + d_output_expected_global_host = d_output_expected_global.detach().clone().cpu() + output_expected_global.backward(d_output_expected_global) + + spawn_multiprocessing( + parallel_assert_compute_aggregated_metric, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + placements, + input_logits_global_host, + output_expected_global_host, + d_output_expected_global_host, + input_logits_global.grad.detach().clone().cpu(), + end, + ) + + +def parallel_assert_compute_aggregated_metric_error_cases( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, +): + """Test error cases for compute_aggregated_metric.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + B = 2 * grid_group_sizes["dp"] + N = grid_group_sizes["cp"][0] * 10 + num_bins = 50 + + regular_tensor = torch.empty((B, N, num_bins), device=manager.device, requires_grad=True) + input_tensor = torch.empty((B, N, num_bins), device=manager.device, requires_grad=True) + torch.manual_seed(42) + init_tensors_uniform([regular_tensor, input_tensor], low=-1.0, high=1.0) + + with pytest.raises(TypeError, match="Expected DTensor"): + compute_aggregated_metric(regular_tensor, end=1.0) + + sharded_bins_dtensor = distribute_tensor( + input_tensor, + device_mesh=manager.device_mesh_subgroups, + placements=(Shard(0), Shard(1), Shard(2)), + ) + with pytest.raises(ValueError, match="bins dimension.*must not be sharded"): + compute_aggregated_metric(sharded_bins_dtensor, end=1.0) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp1-cp2x2"], +) +def test_compute_aggregated_metric_error_cases(setup_env): + """Test error cases for compute_aggregated_metric function.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + spawn_multiprocessing( + parallel_assert_compute_aggregated_metric_error_cases, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +def parallel_assert_compute_aggregated_metric_no_grad( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, +): + """Test compute_aggregated_metric with requires_grad=False.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + B = 2 * grid_group_sizes["dp"] + N = grid_group_sizes["cp"][0] * 10 + num_bins = 50 + + torch.manual_seed(42) + input_tensor = torch.empty((B, N, num_bins), device=manager.device, requires_grad=False) + init_tensors_uniform([input_tensor], low=-1.0, high=1.0) + input_dtensor = distribute_tensor( + input_tensor, + device_mesh=manager.device_mesh_subgroups, + placements=(Shard(0), Shard(1), Replicate()), + ) + + output = compute_aggregated_metric(input_dtensor, end=1.0) + + assert output.shape == (B, N), f"Expected shape {(B, N)}, got {output.shape}" + assert not output.requires_grad, "Output should not require grad when input doesn't" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp2-cp2x2"], +) +def test_compute_aggregated_metric_no_grad(setup_env): + """Test compute_aggregated_metric with requires_grad=False.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + spawn_multiprocessing( + parallel_assert_compute_aggregated_metric_no_grad, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +# --------------------------------------------------------------------------- +# compute_ptms tests +# --------------------------------------------------------------------------- + + +def parallel_assert_compute_ptms( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + input_logits_global_host, + x_preds_global_host, + feats_global_host, + multiplicity: int, + assert_nontrivial: bool, +): + """Compare DTensor compute_ptms with serial reference on this rank's DP chunk only.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device_mesh = manager.device_mesh_subgroups + transpose_comm = TransposeComm(manager.group["cp"], manager.layout_subgroups["cp"]) + + placements_pair = (Shard(0), Shard(1), Shard(2)) + placements_single = (Shard(0), Shard(1), Replicate()) + + size_batch = feats_global_host["atom_pad_mask"].shape[0] + dp_size = manager.group["dp"].size() + local_batch_size = size_batch // dp_size + dp_rank = manager.group_rank["dp"] + dp_idx_str = dp_rank * local_batch_size + dp_idx_end = dp_idx_str + local_batch_size + + chunk_logits = input_logits_global_host.to(manager.device)[dp_idx_str * multiplicity : dp_idx_end * multiplicity] + chunk_x_preds = x_preds_global_host.to(manager.device)[dp_idx_str * multiplicity : dp_idx_end * multiplicity] + chunk_feats = {k: v.to(manager.device)[dp_idx_str:dp_idx_end].clone() for k, v in feats_global_host.items()} + ( + expected_ptm_chunk, + expected_iptm_chunk, + expected_ligand_iptm_chunk, + expected_protein_iptm_chunk, + expected_chain_pair_chunk, + ) = serial_compute_ptms(chunk_logits, chunk_x_preds, chunk_feats, multiplicity) + assert not expected_ptm_chunk.requires_grad + assert not expected_iptm_chunk.requires_grad + assert not expected_ligand_iptm_chunk.requires_grad + assert not expected_protein_iptm_chunk.requires_grad + for _idx1, chain_dict in expected_chain_pair_chunk.items(): + for _idx2, t in chain_dict.items(): + assert not t.requires_grad + if assert_nontrivial: + _assert_nontrivial_metric(expected_iptm_chunk.cpu(), "expected_iptm_chunk") + _assert_nontrivial_metric(expected_ligand_iptm_chunk.cpu(), "expected_ligand_iptm_chunk") + _assert_nontrivial_metric(expected_protein_iptm_chunk.cpu(), "expected_protein_iptm_chunk") + else: + if not torch.all(expected_iptm_chunk == 0): + _assert_nontrivial_metric(expected_iptm_chunk.cpu(), "expected_iptm_chunk") + if not torch.all(expected_ligand_iptm_chunk == 0): + _assert_nontrivial_metric(expected_ligand_iptm_chunk.cpu(), "expected_ligand_iptm_chunk") + if not torch.all(expected_protein_iptm_chunk == 0): + _assert_nontrivial_metric(expected_protein_iptm_chunk.cpu(), "expected_protein_iptm_chunk") + + logits_dtensor = distribute_tensor( + input_logits_global_host.to(manager.device), + device_mesh=device_mesh, + placements=placements_pair, + ) + x_preds_unflat = x_preds_global_host.unflatten(0, (size_batch, multiplicity)) + inputs_atom = { + "atom_counts_per_token": feats_global_host["atom_counts_per_token"].to(dtype=torch.int64), + "atom_to_token": feats_global_host["atom_to_token"].to(dtype=x_preds_global_host.dtype), + "atom_pad_mask": feats_global_host["atom_pad_mask"].to(dtype=x_preds_global_host.dtype), + "atom_resolved_mask": feats_global_host["atom_resolved_mask"].to(dtype=x_preds_global_host.dtype), + "frames_idx": feats_global_host["frames_idx"].to(dtype=torch.int64), + "x_preds_0": x_preds_unflat[:, 0].to(dtype=x_preds_global_host.dtype), + } + for i_mul in range(1, multiplicity): + inputs_atom[f"x_preds_{i_mul}"] = x_preds_unflat[:, i_mul].to(dtype=x_preds_global_host.dtype) + + placements_cp = { + "atom_counts_per_token": (Shard(0), Replicate()), + "atom_to_token": (Shard(0), Replicate()), + "atom_pad_mask": (Shard(0), Replicate()), + "atom_resolved_mask": (Shard(0), Replicate()), + "frames_idx": (Shard(1), Replicate()), + "x_preds_0": (Shard(0), Replicate()), + } + placements_dp_cp = { + "atom_to_token": (Shard(0), Shard(1), Replicate()), + "atom_pad_mask": (Shard(0), Shard(1), Replicate()), + "atom_resolved_mask": (Shard(0), Shard(1), Replicate()), + "frames_idx": (Shard(0), Shard(1), Replicate()), + "x_preds_0": (Shard(0), Shard(1), Replicate()), + } + for i_mul in range(1, multiplicity): + placements_cp[f"x_preds_{i_mul}"] = (Shard(0), Replicate()) + placements_dp_cp[f"x_preds_{i_mul}"] = (Shard(0), Shard(1), Replicate()) + + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + device_mesh, + manager.group["cp"], + multiplicities={"x_preds": multiplicity}, + ) + + x_preds_dtensor = feats_atom["x_preds"] + feats_dtensor = { + "frames_idx": feats_atom["frames_idx"], + "asym_id": distribute_tensor( + feats_global_host["asym_id"].to(manager.device), + device_mesh=device_mesh, + placements=placements_single, + ), + "atom_to_token": feats_atom["atom_to_token"], + "atom_pad_mask": feats_atom["atom_pad_mask"], + "atom_resolved_mask": feats_atom["atom_resolved_mask"], + "mol_type": distribute_tensor( + feats_global_host["mol_type"].to(manager.device), + device_mesh=device_mesh, + placements=placements_single, + ), + "token_pad_mask": distribute_tensor( + feats_global_host["token_pad_mask"].to(manager.device), + device_mesh=device_mesh, + placements=placements_single, + ), + } + + ptm, iptm, ligand_iptm, protein_iptm, chain_pair_iptm = compute_ptms( + logits_dtensor, + x_preds_dtensor, + feats_dtensor, + multiplicity, + transpose_comm, + ) + assert not ptm.requires_grad + assert not iptm.requires_grad + assert not ligand_iptm.requires_grad + assert not protein_iptm.requires_grad + for _idx1, chain_dict in chain_pair_iptm.items(): + for _idx2, dt in chain_dict.items(): + assert not dt.requires_grad + + torch.testing.assert_close( + ptm.to_local().cpu().to(dtype=expected_ptm_chunk.dtype), + expected_ptm_chunk.cpu(), + ) + torch.testing.assert_close( + iptm.to_local().cpu().to(dtype=expected_iptm_chunk.dtype), + expected_iptm_chunk.cpu(), + ) + torch.testing.assert_close( + ligand_iptm.to_local().cpu().to(dtype=expected_ligand_iptm_chunk.dtype), + expected_ligand_iptm_chunk.cpu(), + ) + torch.testing.assert_close( + protein_iptm.to_local().cpu().to(dtype=expected_protein_iptm_chunk.dtype), + expected_protein_iptm_chunk.cpu(), + ) + + for idx1, chain_dict in expected_chain_pair_chunk.items(): + for idx2, expected_value in chain_dict.items(): + torch.testing.assert_close( + chain_pair_iptm[idx1][idx2].to_local().cpu().to(dtype=expected_value.dtype), + expected_value.cpu(), + ) + + for idx1, chain_dict in chain_pair_iptm.items(): + for idx2, dt in chain_dict.items(): + if idx1 not in expected_chain_pair_chunk or idx2 not in expected_chain_pair_chunk.get(idx1, {}): + assert torch.all( + dt.to_local() == CHAIN_IPTM_SENTINEL + ), f"Extra chain pair ({idx1}, {idx2}) should be sentinel {CHAIN_IPTM_SENTINEL}" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp2-cp2x2"], +) +@pytest.mark.parametrize("multiplicity", [1, 2], ids=["multiplicity=1", "multiplicity=2"]) +@pytest.mark.parametrize("seed", [0, 42], ids=["seed=0", "seed=42"]) +def test_compute_ptms_parallel(setup_env, multiplicity, seed): + """Test compute_ptms with DTensor using random features.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + size_ring = grid_group_sizes["cp"][0] + batch_size = grid_group_sizes["dp"] + n_tokens_per_rank = 20 + n_tokens = size_ring * n_tokens_per_rank + max_atoms_per_token = 18 + n_atoms_per_rank = n_tokens_per_rank * max_atoms_per_token + n_atoms = size_ring * n_atoms_per_rank + num_bins = 64 + + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + logits = torch.randn((batch_size * multiplicity, n_tokens, n_tokens, num_bins), device=device_type, generator=rng) + x_preds = torch.randn((batch_size * multiplicity, n_atoms, 3), device=device_type, generator=rng) + rng_features = torch.Generator(device=x_preds.device) + rng_features.manual_seed(seed) + + feats = random_features( + size_batch=batch_size, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=1, + atom_counts_per_token_range=(1, max_atoms_per_token), + device=x_preds.device, + float_value_range=(-1.0, 1.0), + selected_keys=[ + "asym_id", + "atom_to_token", + "atom_pad_mask", + "atom_resolved_mask", + "atom_counts_per_token", + "mol_type", + "token_pad_mask", + "frames_idx", + ], + rng=rng_features, + ) + + spawn_multiprocessing( + parallel_assert_compute_ptms, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + logits.detach().cpu(), + x_preds.detach().cpu(), + {k: v.detach().cpu() for k, v in feats.items()}, + multiplicity, + True, + ) + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp1-cp2x2"], +) +@pytest.mark.parametrize("multiplicity", [1, 2], ids=["multiplicity=1", "multiplicity=2"]) +@pytest.mark.parametrize("seed", [0, 42], ids=["seed=0", "seed=42"]) +def test_compute_ptms_real_data_parallel( + setup_env, + multiplicity, + seed, + create_preprocessed_handle_boltz1_v1, +): + """Test compute_ptms with DTensor using real Boltz-1 data.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + processed = create_preprocessed_handle_boltz1_v1 + record = processed.manifest.records[0] + input_data = load_input(record, processed.targets_dir, processed.msa_dir) + tokenized = BoltzTokenizer().tokenize(input_data) + n_atoms_raw, n_tokens_raw = get_num_atoms_tokens(tokenized) + + ring = grid_group_sizes["cp"][0] + atoms_per_window = 32 + atom_lcm = ring * atoms_per_window // gcd(ring, atoms_per_window) + max_atoms = ((n_atoms_raw + atom_lcm - 1) // atom_lcm) * atom_lcm + max_tokens = ((n_tokens_raw + ring - 1) // ring) * ring + max_seqs = ring + + featurizer = BoltzFeaturizer() + feats_single = featurizer.process( + tokenized, + training=False, + max_atoms=max_atoms, + max_tokens=max_tokens, + max_seqs=max_seqs, + pad_to_max_seqs=True, + ) + if not isinstance(feats_single, dict): + raise TypeError("Expected non-sharded feature dict from BoltzFeaturizer.process") + selected_keys = [ + "asym_id", + "atom_to_token", + "atom_pad_mask", + "atom_resolved_mask", + "mol_type", + "token_pad_mask", + "frames_idx", + ] + feats_single = {k: feats_single[k] for k in selected_keys} + # v1 featurizer doesn't emit atom_counts_per_token; derive from one-hot atom_to_token + feats_single["atom_counts_per_token"] = feats_single["atom_to_token"].sum(dim=0).to(torch.int64) + + batch_size = grid_group_sizes["dp"] + feats = {k: v.unsqueeze(0).repeat_interleave(batch_size, dim=0) for k, v in feats_single.items()} + + n_tokens = feats["token_pad_mask"].shape[1] + n_atoms = feats["atom_pad_mask"].shape[1] + num_bins = 64 + rng = torch.Generator(device=device_type) + rng.manual_seed(seed) + logits = torch.randn((batch_size * multiplicity, n_tokens, n_tokens, num_bins), device=device_type, generator=rng) + x_preds = torch.randn((batch_size * multiplicity, n_atoms, 3), device=device_type, generator=rng) + + spawn_multiprocessing( + parallel_assert_compute_ptms, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + logits.detach().cpu(), + x_preds.detach().cpu(), + {k: v.detach().cpu() for k, v in feats.items()}, + multiplicity, + False, + ) diff --git a/tests/distributed/model/modules/test_dtensor_confidencev2.py b/tests/distributed/model/modules/test_dtensor_confidencev2.py new file mode 100644 index 000000000..70708074e --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_confidencev2.py @@ -0,0 +1,1132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for the distributed v2 ConfidenceHeads and ConfidenceModule (DTensor). + +Checks that the distributed ConfidenceHeads and ConfidenceModule produce outputs +and gradients bit-for-bit identical to the serial v2 versions across: + - use_separate_heads=False (single shared PAE/PDE head) + - use_separate_heads=True (separate intra/inter-chain PAE and PDE heads) + - dtype=float64 and float32 + +PTM/iPTM features (frames_idx, atom_to_token, atom_pad_mask, atom_resolved_mask) +are included so that compute_ptms is exercised end-to-end. +""" + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.confidence_utils import CHAIN_IPTM_SENTINEL +from boltz.distributed.model.modules.confidencev2 import ConfidenceHeads as DTensorConfidenceHeadsV2 +from boltz.distributed.model.modules.confidencev2 import ConfidenceModule as DTensorConfidenceModuleV2 +from boltz.model.modules.confidencev2 import ConfidenceHeads as SerialConfidenceHeadsV2 +from boltz.model.modules.confidencev2 import ConfidenceModule as SerialConfidenceModuleV2 +from boltz.testing.utils import ( + assert_tensors_identical, + create_boltz2_model_init_params, + distribute_atom_features, + get_feature_placements, + init_module_params_glorot, + init_tensors_uniform, + random_features, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_test_dtensor_confidence_heads_v2( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + # config + dtype, + serial_state_dict, + confidence_heads_kwargs, + multiplicity, + # input tensors + s_global_host, + z_global_host, + x_pred_global_host, + d_global_host, + feats_global_host, + pred_distogram_logits_global_host, + # reference serial outputs + serial_output_feats_host, + # reference serial input gradients + s_grad_host, + z_grad_host, + # upstream gradient tensors + d_plddt_logits_host, + d_pde_logits_host, + d_resolved_logits_host, + d_pae_logits_host, + # reference serial parameter gradients + serial_param_grads_host, +): + """Parallel worker: distributes inputs, runs DTensor ConfidenceHeads, compares.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device_mesh = manager.device_mesh_subgroups + + seed_by_rank(0, 42) + + # Build serial module on device (for wrapping into distributed module) + serial_module = SerialConfidenceHeadsV2(**confidence_heads_kwargs) + serial_module = serial_module.to(device=manager.device, dtype=dtype).train() + serial_module.load_state_dict(serial_state_dict) + + cp_group = manager.group["cp"] + layout_group_cp = manager.layout_subgroups["cp"] + transpose_comm = TransposeComm(cp_group, layout_group_cp) + + module = DTensorConfidenceHeadsV2( + layer=serial_module, + device_mesh=device_mesh, + transpose_comm=transpose_comm, + ) + module = module.to(device=manager.device, dtype=dtype).train() + + # ----- distribute inputs ----- + # s: (B*mult, N, D_s) → (Shard(0), Shard(1), Replicate()) + s_dtensor = distribute_tensor( + s_global_host.to(device=manager.device, dtype=dtype).requires_grad_(True), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ) + + # z: (B*mult, N, N, D_z) → (Shard(0), Shard(1), Shard(2)) + z_dtensor = distribute_tensor( + z_global_host.to(device=manager.device, dtype=dtype).requires_grad_(True), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Shard(2)), + ) + + # d: (B*mult, N, N) → (Shard(0), Shard(1), Shard(2)) + d_dtensor = distribute_tensor( + d_global_host.to(device=manager.device, dtype=dtype), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Shard(2)), + ) + + # Distribute atom-level features via distribute_atom_features + placements_single_repr = (Shard(0), Shard(1), Replicate()) + placements_cp_single_repr = (Shard(0), Replicate()) + + special_atom_features = {"atom_pad_mask", "atom_to_token", "atom_resolved_mask", "frames_idx"} + atom_inputs = { + key: feats_global_host[key].to(device=manager.device) + for key in special_atom_features + if key in feats_global_host + } + + base_batch = s_global_host.shape[0] // multiplicity + x_pred_reshaped = x_pred_global_host.reshape(base_batch, multiplicity, *x_pred_global_host.shape[1:]) + for mul_idx in range(multiplicity): + atom_inputs[f"x_pred_{mul_idx}"] = x_pred_reshaped[:, mul_idx].to(manager.device, dtype=dtype) + + if "atom_counts_per_token" in feats_global_host: + atom_inputs["atom_counts_per_token"] = feats_global_host["atom_counts_per_token"].to( + manager.device, dtype=torch.int64 + ) + else: + atom_inputs["atom_counts_per_token"] = ( + feats_global_host["atom_to_token"].sum(dim=1).to(manager.device, dtype=torch.int64) + ) + + atom_placements_cp = {key: placements_cp_single_repr for key in atom_inputs} + atom_placements_cp["frames_idx"] = (Shard(1), Replicate()) + atom_placements_dp_cp = {key: placements_single_repr for key in atom_inputs if key != "atom_counts_per_token"} + atom_placements_dp_cp["frames_idx"] = (Shard(0), Shard(1), Replicate()) + atom_feats_dtensor = distribute_atom_features( + inputs=atom_inputs, + placements_cp=atom_placements_cp, + placements_dp_cp=atom_placements_dp_cp, + device_mesh=device_mesh, + cp_group=manager.group["cp"], + multiplicities={"x_pred": multiplicity}, + ) + + x_pred_dtensor = atom_feats_dtensor["x_pred"] + + feats_dtensor = { + "token_pad_mask": distribute_tensor( + feats_global_host["token_pad_mask"].to(device=manager.device), + device_mesh=device_mesh, + placements=placements_single_repr, + ), + "asym_id": distribute_tensor( + feats_global_host["asym_id"].to(device=manager.device), + device_mesh=device_mesh, + placements=placements_single_repr, + ), + "mol_type": distribute_tensor( + feats_global_host["mol_type"].to(device=manager.device), + device_mesh=device_mesh, + placements=placements_single_repr, + ), + } + for key in special_atom_features: + if key in atom_feats_dtensor: + feats_dtensor[key] = atom_feats_dtensor[key] + + pred_distogram_logits_dtensor = distribute_tensor( + pred_distogram_logits_global_host.to(manager.device), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Shard(2)), + ) + + # Keep copies to verify inputs are not mutated + s_dtensor_copy = s_dtensor.clone() + z_dtensor_copy = z_dtensor.clone() + + # ----- distribute upstream gradients ----- + d_plddt_logits_dtensor = distribute_tensor( + d_plddt_logits_host.to(device=manager.device, dtype=dtype), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ) + d_pde_logits_dtensor = distribute_tensor( + d_pde_logits_host.to(device=manager.device, dtype=dtype), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Shard(2)), + ) + d_resolved_logits_dtensor = distribute_tensor( + d_resolved_logits_host.to(device=manager.device, dtype=dtype), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ) + d_pae_logits_dtensor = distribute_tensor( + d_pae_logits_host.to(device=manager.device, dtype=dtype), + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Shard(2)), + ) + + # ----- forward ----- + output_dtensor = module( + s=s_dtensor, + z=z_dtensor, + x_pred=x_pred_dtensor, + d=d_dtensor, + feats=feats_dtensor, + pred_distogram_logits=pred_distogram_logits_dtensor, + multiplicity=multiplicity, + ) + + # Compare all outputs against serial reference + dp_rank = manager.group_rank["dp"] + dp_size = len(manager.group_ranks["dp"]) + for key in output_dtensor: + assert key in serial_output_feats_host, f"DTensor output key '{key}' missing from serial reference" + dtensor_val = output_dtensor[key] + if key == "pair_chains_iptm": + serial_pciptm = serial_output_feats_host[key] + if isinstance(dtensor_val, dict): + for idx1, chain_dict in dtensor_val.items(): + for idx2, dt_val in chain_dict.items(): + local_val = dt_val.to_local().cpu() + is_sentinel = torch.all(local_val == CHAIN_IPTM_SENTINEL) + if idx1 in serial_pciptm and idx2 in serial_pciptm.get(idx1, {}): + serial_full = serial_pciptm[idx1][idx2] + chunk_size = serial_full.shape[0] // dp_size + serial_local = serial_full[dp_rank * chunk_size : (dp_rank + 1) * chunk_size] + if is_sentinel: + assert torch.all(serial_local.abs() < 1e-5), ( + f"Chain pair ({idx1}, {idx2}): distributed returned sentinel but " + f"serial has non-zero values {serial_local} for this DP rank" + ) + else: + torch.testing.assert_close( + local_val, + serial_local, + msg=f"Chain pair ({idx1}, {idx2}) mismatch on DP rank {dp_rank}", + ) + else: + assert is_sentinel, ( + f"Extra chain pair ({idx1}, {idx2}) should be sentinel " + f"{CHAIN_IPTM_SENTINEL}, got {local_val}" + ) + else: + assert isinstance(serial_pciptm, torch.Tensor), ( + f"pair_chains_iptm: distributed returned DTensor (compute_ptms fallback) " + f"but serial returned {type(serial_pciptm)}" + ) + torch.testing.assert_close( + dtensor_val.full_tensor().cpu(), + serial_pciptm, + msg="pair_chains_iptm fallback mismatch", + ) + continue + torch.testing.assert_close( + dtensor_val.full_tensor().cpu(), + serial_output_feats_host[key], + msg=f"Mismatch for output key '{key}'", + ) + + # ----- backward ----- + torch.autograd.backward( + [ + output_dtensor["plddt_logits"], + output_dtensor["pde_logits"], + output_dtensor["resolved_logits"], + output_dtensor["pae_logits"], + ], + [ + d_plddt_logits_dtensor, + d_pde_logits_dtensor, + d_resolved_logits_dtensor, + d_pae_logits_dtensor, + ], + ) + + # Verify inputs were not mutated by the forward pass + assert_tensors_identical( + s_dtensor.to_local().cpu(), + s_dtensor_copy.to_local().cpu(), + check_grad=False, + check_grad_fn=False, + msg="s_dtensor was mutated during forward", + ) + assert_tensors_identical( + z_dtensor.to_local().cpu(), + z_dtensor_copy.to_local().cpu(), + check_grad=False, + check_grad_fn=False, + msg="z_dtensor was mutated during forward", + ) + + # Compare input gradients + torch.testing.assert_close(s_dtensor.grad.full_tensor().cpu(), s_grad_host, msg="s gradient mismatch") + torch.testing.assert_close(z_dtensor.grad.full_tensor().cpu(), z_grad_host, msg="z gradient mismatch") + + # Compare parameter gradients + result_param_grads = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in serial_param_grads_host: + raise ValueError( + f"Parameter '{name}' has a gradient in the distributed module " f"but not in the serial reference" + ) + result_param_grads[name] = param.grad + + for name, expected in serial_param_grads_host.items(): + assert name in result_param_grads, f"Parameter '{name}' gradient missing in distributed module" + torch.testing.assert_close( + result_param_grads[name].full_tensor().cpu(), + expected, + msg=f"Parameter gradient mismatch for '{name}'", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +# NOTE: the use_separate_heads=False codepath looks like it's not run in confidencev2.py, but add test parameterization +@pytest.mark.parametrize( + ("setup_env", "dtype", "use_separate_heads"), + ( + params_test := [ + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, True), + ] + ), + indirect=["setup_env"], + ids=[ + (f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, device:{x[0][2]}, init:{x[0][3]}, " f"dtype:{x[1]}, separate_heads:{x[2]}") + for x in params_test + ], +) +def test_dtensor_confidence_heads_v2(setup_env, dtype, use_separate_heads): + """Test that DTensor ConfidenceHeadsV2 matches serial ConfidenceHeadsV2. + + Covers: + * Forward pass parity for all logit outputs, aggregated metrics (pLDDT, PDE, PAE, + complex_plddt, complex_iplddt, complex_pde, complex_ipde). + * Backward pass parity for input gradients (s, z) and parameter gradients. + * Non-mutation of inputs during forward. + * Both the single-head (use_separate_heads=False) and the intra/inter-chain + separated head (use_separate_heads=True) configurations. + + PTM/iPTM keys are present in both outputs and are verified to match. + + Parameters + ---------- + setup_env : tuple + Grid group sizes, world size, device type, backend, environment variables per rank. + dtype : torch.dtype + Tensor dtype for forward and backward passes. + use_separate_heads : bool + Whether to use separate intra/inter-chain PAE and PDE projection heads. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, only {torch.cuda.device_count()} available") + + seed_by_rank(0) + + B = grid_group_sizes["dp"] + size_cp = grid_group_sizes["cp"][0] + multiplicity = 2 + + # N_tokens and N_atoms must be divisible by CP size for even sharding + N_tokens = 16 * size_cp + N_atoms = N_tokens * 4 + + selected_keys = [ + "mol_type", + "asym_id", + "token_pad_mask", + "frames_idx", + "atom_to_token", + "atom_resolved_mask", + "atom_pad_mask", + ] + + feats_host = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=1, + atom_counts_per_token_range=(1, 4), + device=torch.device("cpu"), + float_value_range=(-0.2, 0.2), + selected_keys=selected_keys, + ) + feats_host = {k: (v.to(dtype=dtype) if v.is_floating_point() else v) for k, v in feats_host.items()} + feats_device = {k: v.to(device=device_type) for k, v in feats_host.items()} + + boltz2_params = create_boltz2_model_init_params(use_large_model=True) + token_s = boltz2_params["token_s"] + token_z = boltz2_params["token_z"] + confidence_args = boltz2_params["confidence_model_args"]["confidence_args"] + num_distogram_bins = boltz2_params["confidence_model_args"]["num_dist_bins"] + val_init_range = 0.15 + + confidence_heads_kwargs = { + "token_s": token_s, + "token_z": token_z, + **confidence_args, + "token_level_confidence": True, + "use_separate_heads": use_separate_heads, + } + + # Input tensors: s and z have gradient tracking for backward verification. + s = torch.empty(B * multiplicity, N_tokens, token_s, device=device_type, dtype=dtype, requires_grad=True) + z = torch.empty(B * multiplicity, N_tokens, N_tokens, token_z, device=device_type, dtype=dtype, requires_grad=True) + x_pred = torch.empty(B * multiplicity, N_atoms, 3, device=device_type, dtype=dtype) + d = torch.empty(B * multiplicity, N_tokens, N_tokens, device=device_type, dtype=dtype) + pred_distogram_logits = torch.empty(B, N_tokens, N_tokens, num_distogram_bins, device=device_type, dtype=dtype) + + init_tensors_uniform([s, z, x_pred, d, pred_distogram_logits], low=-val_init_range, high=val_init_range) + + s_global_host = s.detach().clone().cpu() + z_global_host = z.detach().clone().cpu() + x_pred_global_host = x_pred.detach().clone().cpu() + d_global_host = d.detach().clone().cpu() + pred_distogram_logits_global_host = pred_distogram_logits.detach().clone().cpu() + + # ----- serial module ----- + serial_module = SerialConfidenceHeadsV2(**confidence_heads_kwargs) + serial_module = serial_module.to(device=device_type, dtype=dtype).train() + init_module_params_glorot(serial_module, gain=val_init_range) + + serial_state_dict = serial_module.state_dict() + + # ----- serial forward ----- + serial_output = serial_module( + s=s, + z=z, + x_pred=x_pred, + d=d, + feats=feats_device, + pred_distogram_logits=pred_distogram_logits, + multiplicity=multiplicity, + ) + + # Upstream gradients for backward + d_plddt_logits = torch.rand_like(serial_output["plddt_logits"], device=device_type) + d_pde_logits = torch.rand_like(serial_output["pde_logits"], device=device_type) + d_resolved_logits = torch.rand_like(serial_output["resolved_logits"], device=device_type) + d_pae_logits = torch.rand_like(serial_output["pae_logits"], device=device_type) + + d_plddt_logits_host = d_plddt_logits.detach().clone().cpu() + d_pde_logits_host = d_pde_logits.detach().clone().cpu() + d_resolved_logits_host = d_resolved_logits.detach().clone().cpu() + d_pae_logits_host = d_pae_logits.detach().clone().cpu() + + torch.autograd.backward( + [ + serial_output["plddt_logits"], + serial_output["pde_logits"], + serial_output["resolved_logits"], + serial_output["pae_logits"], + ], + [d_plddt_logits, d_pde_logits, d_resolved_logits, d_pae_logits], + ) + + # Save all serial outputs as CPU tensors; pair_chains_iptm handled separately + def _to_cpu(val): + if isinstance(val, torch.Tensor): + return val.detach().clone().cpu() + if isinstance(val, dict): + return {k: _to_cpu(v) for k, v in val.items()} + return val + + serial_output_feats_host = {k: _to_cpu(v) for k, v in serial_output.items()} + + # Verify the serial module has non-zero, non-NaN outputs (guard against vacuous pass) + assert not torch.isnan(serial_output["plddt_logits"]).any(), "serial plddt_logits contains NaN" + assert not torch.isnan(serial_output["pde_logits"]).any(), "serial pde_logits contains NaN" + assert not torch.isnan(serial_output["pae_logits"]).any(), "serial pae_logits contains NaN" + assert serial_output["plddt"].abs().max() > 0, "serial plddt is all-zero (vacuous)" + assert serial_output["complex_plddt"].abs().max() > 0, "serial complex_plddt is all-zero (vacuous)" + assert serial_output["ptm"].abs().max() > 0, "serial ptm is all-zero (vacuous)" + assert isinstance(serial_output["pair_chains_iptm"], dict), ( + f"serial pair_chains_iptm should be a dict (from compute_ptms), " + f"got {type(serial_output['pair_chains_iptm'])} — compute_ptms likely failed silently" + ) + + s_grad_host = s.grad.detach().clone().cpu() + z_grad_host = z.grad.detach().clone().cpu() + + # Verify non-zero gradients (guard against vacuous backward pass) + assert s_grad_host.abs().max() > 0, "serial s gradient is all-zero (vacuous)" + assert z_grad_host.abs().max() > 0, "serial z gradient is all-zero (vacuous)" + + serial_param_grads_host = { + name: param.grad.detach().clone().cpu() + for name, param in serial_module.named_parameters() + if param.grad is not None + } + assert len(serial_param_grads_host) > 0, "No serial parameter gradients found (vacuous)" + + # ----- parallel distributed test ----- + spawn_multiprocessing( + parallel_test_dtensor_confidence_heads_v2, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + serial_state_dict, + confidence_heads_kwargs, + multiplicity, + s_global_host, + z_global_host, + x_pred_global_host, + d_global_host, + feats_host, + pred_distogram_logits_global_host, + serial_output_feats_host, + s_grad_host, + z_grad_host, + d_plddt_logits_host, + d_pde_logits_host, + d_resolved_logits_host, + d_pae_logits_host, + serial_param_grads_host, + ) + + +# ============================================================================== +# ConfidenceModule v2 test +# ============================================================================== + + +def parallel_test_dtensor_confidence_module_v2( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + # config + dtype, + serial_state_dict, + confidence_module_kwargs, + multiplicity, + run_sequentially, + # input tensors + s_inputs_global_host, + s_global_host, + z_global_host, + x_pred_global_host, + feats_global_host, + pred_distogram_logits_global_host, + # reference serial outputs + serial_output_feats_host, + # reference serial input gradients + s_inputs_grad_host, + s_grad_host, + z_grad_host, + # upstream gradient tensors + d_plddt_logits_host, + d_pde_logits_host, + d_resolved_logits_host, + d_pae_logits_host, + # reference serial parameter gradients + serial_param_grads_host, +): + """Parallel worker: distributes inputs, runs DTensor ConfidenceModule v2, compares.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device_mesh = manager.device_mesh_subgroups + + seed_by_rank(0, 42) + + # Build serial module on device (for wrapping into distributed module) + serial_module = SerialConfidenceModuleV2(**confidence_module_kwargs) + serial_module = serial_module.to(device=manager.device, dtype=dtype).train() + serial_module.load_state_dict(serial_state_dict) + + cp_group = manager.group["cp"] + layout_group_cp = manager.layout_subgroups["cp"] + transpose_comm = TransposeComm(cp_group, layout_group_cp) + + module = DTensorConfidenceModuleV2( + module=serial_module, + dist_manager=manager, + transpose_comm=transpose_comm, + ) + module = module.to(device=manager.device, dtype=dtype).train() + + # ----- distribute inputs ----- + placements_single_repr = (Shard(0), Shard(1), Replicate()) + placements_pair_repr = (Shard(0), Shard(1), Shard(2)) + placements_cp_single_repr = (Shard(0), Replicate()) + + s_inputs_dtensor = distribute_tensor( + s_inputs_global_host.to(device=manager.device, dtype=dtype).requires_grad_(True), + device_mesh=device_mesh, + placements=placements_single_repr, + ) + s_dtensor = distribute_tensor( + s_global_host.to(device=manager.device, dtype=dtype).requires_grad_(True), + device_mesh=device_mesh, + placements=placements_single_repr, + ) + z_dtensor = distribute_tensor( + z_global_host.to(device=manager.device, dtype=dtype).requires_grad_(True), + device_mesh=device_mesh, + placements=placements_pair_repr, + ) + + # Distribute atom-level features (token_to_rep_atom, x_pred, PTM-related keys) + special_atom_features = {"token_to_rep_atom", "atom_pad_mask", "atom_to_token", "atom_resolved_mask", "frames_idx"} + atom_inputs = { + key: feats_global_host[key].to(device=manager.device) + for key in special_atom_features + if key in feats_global_host + } + + base_batch = s_inputs_global_host.shape[0] + x_pred_reshaped = x_pred_global_host.reshape(base_batch, multiplicity, *x_pred_global_host.shape[1:]) + for mul_idx in range(multiplicity): + atom_inputs[f"x_pred_{mul_idx}"] = x_pred_reshaped[:, mul_idx].to(manager.device, dtype=dtype) + + atom_inputs["atom_counts_per_token"] = feats_global_host["atom_counts_per_token"].to( + manager.device, dtype=torch.int64 + ) + + atom_placements_cp = {key: placements_cp_single_repr for key in atom_inputs} + atom_placements_cp["frames_idx"] = (Shard(1), Replicate()) + atom_placements_dp_cp = {key: placements_single_repr for key in atom_inputs if key != "atom_counts_per_token"} + atom_placements_dp_cp["frames_idx"] = (Shard(0), Shard(1), Replicate()) + atom_feats_dtensor = distribute_atom_features( + inputs=atom_inputs, + placements_cp=atom_placements_cp, + placements_dp_cp=atom_placements_dp_cp, + device_mesh=device_mesh, + cp_group=manager.group["cp"], + multiplicities={"x_pred": multiplicity}, + ) + + x_pred_dtensor = atom_feats_dtensor["x_pred"] + + feature_placements = get_feature_placements() + + single_repr_keys = [ + "token_pad_mask", + "asym_id", + "mol_type", + "residue_index", + "entity_id", + "token_index", + "sym_id", + "cyclic_period", + ] + feats_dtensor = { + key: distribute_tensor( + feats_global_host[key].to(device=manager.device), + device_mesh=device_mesh, + placements=feature_placements["token_features"][key], + ) + for key in single_repr_keys + if key in feats_global_host + } + + pair_repr_keys = ["token_bonds", "type_bonds", "token_pair_pad_mask", "contact_conditioning", "contact_threshold"] + for key in pair_repr_keys: + if key in feats_global_host: + feats_dtensor[key] = distribute_tensor( + feats_global_host[key].to(device=manager.device), + device_mesh=device_mesh, + placements=feature_placements["token_features"][key], + ) + + feats_dtensor["token_to_rep_atom"] = atom_feats_dtensor["token_to_rep_atom"] + for key in special_atom_features - {"token_to_rep_atom"}: + if key in atom_feats_dtensor: + feats_dtensor[key] = atom_feats_dtensor[key] + + pred_distogram_logits_dtensor = distribute_tensor( + pred_distogram_logits_global_host.to(manager.device), + device_mesh=device_mesh, + placements=placements_pair_repr, + ) + + # Keep copies to verify inputs are not mutated + s_inputs_dtensor_copy = s_inputs_dtensor.clone() + s_dtensor_copy = s_dtensor.clone() + z_dtensor_copy = z_dtensor.clone() + + # ----- distribute upstream gradients ----- + d_plddt_logits_dtensor = distribute_tensor( + d_plddt_logits_host.to(device=manager.device, dtype=dtype), + device_mesh=device_mesh, + placements=placements_single_repr, + ) + d_pde_logits_dtensor = distribute_tensor( + d_pde_logits_host.to(device=manager.device, dtype=dtype), + device_mesh=device_mesh, + placements=placements_pair_repr, + ) + d_resolved_logits_dtensor = distribute_tensor( + d_resolved_logits_host.to(device=manager.device, dtype=dtype), + device_mesh=device_mesh, + placements=placements_single_repr, + ) + d_pae_logits_dtensor = distribute_tensor( + d_pae_logits_host.to(device=manager.device, dtype=dtype), + device_mesh=device_mesh, + placements=placements_pair_repr, + ) + + # ----- forward ----- + output_dtensor = module( + s_inputs=s_inputs_dtensor, + s=s_dtensor, + z=z_dtensor, + x_pred=x_pred_dtensor, + feats=feats_dtensor, + pred_distogram_logits=pred_distogram_logits_dtensor, + multiplicity=multiplicity, + run_sequentially=run_sequentially, + ) + + # Compare all outputs against serial reference + dp_rank = manager.group_rank["dp"] + dp_size = len(manager.group_ranks["dp"]) + for key in output_dtensor: + assert key in serial_output_feats_host, f"DTensor output key '{key}' missing from serial reference" + dtensor_val = output_dtensor[key] + if key == "pair_chains_iptm": + serial_pciptm = serial_output_feats_host[key] + if isinstance(dtensor_val, dict): + for idx1, chain_dict in dtensor_val.items(): + for idx2, dt_val in chain_dict.items(): + local_val = dt_val.to_local().cpu() + is_sentinel = torch.all(local_val == CHAIN_IPTM_SENTINEL) + if idx1 in serial_pciptm and idx2 in serial_pciptm.get(idx1, {}): + serial_full = serial_pciptm[idx1][idx2] + chunk_size = serial_full.shape[0] // dp_size + serial_local = serial_full[dp_rank * chunk_size : (dp_rank + 1) * chunk_size] + if is_sentinel: + assert torch.all(serial_local.abs() < 1e-5), ( + f"Chain pair ({idx1}, {idx2}): distributed returned sentinel but " + f"serial has non-zero values {serial_local} for this DP rank" + ) + else: + torch.testing.assert_close( + local_val, + serial_local, + msg=f"Chain pair ({idx1}, {idx2}) mismatch on DP rank {dp_rank}", + ) + else: + assert is_sentinel, ( + f"Extra chain pair ({idx1}, {idx2}) should be sentinel " + f"{CHAIN_IPTM_SENTINEL}, got {local_val}" + ) + else: + assert isinstance(serial_pciptm, torch.Tensor), ( + f"pair_chains_iptm: distributed returned DTensor (compute_ptms fallback) " + f"but serial returned {type(serial_pciptm)}" + ) + torch.testing.assert_close( + dtensor_val.full_tensor().cpu(), + serial_pciptm, + msg="pair_chains_iptm fallback mismatch", + ) + continue + torch.testing.assert_close( + dtensor_val.full_tensor().cpu(), + serial_output_feats_host[key], + msg=f"Mismatch for output key '{key}'", + ) + + # ----- backward ----- + torch.autograd.backward( + [ + output_dtensor["plddt_logits"], + output_dtensor["pde_logits"], + output_dtensor["resolved_logits"], + output_dtensor["pae_logits"], + ], + [ + d_plddt_logits_dtensor, + d_pde_logits_dtensor, + d_resolved_logits_dtensor, + d_pae_logits_dtensor, + ], + ) + + # Verify inputs were not mutated by the forward pass + assert_tensors_identical( + s_inputs_dtensor.to_local().cpu(), + s_inputs_dtensor_copy.to_local().cpu(), + check_grad=False, + check_grad_fn=False, + msg="s_inputs_dtensor was mutated during forward", + ) + assert_tensors_identical( + s_dtensor.to_local().cpu(), + s_dtensor_copy.to_local().cpu(), + check_grad=False, + check_grad_fn=False, + msg="s_dtensor was mutated during forward", + ) + assert_tensors_identical( + z_dtensor.to_local().cpu(), + z_dtensor_copy.to_local().cpu(), + check_grad=False, + check_grad_fn=False, + msg="z_dtensor was mutated during forward", + ) + + # Compare input gradients + torch.testing.assert_close( + s_inputs_dtensor.grad.full_tensor().cpu(), s_inputs_grad_host, msg="s_inputs gradient mismatch" + ) + torch.testing.assert_close(s_dtensor.grad.full_tensor().cpu(), s_grad_host, msg="s gradient mismatch") + torch.testing.assert_close(z_dtensor.grad.full_tensor().cpu(), z_grad_host, msg="z gradient mismatch") + + # Compare parameter gradients + result_param_grads = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in serial_param_grads_host: + raise ValueError( + f"Parameter '{name}' has a gradient in the distributed module " f"but not in the serial reference" + ) + result_param_grads[name] = param.grad + + for name, expected in serial_param_grads_host.items(): + assert name in result_param_grads, f"Parameter '{name}' gradient missing in distributed module" + torch.testing.assert_close( + result_param_grads[name].full_tensor().cpu(), + expected, + msg=f"Parameter gradient mismatch for '{name}'", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + ("setup_env", "dtype", "use_separate_heads", "run_sequentially", "multiplicity"), + ( + params_test_module := [ + # multiplicity=2 exercises resolved_mask per-sample indexing paths + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, False, False, 2), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, True, False, 2), + # multiplicity=2 + run_sequentially needs B=1 (serial constraint), so dp=1 + (((1, (2, 2)), True, "cuda", "ENV"), torch.float64, False, True, 2), + ] + ), + indirect=["setup_env"], + ids=[ + ( + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, device:{x[0][2]}, init:{x[0][3]}, " + f"dtype:{x[1]}, separate_heads:{x[2]}, sequential:{x[3]}, mult:{x[4]}" + ) + for x in params_test_module + ], +) +def test_dtensor_confidence_module_v2(setup_env, dtype, use_separate_heads, run_sequentially, multiplicity): + """Test that DTensor ConfidenceModuleV2 matches serial ConfidenceModuleV2. + + Covers: + * Forward pass parity for all logit outputs and aggregated metrics (pLDDT, PDE, + PAE, complex_plddt, complex_iplddt, complex_pde, complex_ipde). + * Backward pass parity for input gradients (s_inputs, s, z) and parameter gradients. + * Non-mutation of inputs during forward. + * Both the single-head and the intra/inter-chain separated head configurations. + * Pairformer stack, distogram embedding, outer-sum s-to-z pair update. + * multiplicity=1 and multiplicity=2 (exercises resolved_mask per-sample indexing). + + PTM/iPTM and pair_chains_iptm outputs are compared between serial and distributed. + + Parameters + ---------- + setup_env : tuple + Grid group sizes, world size, device type, backend, environment variables per rank. + dtype : torch.dtype + Tensor dtype for forward and backward passes. + use_separate_heads : bool + Whether to use separate intra/inter-chain PAE and PDE projection heads. + multiplicity : int + Number of diffusion samples per batch element. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if dtype == torch.float32: + pytest.xfail("float32 dtype for logits has numerical stability issues and requires higher tolerances") + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, only {torch.cuda.device_count()} available") + + seed_by_rank(42) + + B = grid_group_sizes["dp"] + size_cp = grid_group_sizes["cp"][0] + + N_tokens = 16 * size_cp + N_atoms = N_tokens * 4 + + boltz2_params = create_boltz2_model_init_params(use_large_model=True) + token_s = boltz2_params["token_s"] + token_z = boltz2_params["token_z"] + num_distogram_bins = boltz2_params["confidence_model_args"]["num_dist_bins"] + + confidence_model_args = boltz2_params["confidence_model_args"].copy() + confidence_model_args["confidence_args"] = { + **confidence_model_args["confidence_args"], + "use_separate_heads": use_separate_heads, + } + confidence_module_kwargs = { + "token_s": token_s, + "token_z": token_z, + "pairformer_args": boltz2_params["pairformer_args"], + "token_level_confidence": True, + "bond_type_feature": boltz2_params["bond_type_feature"], + **confidence_model_args, + } + + selected_keys = [ + "mol_type", + "asym_id", + "token_pad_mask", + "token_pair_pad_mask", + "token_to_rep_atom", + "atom_counts_per_token", + "frames_idx", + "atom_to_token", + "atom_pad_mask", + "atom_resolved_mask", + "residue_index", + "entity_id", + "token_index", + "sym_id", + "cyclic_period", + "token_bonds", + "type_bonds", + "contact_conditioning", + "contact_threshold", + ] + + feats_host = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=1, + atom_counts_per_token_range=(1, 4), + device=torch.device("cpu"), + float_value_range=(-0.2, 0.2), + selected_keys=selected_keys, + ) + feats_host = {k: (v.to(dtype=dtype) if v.is_floating_point() else v) for k, v in feats_host.items()} + feats_device = {k: v.to(device=device_type) for k, v in feats_host.items()} + + val_init_range = 0.002 + + # Input tensors + s_inputs = torch.empty(B, N_tokens, token_s, device=device_type, dtype=dtype, requires_grad=True) + s = torch.empty(B, N_tokens, token_s, device=device_type, dtype=dtype, requires_grad=True) + z = torch.empty(B, N_tokens, N_tokens, token_z, device=device_type, dtype=dtype, requires_grad=True) + x_pred = torch.empty(B * multiplicity, N_atoms, 3, device=device_type, dtype=dtype) + pred_distogram_logits = torch.empty(B, N_tokens, N_tokens, num_distogram_bins, device=device_type, dtype=dtype) + + init_tensors_uniform([s_inputs, s, z, pred_distogram_logits], low=-val_init_range, high=val_init_range) + # x_pred needs a wider range so that inter-atom distances exceed + # the collinear-mask overlap threshold (0.01) in compute_ptms. + init_tensors_uniform([x_pred], low=-10.0, high=10.0) + + s_inputs_global_host = s_inputs.detach().clone().cpu() + s_global_host = s.detach().clone().cpu() + z_global_host = z.detach().clone().cpu() + x_pred_global_host = x_pred.detach().clone().cpu() + pred_distogram_logits_global_host = pred_distogram_logits.detach().clone().cpu() + + # ----- serial module ----- + serial_module = SerialConfidenceModuleV2(**confidence_module_kwargs) + serial_module = serial_module.to(device=device_type, dtype=dtype).train() + init_module_params_glorot(serial_module) + + serial_state_dict = serial_module.state_dict() + + # ----- serial forward ----- + serial_output = serial_module( + s_inputs=s_inputs, + s=s, + z=z, + x_pred=x_pred, + feats=feats_device, + pred_distogram_logits=pred_distogram_logits, + multiplicity=multiplicity, + run_sequentially=run_sequentially, + ) + + # Upstream gradients for backward + d_plddt_logits = torch.rand_like(serial_output["plddt_logits"], device=device_type) + d_pde_logits = torch.rand_like(serial_output["pde_logits"], device=device_type) + d_resolved_logits = torch.rand_like(serial_output["resolved_logits"], device=device_type) + d_pae_logits = torch.rand_like(serial_output["pae_logits"], device=device_type) + + d_plddt_logits_host = d_plddt_logits.detach().clone().cpu() + d_pde_logits_host = d_pde_logits.detach().clone().cpu() + d_resolved_logits_host = d_resolved_logits.detach().clone().cpu() + d_pae_logits_host = d_pae_logits.detach().clone().cpu() + + torch.autograd.backward( + [ + serial_output["plddt_logits"], + serial_output["pde_logits"], + serial_output["resolved_logits"], + serial_output["pae_logits"], + ], + [d_plddt_logits, d_pde_logits, d_resolved_logits, d_pae_logits], + ) + + def _to_cpu(val): + if isinstance(val, torch.Tensor): + return val.detach().clone().cpu() + if isinstance(val, dict): + return {k: _to_cpu(v) for k, v in val.items()} + return val + + serial_output_feats_host = {k: _to_cpu(v) for k, v in serial_output.items()} + + # Verify the serial module has non-zero, non-NaN outputs (guard against vacuous pass) + assert not torch.isnan(serial_output["plddt_logits"]).any(), "serial plddt_logits contains NaN" + assert not torch.isnan(serial_output["pde_logits"]).any(), "serial pde_logits contains NaN" + assert not torch.isnan(serial_output["pae_logits"]).any(), "serial pae_logits contains NaN" + assert serial_output["plddt"].abs().max() > 0, "serial plddt is all-zero (vacuous)" + assert serial_output["complex_plddt"].abs().max() > 0, "serial complex_plddt is all-zero (vacuous)" + assert serial_output["ptm"].abs().max() > 0, "serial ptm is all-zero (vacuous)" + assert isinstance(serial_output["pair_chains_iptm"], dict), ( + f"serial pair_chains_iptm should be a dict (from compute_ptms), " + f"got {type(serial_output['pair_chains_iptm'])} — compute_ptms likely failed silently" + ) + + s_inputs_grad_host = s_inputs.grad.detach().clone().cpu() + s_grad_host = s.grad.detach().clone().cpu() + z_grad_host = z.grad.detach().clone().cpu() + + # Verify non-zero gradients (guard against vacuous backward pass) + assert s_inputs_grad_host.abs().max() > 0, "serial s_inputs gradient is all-zero (vacuous)" + assert s_grad_host.abs().max() > 0, "serial s gradient is all-zero (vacuous)" + assert z_grad_host.abs().max() > 0, "serial z gradient is all-zero (vacuous)" + + serial_param_grads_host = { + name: param.grad.detach().clone().cpu() + for name, param in serial_module.named_parameters() + if param.grad is not None + } + assert len(serial_param_grads_host) > 0, "No serial parameter gradients found (vacuous)" + + # ----- parallel distributed test ----- + spawn_multiprocessing( + parallel_test_dtensor_confidence_module_v2, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + serial_state_dict, + confidence_module_kwargs, + multiplicity, + run_sequentially, + s_inputs_global_host, + s_global_host, + z_global_host, + x_pred_global_host, + feats_host, + pred_distogram_logits_global_host, + serial_output_feats_host, + s_inputs_grad_host, + s_grad_host, + z_grad_host, + d_plddt_logits_host, + d_pde_logits_host, + d_resolved_logits_host, + d_pae_logits_host, + serial_param_grads_host, + ) diff --git a/tests/distributed/model/modules/test_dtensor_diffusion.py b/tests/distributed/model/modules/test_dtensor_diffusion.py new file mode 100644 index 000000000..9961a7899 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_diffusion.py @@ -0,0 +1,2368 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor AtomDiffusion with window batching. + +Tests the DTensor AtomDiffusion (preconditioned_network_forward, training +forward) against V1 and V2 serial references, verifying forward and backward +numerical equivalence. + +Parametrized on ``internalized_conditioning``: +- True (V1): module owns pairwise_conditioner, forward takes z_trunk + relative_position_encoding +- False (V2): forward takes pre-computed diffusion_conditioning dict + +Both V1 and V2 use float64 with default tolerance for exact comparison. + +Adapted from Boltz-1x CP tests (test_dtensor_diffusion.py and +test_dtensor_diffusion_precond.py). +""" + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +import boltz.distributed.model.modules.diffusion as distributed_diffusion_module +import boltz.model.modules.diffusion as serial_diffusion_v1_module +import boltz.model.modules.diffusionv2 as serial_diffusion_v2_module +from boltz.data import const as boltz_const +from boltz.distributed.comm import AttentionPairBiasComm, TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.elementwise_op import ElementwiseOp, scalar_tensor_op +from boltz.distributed.model.modules.diffusion import AtomDiffusion as DistributedAtomDiffusion +from boltz.distributed.model.modules.diffusion_conditioning import ( + DiffusionConditioning as DistributedDiffusionConditioning, +) +from boltz.model.modules.diffusion import AtomDiffusion as SerialAtomDiffusionV1 +from boltz.model.modules.diffusion_conditioning import DiffusionConditioning as SerialDiffusionConditioning +from boltz.model.modules.diffusionv2 import AtomDiffusion as SerialAtomDiffusionV2 +from boltz.testing.utils import ( + SetModuleInfValues, + assert_tensors_identical, + distribute_atom_features, + get_feature_placements, + get_param_by_key, + init_module_params_glorot, + init_module_params_uniform, + init_tensors_uniform, + random_features, + seed_by_rank, + spawn_multiprocessing, +) + +# Atom features needed +_selected_atom_keys = { + "atom_pad_mask", + "ref_pos", + "ref_space_uid", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "atom_to_token", + "atom_counts_per_token", + "atom_resolved_mask", + "plddt", +} +# Token features needed +_selected_token_keys = {"token_pad_mask"} + +_selected_model_io_keys = { + "r_noisy_expected", + "r_update_expected", + "d_r_update_expected", + "d_r_noisy_expected", + "noise", + "denoised_atom_coords", + "d_denoised_atom_coords", + "aligned_true_atom_coords", +} + +_placements = get_feature_placements( + token_keys={"mol_type"}, + msa_keys=set(), + atom_keys=_selected_atom_keys, + model_io_keys=_selected_model_io_keys, + model_io_fp32_keys=set(), +) +_placements_cp_atom_features = _placements["cp_atom_features"] +_placements_atom_features = _placements["atom_features"] +_placements_model_io = _placements["model_io"] +_placements_cp_model_io = _placements["cp_model_io"] + +# ====================================================================== +# Test 1: preconditioned_network_forward +# ====================================================================== + + +def parallel_assert_precond_forward( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + internalized_conditioning: bool, + dtype: torch.dtype, + multiplicity: int, + # Module states + atom_diffusion_state_dict, + conditioning_state_dict, # None for internalized + score_model_kwargs: dict, + conditioning_kwargs: dict | None, # None for internalized + W: int, + H: int, + # Inputs + feats_global_host: dict[str, torch.Tensor], + s_inputs_global_host: torch.Tensor, + s_trunk_global_host: torch.Tensor, + z_trunk_global_host: torch.Tensor, + rel_pos_enc_global_host: torch.Tensor, + noised_atom_coords_global_host: torch.Tensor, + sigma_global_host: torch.Tensor, + # Expected outputs + denoised_coords_expected_global_host: torch.Tensor, + # Upstream grad + d_denoised_coords_global_host: torch.Tensor, + # Expected input grads + d_s_inputs_expected_global_host: torch.Tensor, + d_s_trunk_expected_global_host: torch.Tensor, + d_noised_coords_expected_global_host: torch.Tensor, + # Expected param grads + expected_param_grads_global_host_dict: dict[str, torch.Tensor], +): + """Parallel assertion for AtomDiffusion.preconditioned_network_forward().""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Re-create serial module to initialize DTensor version + if internalized_conditioning: + serial_atom_diffusion = SerialAtomDiffusionV1(score_model_args=score_model_kwargs) + else: + serial_atom_diffusion = SerialAtomDiffusionV2(score_model_args=score_model_kwargs) + serial_atom_diffusion = serial_atom_diffusion.to(device=manager.device, dtype=dtype) + serial_atom_diffusion.load_state_dict(atom_diffusion_state_dict) + serial_atom_diffusion = serial_atom_diffusion.train() + + # Create ring_comm for token-level transformer + ring_comm = AttentionPairBiasComm( + manager.group["cp"], + manager.layout_subgroups["cp"], + manager.subgroups["cp"][0], + manager.subgroups["cp"][1], + ) + + # Create DTensor module + module = DistributedAtomDiffusion( + layer=serial_atom_diffusion, + device_mesh=manager.device_mesh_subgroups, + ring_comm=ring_comm, + ).train() + + # ------------------------------------------------------------------ + # Distribute token-level tensors (common to both paths) + # ------------------------------------------------------------------ + placements_single = (Shard(0), Shard(1), Replicate()) + placements_pair = (Shard(0), Shard(1), Shard(2)) + + token_pad_mask_dt = distribute_tensor( + feats_global_host["token_pad_mask"].to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ) + + s_inputs_dt = distribute_tensor( + s_inputs_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ).requires_grad_(True) + s_trunk_dt = distribute_tensor( + s_trunk_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ).requires_grad_(True) + + placements_scalar = (Shard(0), Replicate(), Replicate()) + sigma_dt = distribute_tensor( + sigma_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_scalar, + ) + + z_trunk_device = z_trunk_global_host.to(device=manager.device, dtype=dtype) + rel_pos_enc_device = rel_pos_enc_global_host.to(device=manager.device, dtype=dtype) + + # ------------------------------------------------------------------ + # Distribute atom features, atom-level I/O, and conditioning + # ------------------------------------------------------------------ + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in _placements_cp_atom_features + } + + # All atom-level I/O must share the same intersperse-padded atom ordering. + # Both V1 and V2 pass unpacked feats — DiffusionModule packs internally. + size_batch = feats_global_host["atom_pad_mask"].shape[0] + io_tensors = { + "r_noisy_expected": noised_atom_coords_global_host, + "r_update_expected": denoised_coords_expected_global_host, + "d_r_update_expected": d_denoised_coords_global_host, + "d_r_noisy_expected": d_noised_coords_expected_global_host, + } + for base_name, tensor_host in io_tensors.items(): + unflat = tensor_host.unflatten(0, (size_batch, multiplicity)) + for i_mul in range(multiplicity): + inputs_atom[f"{base_name}_{i_mul}"] = unflat[:, i_mul].to(dtype=dtype) + + io_keys_used = set(io_tensors.keys()) + placements_cp_io_mul = { + f"{k}_{i_mul}": _placements_cp_model_io[k] for k in io_keys_used for i_mul in range(multiplicity) + } + placements_io_mul = {f"{k}_{i_mul}": _placements_model_io[k] for k in io_keys_used for i_mul in range(multiplicity)} + multiplicities = dict.fromkeys(io_keys_used, multiplicity) + + feats_and_io = distribute_atom_features( + inputs_atom, + _placements_cp_atom_features | placements_cp_io_mul, + _placements_atom_features | placements_io_mul, + manager.device_mesh_subgroups, + manager.group["cp"], + multiplicities=multiplicities, + ) + noised_atom_coords_dt = feats_and_io.pop("r_noisy_expected").requires_grad_(True) + denoised_expected_dt = feats_and_io.pop("r_update_expected") + d_denoised_expected_dt = feats_and_io.pop("d_r_update_expected") + d_noised_expected_dt = feats_and_io.pop("d_r_noisy_expected") + feats_dt = feats_and_io + feats_dt["token_pad_mask"] = token_pad_mask_dt + + z_trunk_dt = distribute_tensor(z_trunk_device, manager.device_mesh_subgroups, placements_pair) + rel_pos_enc_dt = distribute_tensor(rel_pos_enc_device, manager.device_mesh_subgroups, placements_pair) + + # ------------------------------------------------------------------ + # Build network_condition_kwargs (depends on internalized_conditioning) + # ------------------------------------------------------------------ + if internalized_conditioning: + # V1: pass z_trunk and rel_pos_enc directly + network_condition_kwargs = { + "s_inputs": s_inputs_dt, + "s_trunk": s_trunk_dt, + "feats": feats_dt, + "multiplicity": multiplicity, + "z_trunk": z_trunk_dt, + "relative_position_encoding": rel_pos_enc_dt, + } + else: + # V2: DTensor DiffusionConditioning takes unpacked feats (packs internally) + serial_conditioning = SerialDiffusionConditioning(**conditioning_kwargs) + serial_conditioning = serial_conditioning.to(device=manager.device, dtype=dtype) + serial_conditioning.load_state_dict(conditioning_state_dict) + serial_conditioning = serial_conditioning.eval() + + dtensor_conditioning = DistributedDiffusionConditioning( + layer=serial_conditioning, + device_mesh=manager.device_mesh_subgroups, + ).eval() + + with torch.no_grad(): + q_cond_dt, c_cond_dt, atom_enc_bias_dt, atom_dec_bias_dt, token_trans_bias_dt = dtensor_conditioning( + s_trunk=s_trunk_dt.detach(), + z_trunk=z_trunk_dt, + relative_position_encoding=rel_pos_enc_dt, + feats=feats_dt, + ) + + diff_cond_dt = { + "q": q_cond_dt, + "c": c_cond_dt, + "atom_enc_bias": atom_enc_bias_dt, + "atom_dec_bias": atom_dec_bias_dt, + "token_trans_bias": token_trans_bias_dt, + } + + network_condition_kwargs = { + "s_inputs": s_inputs_dt, + "s_trunk": s_trunk_dt, + "feats": feats_dt, + "multiplicity": multiplicity, + "diffusion_conditioning": diff_cond_dt, + } + + # ------------------------------------------------------------------ + # Forward pass: preconditioned_network_forward + # ------------------------------------------------------------------ + precond_result = module.preconditioned_network_forward( + noised_atom_coords_dt, + sigma_dt, + network_condition_kwargs, + ) + + if internalized_conditioning: + denoised_coords_dt = precond_result[0] + else: + denoised_coords_dt = precond_result + + # ------------------------------------------------------------------ + # Forward comparison (both V1 and V2 outputs are in intersperse-padded layout) + # ------------------------------------------------------------------ + torch.testing.assert_close(denoised_coords_dt.full_tensor(), denoised_expected_dt.full_tensor()) + + # ------------------------------------------------------------------ + # Backward pass (upstream grad in intersperse-padded layout) + # ------------------------------------------------------------------ + denoised_coords_dt.backward(d_denoised_expected_dt) + + # Check input gradients (token-level) + torch.testing.assert_close( + s_inputs_dt.grad.full_tensor(), + d_s_inputs_expected_global_host.to(device=manager.device, dtype=dtype), + ) + torch.testing.assert_close( + s_trunk_dt.grad.full_tensor(), + d_s_trunk_expected_global_host.to(device=manager.device, dtype=dtype), + ) + + # noised_atom_coords grad (atom-level, intersperse-padded layout) + torch.testing.assert_close(noised_atom_coords_dt.grad.full_tensor(), d_noised_expected_dt.full_tensor()) + + for name, grad_expected_global in expected_param_grads_global_host_dict.items(): + grad_param = get_param_by_key(module, name).grad + if grad_param is None: + continue + if hasattr(grad_param, "full_tensor"): + grad_global_host = grad_param.full_tensor().cpu() + else: + grad_global_host = grad_param.detach().cpu() + torch.testing.assert_close( + grad_global_host, + grad_expected_global.to(dtype=dtype), + msg=lambda m: f"Parameter gradient mismatch for {name}: {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=["setup_env"], + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +@pytest.mark.parametrize("multiplicity", [1, 4], ids=lambda x: f"mul:{x}") +@pytest.mark.parametrize("internalized_conditioning", [False, True], ids=["extern", "intern"]) +def test_preconditioned_network_forward(setup_env, multiplicity, internalized_conditioning: bool): + """Test AtomDiffusion.preconditioned_network_forward() with DTensor CP. + + Tests the core preconditioned forward pass which computes: + denoised = c_skip(sigma) * x + c_out(sigma) * score_model(c_in(sigma) * x, c_noise(sigma)) + + Parametrized on ``internalized_conditioning``: + - False (V2 / externalized): uses DiffusionConditioning to pre-compute q/c/bias + - True (V1 / internalized): passes z_trunk + relative_position_encoding directly + + Uses float64 for exact comparison. Verifies forward and backward numerical + equivalence against serial reference. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + dtype = torch.float64 + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + seed_by_rank(0, seed=seed) + + size_cp = grid_group_sizes["cp"][0] + B = 1 * grid_group_sizes["dp"] + + W = 32 + H = 128 + val_init_min_max = (-0.5, 0.5) + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + N_tokens = 30 * size_cp + N_atoms_raw = N_tokens * n_atoms_per_token_max + N_atoms = ((N_atoms_raw + W - 1) // W) * W + N_msa = 1 + + atom_s = 8 + token_s = 4 + token_z = 4 + atom_z = 8 + + atom_encoder_depth = 2 + atom_encoder_heads = 2 + token_transformer_depth = 2 + token_transformer_heads = 2 + atom_decoder_depth = 2 + atom_decoder_heads = 2 + conditioning_transition_layers = 1 + + # V1 includes atom_pad_mask in atom features; V2 does not + atom_feature_dim = 3 + 1 + (1 if internalized_conditioning else 0) + boltz_const.num_elements + 4 * 64 + + selected_keys = list(_selected_atom_keys | _selected_token_keys) + + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=selected_keys, + ) + feats = {k: v.to(dtype=dtype) if v.dtype == torch.float64 else v for k, v in feats.items()} + + # V1 s_inputs has wider dim: + # input_dim = 2 * token_s + 2 * num_tokens + 1 + len(pocket_contact_info) + # s_inputs_dim = input_dim - token_s + # V2 s_inputs has token_s dim + if internalized_conditioning: + v1_input_dim = 2 * token_s + 2 * boltz_const.num_tokens + 1 + len(boltz_const.pocket_contact_info) + s_inputs_dim = v1_input_dim - token_s + else: + s_inputs_dim = token_s + + s_inputs = torch.empty((B, N_tokens, s_inputs_dim), device=device_type, dtype=dtype, requires_grad=True) + s_trunk = torch.empty((B, N_tokens, token_s), device=device_type, dtype=dtype, requires_grad=True) + z_trunk = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype) + rel_pos_enc = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype) + init_tensors_uniform([s_inputs, s_trunk, z_trunk, rel_pos_enc], low=val_init_min_max[0], high=val_init_min_max[1]) + + # noised_atom_coords: (B*M, N_atoms_actual, 3) + N_atoms_actual = feats["atom_pad_mask"].shape[1] + noised_atom_coords = torch.empty( + (B * multiplicity, N_atoms_actual, 3), device=device_type, dtype=dtype, requires_grad=True + ) + init_tensors_uniform([noised_atom_coords], low=val_init_min_max[0], high=val_init_min_max[1]) + + # ------------------------------------------------------------------ + # Build serial modules (depends on internalized_conditioning) + # ------------------------------------------------------------------ + if internalized_conditioning: + # V1: module owns pairwise_conditioner, no external conditioning + score_model_kwargs = { + "token_s": token_s, + "token_z": token_z, + "atom_s": atom_s, + "atom_z": atom_z, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "sigma_data": 16, + "dim_fourier": 32, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "atom_feature_dim": atom_feature_dim, + "conditioning_transition_layers": conditioning_transition_layers, + } + conditioning_kwargs = None + conditioning_state_dict = None + + serial_atom_diffusion = SerialAtomDiffusionV1( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + ).to(device=device_type, dtype=dtype) + else: + # V2: uses DiffusionConditioning + score_model_kwargs = { + "token_s": token_s, + "atom_s": atom_s, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "sigma_data": 16, + "dim_fourier": 32, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "conditioning_transition_layers": conditioning_transition_layers, + } + conditioning_kwargs = { + "token_s": token_s, + "token_z": token_z, + "atom_s": atom_s, + "atom_z": atom_z, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "atom_feature_dim": atom_feature_dim, + "conditioning_transition_layers": conditioning_transition_layers, + } + serial_conditioning = SerialDiffusionConditioning(**conditioning_kwargs).to(device=device_type, dtype=dtype) + serial_conditioning.train() + init_module_params_glorot(serial_conditioning, gain=0.5) + serial_conditioning.apply(SetModuleInfValues()) + conditioning_state_dict = serial_conditioning.state_dict() + + serial_atom_diffusion = SerialAtomDiffusionV2( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + ).to(device=device_type, dtype=dtype) + + serial_atom_diffusion.train() + init_module_params_glorot(serial_atom_diffusion, gain=0.5) + serial_atom_diffusion.apply(SetModuleInfValues()) + atom_diffusion_state_dict = serial_atom_diffusion.state_dict() + + # Generate sigma from the module's noise distribution (deterministic via seed_by_rank) + sigma = serial_atom_diffusion.noise_distribution(B * multiplicity).to(device=device_type, dtype=dtype) + + # ------------------------------------------------------------------ + # Serial preconditioned_network_forward + # ------------------------------------------------------------------ + s_inputs_serial = s_inputs.detach().clone().requires_grad_(True) + s_trunk_serial = s_trunk.detach().clone().requires_grad_(True) + noised_coords_serial = noised_atom_coords.detach().clone().requires_grad_(True) + + if internalized_conditioning: + # V1: pass z_trunk + relative_position_encoding directly + serial_result = serial_atom_diffusion.preconditioned_network_forward( + noised_coords_serial, + sigma, + network_condition_kwargs={ + "s_inputs": s_inputs_serial, + "s_trunk": s_trunk_serial, + "z_trunk": z_trunk.detach(), + "relative_position_encoding": rel_pos_enc.detach(), + "feats": {k: v.detach().clone() for k, v in feats.items()}, + "multiplicity": multiplicity, + }, + ) + denoised_serial = serial_result[0] # (denoised_coords, token_a) + else: + # V2: pre-compute conditioning + with torch.no_grad(): + q_cond, c_cond, to_keys, atom_enc_bias_cond, atom_dec_bias_cond, token_trans_bias_cond = ( + serial_conditioning( + s_trunk=s_trunk.detach(), + z_trunk=z_trunk.detach(), + relative_position_encoding=rel_pos_enc.detach(), + feats={k: v.detach() for k, v in feats.items()}, + ) + ) + + diff_cond_serial = { + "q": q_cond.detach(), + "c": c_cond.detach(), + "to_keys": to_keys, + "atom_enc_bias": atom_enc_bias_cond.detach(), + "atom_dec_bias": atom_dec_bias_cond.detach(), + "token_trans_bias": token_trans_bias_cond.detach(), + } + + denoised_serial = serial_atom_diffusion.preconditioned_network_forward( + noised_coords_serial, + sigma, + network_condition_kwargs={ + "s_inputs": s_inputs_serial, + "s_trunk": s_trunk_serial, + "feats": {k: v.detach().clone() for k, v in feats.items()}, + "multiplicity": multiplicity, + "diffusion_conditioning": diff_cond_serial, + }, + ) + + # Upstream gradient + d_denoised = torch.empty_like(denoised_serial) + init_tensors_uniform([d_denoised], low=val_init_min_max[0], high=val_init_min_max[1]) + atom_mask_mul = feats["atom_pad_mask"].repeat_interleave(multiplicity, 0).unsqueeze(-1) + d_denoised = d_denoised * atom_mask_mul + + denoised_serial.backward(d_denoised) + + expected_param_grads = { + name: param.grad.detach().cpu() + for name, param in serial_atom_diffusion.named_parameters() + if param.requires_grad and param.grad is not None + } + + spawn_multiprocessing( + parallel_assert_precond_forward, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + internalized_conditioning, + dtype, + multiplicity, + {k: v.detach().cpu() for k, v in atom_diffusion_state_dict.items()}, + {k: v.detach().cpu() for k, v in conditioning_state_dict.items()} if conditioning_state_dict else None, + score_model_kwargs, + conditioning_kwargs, + W, + H, + {k: v.detach().cpu() for k, v in feats.items()}, + s_inputs.detach().cpu(), + s_trunk.detach().cpu(), + z_trunk.detach().cpu(), + rel_pos_enc.detach().cpu(), + noised_atom_coords.detach().cpu(), + sigma.detach().cpu(), + denoised_serial.detach().cpu(), + d_denoised.detach().cpu(), + s_inputs_serial.grad.detach().cpu(), + s_trunk_serial.grad.detach().cpu(), + noised_coords_serial.grad.detach().cpu(), + expected_param_grads, + ) + + +# ====================================================================== +# Test 2: AtomDiffusion.forward() (training forward) +# ====================================================================== + + +def parallel_assert_atom_diffusion_forward( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + internalized_conditioning: bool, + dtype: torch.dtype, + multiplicity: int, + # Module states + atom_diffusion_state_dict, + conditioning_state_dict, # None for internalized + score_model_kwargs: dict, + conditioning_kwargs: dict | None, # None for internalized + W: int, + H: int, + # Inputs + feats_global_host: dict[str, torch.Tensor], + s_inputs_global_host: torch.Tensor, + s_trunk_global_host: torch.Tensor, + z_trunk_global_host: torch.Tensor, + rel_pos_enc_global_host: torch.Tensor, + coords_global_host: torch.Tensor, + sigmas_global_host: torch.Tensor, + noise_global_host: torch.Tensor, + # Expected outputs + denoised_expected_global_host: torch.Tensor, + noised_expected_global_host: torch.Tensor | None, # V1 only + # Upstream grad + d_denoised_global_host: torch.Tensor, + # Expected input grads + d_s_inputs_expected_global_host: torch.Tensor, + d_s_trunk_expected_global_host: torch.Tensor, + # Expected param grads + expected_param_grads_global_host_dict: dict[str, torch.Tensor], +): + """Parallel assertion for AtomDiffusion.forward() (training forward).""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Re-create serial module to initialize DTensor version + if internalized_conditioning: + serial_atom_diffusion = SerialAtomDiffusionV1( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + ) + else: + serial_atom_diffusion = SerialAtomDiffusionV2( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + ) + serial_atom_diffusion = serial_atom_diffusion.to(device=manager.device, dtype=dtype) + serial_atom_diffusion.load_state_dict(atom_diffusion_state_dict) + serial_atom_diffusion = serial_atom_diffusion.train() + + # Create ring_comm for token-level transformer + ring_comm = AttentionPairBiasComm( + manager.group["cp"], + manager.layout_subgroups["cp"], + manager.subgroups["cp"][0], + manager.subgroups["cp"][1], + ) + + # Create DTensor module + module = DistributedAtomDiffusion( + layer=serial_atom_diffusion, + device_mesh=manager.device_mesh_subgroups, + ring_comm=ring_comm, + ).train() + + # ------------------------------------------------------------------ + # Distribute token-level tensors (common to both paths) + # ------------------------------------------------------------------ + placements_single = (Shard(0), Shard(1), Replicate()) + placements_pair = (Shard(0), Shard(1), Shard(2)) + placements_scalar = (Shard(0), Replicate(), Replicate()) + + token_pad_mask_dt = distribute_tensor( + feats_global_host["token_pad_mask"].to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ) + + s_inputs_dt = distribute_tensor( + s_inputs_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ).requires_grad_(True) + s_trunk_dt = distribute_tensor( + s_trunk_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ).requires_grad_(True) + + sigmas_dt = distribute_tensor( + sigmas_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_scalar, + ) + + z_trunk_device = z_trunk_global_host.to(device=manager.device, dtype=dtype) + rel_pos_enc_device = rel_pos_enc_global_host.to(device=manager.device, dtype=dtype) + + # ------------------------------------------------------------------ + # Distribute atom features, coords, noise, and conditioning + # ------------------------------------------------------------------ + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in _placements_cp_atom_features + } + + # All atom-level I/O must share the same intersperse-padded atom ordering. + # Both V1 and V2 pass unpacked feats — DiffusionModule packs internally. + size_batch = feats_global_host["atom_pad_mask"].shape[0] + io_tensors = { + "noise": noise_global_host, + "denoised_atom_coords": coords_global_host, + "d_denoised_atom_coords": d_denoised_global_host, + } + for base_name, tensor_host in io_tensors.items(): + unflat = tensor_host.unflatten(0, (size_batch, multiplicity)) + for i_mul in range(multiplicity): + inputs_atom[f"{base_name}_{i_mul}"] = unflat[:, i_mul].to(dtype=dtype) + + io_keys_used = set(io_tensors.keys()) + placements_cp_io_mul = { + f"{k}_{i_mul}": _placements_cp_model_io[k] for k in io_keys_used for i_mul in range(multiplicity) + } + placements_io_mul = {f"{k}_{i_mul}": _placements_model_io[k] for k in io_keys_used for i_mul in range(multiplicity)} + multiplicities = dict.fromkeys(io_keys_used, multiplicity) + + feats_and_io = distribute_atom_features( + inputs_atom, + _placements_cp_atom_features | placements_cp_io_mul, + _placements_atom_features | placements_io_mul, + manager.device_mesh_subgroups, + manager.group["cp"], + multiplicities=multiplicities, + ) + noise_dt = feats_and_io.pop("noise") + coords_dt = feats_and_io.pop("denoised_atom_coords") + d_denoised_expected_dt = feats_and_io.pop("d_denoised_atom_coords") + feats_dt = feats_and_io + feats_dt["token_pad_mask"] = token_pad_mask_dt + feats_dt["coords"] = coords_dt + + z_trunk_dt = distribute_tensor(z_trunk_device, manager.device_mesh_subgroups, placements_pair) + rel_pos_enc_dt = distribute_tensor(rel_pos_enc_device, manager.device_mesh_subgroups, placements_pair) + + # ------------------------------------------------------------------ + # Forward pass (with monkeypatched noise_distribution and create_distributed_randn) + # ------------------------------------------------------------------ + monkeypatch.setattr(module, "noise_distribution", lambda bs, dtype=None: sigmas_dt) + monkeypatch.setattr(distributed_diffusion_module, "create_distributed_randn", lambda *a, **kw: noise_dt) + + if internalized_conditioning: + # V1: pass z_trunk and rel_pos_enc directly + out_dt = module( + s_inputs=s_inputs_dt, + s_trunk=s_trunk_dt, + feats=feats_dt, + z_trunk=z_trunk_dt, + relative_position_encoding=rel_pos_enc_dt, + multiplicity=multiplicity, + ) + else: + # V2: DTensor DiffusionConditioning takes unpacked feats (packs internally) + serial_conditioning = SerialDiffusionConditioning(**conditioning_kwargs) + serial_conditioning = serial_conditioning.to(device=manager.device, dtype=dtype) + serial_conditioning.load_state_dict(conditioning_state_dict) + serial_conditioning = serial_conditioning.eval() + + dtensor_conditioning = DistributedDiffusionConditioning( + layer=serial_conditioning, + device_mesh=manager.device_mesh_subgroups, + ).eval() + + with torch.no_grad(): + q_cond_dt, c_cond_dt, atom_enc_bias_dt, atom_dec_bias_dt, token_trans_bias_dt = dtensor_conditioning( + s_trunk=s_trunk_dt.detach(), + z_trunk=z_trunk_dt, + relative_position_encoding=rel_pos_enc_dt, + feats=feats_dt, + ) + + diff_cond_dt = { + "q": q_cond_dt, + "c": c_cond_dt, + "atom_enc_bias": atom_enc_bias_dt, + "atom_dec_bias": atom_dec_bias_dt, + "token_trans_bias": token_trans_bias_dt, + } + + out_dt = module( + s_inputs=s_inputs_dt, + s_trunk=s_trunk_dt, + feats=feats_dt, + diffusion_conditioning=diff_cond_dt, + multiplicity=multiplicity, + ) + + denoised_dt = out_dt["denoised_atom_coords"] + + # ------------------------------------------------------------------ + # Forward comparison: noised_atom_coords (V1 only) + # Extract real atoms via boolean mask from both layouts and compare. + # denoised_atom_coords comparison is skipped because center_random_augmentation + # processing cannot be trivially aligned between layouts; correctness is + # verified through the backward pass instead. + # ------------------------------------------------------------------ + if internalized_conditioning and noised_expected_global_host is not None: + noised_dt = out_dt["noised_atom_coords"] + noised_expected = noised_expected_global_host.to(device=manager.device, dtype=dtype) + dt_mask = feats_dt["atom_pad_mask"].full_tensor().repeat_interleave(multiplicity, 0).bool() + serial_mask = ( + feats_global_host["atom_pad_mask"].to(device=manager.device).repeat_interleave(multiplicity, 0).bool() + ) + torch.testing.assert_close(noised_dt.full_tensor()[dt_mask], noised_expected[serial_mask]) + + # ------------------------------------------------------------------ + # Backward pass (upstream grad in intersperse-padded layout) + # ------------------------------------------------------------------ + denoised_dt.backward(d_denoised_expected_dt) + + # Check input gradients (token-level) + torch.testing.assert_close( + s_inputs_dt.grad.full_tensor(), + d_s_inputs_expected_global_host.to(device=manager.device, dtype=dtype), + ) + torch.testing.assert_close( + s_trunk_dt.grad.full_tensor(), + d_s_trunk_expected_global_host.to(device=manager.device, dtype=dtype), + ) + + for name, grad_expected_global in expected_param_grads_global_host_dict.items(): + grad_param = get_param_by_key(module, name).grad + if grad_param is None: + continue + if hasattr(grad_param, "full_tensor"): + grad_global_host = grad_param.full_tensor().cpu() + else: + grad_global_host = grad_param.detach().cpu() + torch.testing.assert_close( + grad_global_host, + grad_expected_global.to(dtype=dtype), + msg=lambda m: f"Parameter gradient mismatch for {name}: {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=["setup_env"], + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +@pytest.mark.parametrize("multiplicity", [1, 4], ids=lambda x: f"mul:{x}") +@pytest.mark.parametrize("internalized_conditioning", [False, True], ids=["extern", "intern"]) +def test_atom_diffusion_forward(setup_env, multiplicity, internalized_conditioning: bool): + """Test AtomDiffusion.forward() (training forward) with DTensor CP. + + Tests the full training forward which includes: + - center_random_augmentation (with augmentation=False for determinism) + - Noise generation (provided externally for determinism) + - Sigma scheduling (provided externally for determinism) + - Preconditioned network forward + + Parametrized on ``internalized_conditioning``: + - False (V2 / externalized): uses DiffusionConditioning + - True (V1 / internalized): passes z_trunk + relative_position_encoding + + Uses float64 for exact comparison. Uses pre-generated sigmas and noise for deterministic serial vs DTensor + comparison of forward and backward numerical equivalence. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + dtype = torch.float64 + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + seed_by_rank(0, seed=seed) + + size_cp = grid_group_sizes["cp"][0] + B = 1 * grid_group_sizes["dp"] + + W = 32 + H = 128 + val_init_min_max = (-0.2, 0.2) + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + N_tokens = 30 * size_cp + N_atoms_raw = N_tokens * n_atoms_per_token_max + N_atoms = ((N_atoms_raw + W - 1) // W) * W + N_msa = 1 + + atom_s = 8 + token_s = 4 + token_z = 4 + atom_z = 8 + + atom_encoder_depth = 2 + atom_encoder_heads = 2 + token_transformer_depth = 2 + token_transformer_heads = 2 + atom_decoder_depth = 2 + atom_decoder_heads = 2 + conditioning_transition_layers = 1 + + # V1 includes atom_pad_mask in atom features; V2 does not + atom_feature_dim = 3 + 1 + (1 if internalized_conditioning else 0) + boltz_const.num_elements + 4 * 64 + + selected_keys = list(_selected_atom_keys | _selected_token_keys) + + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=selected_keys, + ) + feats = {k: v.to(dtype=dtype) if v.dtype == torch.float64 else v for k, v in feats.items()} + N_atoms_actual = feats["atom_pad_mask"].shape[1] + + # V1 s_inputs has wider dim + if internalized_conditioning: + v1_input_dim = 2 * token_s + 2 * boltz_const.num_tokens + 1 + len(boltz_const.pocket_contact_info) + s_inputs_dim = v1_input_dim - token_s + else: + s_inputs_dim = token_s + + s_inputs = torch.empty((B, N_tokens, s_inputs_dim), device=device_type, dtype=dtype, requires_grad=True) + s_trunk = torch.empty((B, N_tokens, token_s), device=device_type, dtype=dtype, requires_grad=True) + z_trunk = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype) + rel_pos_enc = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype) + init_tensors_uniform([s_inputs, s_trunk, z_trunk, rel_pos_enc], low=val_init_min_max[0], high=val_init_min_max[1]) + + # coords: (B*M, N_atoms, 3) for both V1 and V2 (DTensor format) + coords = torch.empty((B * multiplicity, N_atoms_actual, 3), device=device_type, dtype=dtype) + init_tensors_uniform([coords], low=val_init_min_max[0], high=val_init_min_max[1]) + + # ------------------------------------------------------------------ + # Build serial modules (depends on internalized_conditioning) + # ------------------------------------------------------------------ + if internalized_conditioning: + score_model_kwargs = { + "token_s": token_s, + "token_z": token_z, + "atom_s": atom_s, + "atom_z": atom_z, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "sigma_data": 16, + "dim_fourier": 32, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "atom_feature_dim": atom_feature_dim, + "conditioning_transition_layers": conditioning_transition_layers, + } + conditioning_kwargs = None + conditioning_state_dict = None + + serial_atom_diffusion = SerialAtomDiffusionV1( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + ).to(device=device_type, dtype=dtype) + else: + score_model_kwargs = { + "token_s": token_s, + "atom_s": atom_s, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "sigma_data": 16, + "dim_fourier": 32, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "conditioning_transition_layers": conditioning_transition_layers, + } + conditioning_kwargs = { + "token_s": token_s, + "token_z": token_z, + "atom_s": atom_s, + "atom_z": atom_z, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "atom_feature_dim": atom_feature_dim, + "conditioning_transition_layers": conditioning_transition_layers, + } + serial_conditioning = SerialDiffusionConditioning(**conditioning_kwargs).to(device=device_type, dtype=dtype) + serial_conditioning.train() + init_module_params_glorot(serial_conditioning, gain=0.5) + serial_conditioning.apply(SetModuleInfValues()) + conditioning_state_dict = serial_conditioning.state_dict() + + serial_atom_diffusion = SerialAtomDiffusionV2( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + ).to(device=device_type, dtype=dtype) + + serial_atom_diffusion.train() + init_module_params_glorot(serial_atom_diffusion, gain=0.5) + serial_atom_diffusion.apply(SetModuleInfValues()) + atom_diffusion_state_dict = serial_atom_diffusion.state_dict() + + # Pre-generate sigmas and noise for deterministic comparison. + # These are injected via monkeypatching so serial and DTensor use identical values. + sigmas = serial_atom_diffusion.noise_distribution(B * multiplicity).to(device=device_type, dtype=dtype) + noise = torch.empty((B * multiplicity, N_atoms_actual, 3), device=device_type, dtype=dtype) + init_tensors_uniform([noise], low=val_init_min_max[0], high=val_init_min_max[1]) + + # ------------------------------------------------------------------ + # Serial forward (with monkeypatched noise_distribution and randn_like) + # ------------------------------------------------------------------ + serial_mod = serial_diffusion_v1_module if internalized_conditioning else serial_diffusion_v2_module + _monkeypatch = pytest.MonkeyPatch() + _monkeypatch.setattr(serial_atom_diffusion, "noise_distribution", lambda bs, dtype=None: sigmas) + _monkeypatch.setattr(serial_mod.torch, "randn_like", lambda t: noise.to(t)) + + feats_serial = {k: v.detach().clone() for k, v in feats.items()} + s_inputs_serial = s_inputs.detach().clone().requires_grad_(True) + s_trunk_serial = s_trunk.detach().clone().requires_grad_(True) + + if internalized_conditioning: + # V1: coords need (B, M, N_atoms, 3) shape for serial forward + feats_serial["coords"] = coords.detach().clone().reshape(B, multiplicity, N_atoms_actual, 3) + + out_serial = serial_atom_diffusion( + s_inputs=s_inputs_serial, + s_trunk=s_trunk_serial, + z_trunk=z_trunk.detach(), + relative_position_encoding=rel_pos_enc.detach(), + feats=feats_serial, + multiplicity=multiplicity, + ) + else: + feats_serial["coords"] = coords.detach().clone() + + with torch.no_grad(): + q_cond, c_cond, to_keys, atom_enc_bias_cond, atom_dec_bias_cond, token_trans_bias_cond = ( + serial_conditioning( + s_trunk=s_trunk.detach(), + z_trunk=z_trunk.detach(), + relative_position_encoding=rel_pos_enc.detach(), + feats={k: v.detach() for k, v in feats.items()}, + ) + ) + + diff_cond_serial = { + "q": q_cond.detach(), + "c": c_cond.detach(), + "to_keys": to_keys, + "atom_enc_bias": atom_enc_bias_cond.detach(), + "atom_dec_bias": atom_dec_bias_cond.detach(), + "token_trans_bias": token_trans_bias_cond.detach(), + } + + out_serial = serial_atom_diffusion( + s_inputs=s_inputs_serial, + s_trunk=s_trunk_serial, + feats=feats_serial, + diffusion_conditioning=diff_cond_serial, + multiplicity=multiplicity, + ) + + _monkeypatch.undo() + + denoised_serial = out_serial["denoised_atom_coords"] + noised_serial = out_serial.get("noised_atom_coords") # V1 only + + # Upstream gradient + d_denoised = torch.empty_like(denoised_serial) + init_tensors_uniform([d_denoised], low=val_init_min_max[0], high=val_init_min_max[1]) + atom_mask_mul = feats["atom_pad_mask"].repeat_interleave(multiplicity, 0).unsqueeze(-1) + d_denoised = d_denoised * atom_mask_mul + + denoised_serial.backward(d_denoised) + + expected_param_grads = { + name: param.grad.detach().cpu() + for name, param in serial_atom_diffusion.named_parameters() + if param.requires_grad and param.grad is not None + } + + spawn_multiprocessing( + parallel_assert_atom_diffusion_forward, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + internalized_conditioning, + dtype, + multiplicity, + {k: v.detach().cpu() for k, v in atom_diffusion_state_dict.items()}, + {k: v.detach().cpu() for k, v in conditioning_state_dict.items()} if conditioning_state_dict else None, + score_model_kwargs, + conditioning_kwargs, + W, + H, + {k: v.detach().cpu() for k, v in feats.items()}, + s_inputs.detach().cpu(), + s_trunk.detach().cpu(), + z_trunk.detach().cpu(), + rel_pos_enc.detach().cpu(), + coords.detach().cpu(), + sigmas.detach().cpu(), + noise.detach().cpu(), + denoised_serial.detach().cpu(), + noised_serial.detach().cpu() if noised_serial is not None else None, + d_denoised.detach().cpu(), + s_inputs_serial.grad.detach().cpu(), + s_trunk_serial.grad.detach().cpu(), + expected_param_grads, + ) + + +# ====================================================================== +# Test 3: AtomDiffusion.sample() (inference sampling) +# ====================================================================== + + +def parallel_assert_atom_diffusion_sample( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + internalized_conditioning: bool, + dtype: torch.dtype, + multiplicity: int, + num_sampling_steps: int, + max_parallel_samples: int, + atom_diffusion_state_dict, + conditioning_state_dict, + score_model_kwargs: dict, + conditioning_kwargs: dict | None, + W: int, + H: int, + feats_global_host: dict[str, torch.Tensor], + s_inputs_global_host: torch.Tensor, + s_trunk_global_host: torch.Tensor, + z_trunk_global_host: torch.Tensor, + rel_pos_enc_global_host: torch.Tensor, + init_noise_global_host: torch.Tensor, + step_noise_list_global_host: list[torch.Tensor], + sample_coords_expected_global_host: torch.Tensor, +): + """Parallel assertion for AtomDiffusion.sample().""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + if internalized_conditioning: + serial_atom_diffusion = SerialAtomDiffusionV1( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + alignment_reverse_diff=False, + num_sampling_steps=num_sampling_steps, + ) + else: + serial_atom_diffusion = SerialAtomDiffusionV2( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + alignment_reverse_diff=False, + num_sampling_steps=num_sampling_steps, + ) + serial_atom_diffusion = serial_atom_diffusion.to(device=manager.device, dtype=dtype) + serial_atom_diffusion.load_state_dict(atom_diffusion_state_dict) + serial_atom_diffusion = serial_atom_diffusion.eval() + + ring_comm = AttentionPairBiasComm( + manager.group["cp"], + manager.layout_subgroups["cp"], + manager.subgroups["cp"][0], + manager.subgroups["cp"][1], + ) + module = DistributedAtomDiffusion( + layer=serial_atom_diffusion, + device_mesh=manager.device_mesh_subgroups, + ring_comm=ring_comm, + ).eval() + + # ------------------------------------------------------------------ + # Distribute token-level tensors + # ------------------------------------------------------------------ + placements_single = (Shard(0), Shard(1), Replicate()) + placements_pair = (Shard(0), Shard(1), Shard(2)) + + token_pad_mask_dt = distribute_tensor( + feats_global_host["token_pad_mask"].to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ) + s_inputs_dt = distribute_tensor( + s_inputs_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ) + s_trunk_dt = distribute_tensor( + s_trunk_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ) + + z_trunk_device = z_trunk_global_host.to(device=manager.device, dtype=dtype) + rel_pos_enc_device = rel_pos_enc_global_host.to(device=manager.device, dtype=dtype) + + # ------------------------------------------------------------------ + # Distribute atom features + noise tensors via distribute_atom_features. + # Noise tensors are distributed alongside atom features so that + # intersperse padding naturally places zeros at padding positions. + # ------------------------------------------------------------------ + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in _placements_cp_atom_features + } + size_batch = feats_global_host["atom_pad_mask"].shape[0] + + all_noise = [init_noise_global_host] + list(step_noise_list_global_host) + for i_noise, noise_host in enumerate(all_noise): + unflat = noise_host.unflatten(0, (size_batch, multiplicity)) + for i_mul in range(multiplicity): + inputs_atom[f"_noise_{i_noise}_{i_mul}"] = unflat[:, i_mul].to(dtype=dtype) + + noise_cp_placements = {} + noise_placements = {} + for i_noise in range(len(all_noise)): + for i_mul in range(multiplicity): + key = f"_noise_{i_noise}_{i_mul}" + noise_cp_placements[key] = _placements_cp_model_io["noise"] + noise_placements[key] = _placements_model_io["noise"] + + feats_and_noise = distribute_atom_features( + inputs_atom, + _placements_cp_atom_features | noise_cp_placements, + _placements_atom_features | noise_placements, + manager.device_mesh_subgroups, + manager.group["cp"], + multiplicities={f"_noise_{i}": multiplicity for i in range(len(all_noise))}, + ) + + noise_dts = [] + for i_noise in range(len(all_noise)): + noise_dts.append(feats_and_noise.pop(f"_noise_{i_noise}")) + init_noise_dt = noise_dts[0] + step_noise_dts = noise_dts[1:] + + feats_dt = feats_and_noise + feats_dt["token_pad_mask"] = token_pad_mask_dt + + z_trunk_dt = distribute_tensor(z_trunk_device, manager.device_mesh_subgroups, placements_pair) + rel_pos_enc_dt = distribute_tensor(rel_pos_enc_device, manager.device_mesh_subgroups, placements_pair) + + # ------------------------------------------------------------------ + # Make DTensor sample() deterministic via monkeypatching with + # non-zero noise distributed through intersperse-padded layout. + # ------------------------------------------------------------------ + _orig_center_random_augmentation = distributed_diffusion_module.center_random_augmentation + + def _centering_only_augmentation(atom_coords, atom_mask, **kwargs): + kwargs["augmentation"] = False + kwargs["centering"] = True + return _orig_center_random_augmentation(atom_coords, atom_mask, **kwargs) + + _dt_randn_calls = [] + _dt_randn_sequence = [init_noise_dt] + step_noise_dts + + def _fixed_create_distributed_randn(shape, device_mesh, placements, dtype=torch.float32, scale=1.0): + idx = len(_dt_randn_calls) + _dt_randn_calls.append(idx) + noise_dt = _dt_randn_sequence[idx] + if scale != 1.0: + noise_dt = scalar_tensor_op(scale, noise_dt, ElementwiseOp.PROD) + return noise_dt + + monkeypatch.setattr(distributed_diffusion_module, "center_random_augmentation", _centering_only_augmentation) + monkeypatch.setattr(distributed_diffusion_module, "create_distributed_randn", _fixed_create_distributed_randn) + + # ------------------------------------------------------------------ + # Build network_condition_kwargs and run DTensor sample + # ------------------------------------------------------------------ + if internalized_conditioning: + network_condition_kwargs = { + "s_inputs": s_inputs_dt, + "s_trunk": s_trunk_dt, + "feats": feats_dt, + "z_trunk": z_trunk_dt, + "relative_position_encoding": rel_pos_enc_dt, + } + else: + serial_conditioning = SerialDiffusionConditioning(**conditioning_kwargs) + serial_conditioning = serial_conditioning.to(device=manager.device, dtype=dtype) + serial_conditioning.load_state_dict(conditioning_state_dict) + serial_conditioning = serial_conditioning.eval() + + dtensor_conditioning = DistributedDiffusionConditioning( + layer=serial_conditioning, + device_mesh=manager.device_mesh_subgroups, + ).eval() + + with torch.no_grad(): + q_dt, c_dt, enc_bias_dt, dec_bias_dt, trans_bias_dt = dtensor_conditioning( + s_trunk=s_trunk_dt, + z_trunk=z_trunk_dt, + relative_position_encoding=rel_pos_enc_dt, + feats=feats_dt, + ) + + network_condition_kwargs = { + "s_inputs": s_inputs_dt, + "s_trunk": s_trunk_dt, + "feats": feats_dt, + "diffusion_conditioning": { + "q": q_dt, + "c": c_dt, + "atom_enc_bias": enc_bias_dt, + "atom_dec_bias": dec_bias_dt, + "token_trans_bias": trans_bias_dt, + }, + } + + with torch.no_grad(): + out_dt = module.sample( + atom_mask=feats_dt["atom_pad_mask"], + multiplicity=multiplicity, + max_parallel_samples=max_parallel_samples, + **network_condition_kwargs, + ) + + # ------------------------------------------------------------------ + # Comparison: extract real atoms (where atom_pad_mask=1) from both + # the DTensor output (intersperse-padded layout) and serial expected + # (raw layout). The real atoms maintain the same order in both layouts, + # so boolean indexing extracts matching sequences. + # ------------------------------------------------------------------ + dt_full = out_dt["sample_atom_coords"].full_tensor() + dt_mask = feats_dt["atom_pad_mask"].full_tensor().repeat_interleave(multiplicity, 0).bool() + dt_real = dt_full[dt_mask] + + sample_expected = sample_coords_expected_global_host.to(device=manager.device, dtype=dtype) + serial_mask = feats_global_host["atom_pad_mask"].to(device=manager.device).repeat_interleave(multiplicity, 0).bool() + serial_real = sample_expected[serial_mask] + + torch.testing.assert_close(dt_real, serial_real) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [((2, (2, 2)), True, "cuda", "ENV")], + indirect=["setup_env"], + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +@pytest.mark.parametrize("multiplicity", [1, 4], ids=lambda x: f"mul:{x}") +@pytest.mark.parametrize("internalized_conditioning", [False, True], ids=["extern", "intern"]) +def test_atom_diffusion_sample(setup_env, multiplicity, internalized_conditioning: bool): + """Test AtomDiffusion.sample() (inference) with DTensor CP. + + Determinism is achieved via monkeypatching with pre-generated non-zero noise + tensors. Serial uses mocked torch.randn returning the noise sequence; DTensor + uses the same noise distributed through distribute_atom_features (intersperse- + padded layout) via mocked create_distributed_randn. + Uses num_sampling_steps=2 to limit numerical error accumulation across denoising steps. + Exercises max_parallel_samples chunking (max_parallel_samples=2 when multiplicity=4). + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + dtype = torch.float64 + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + seed_by_rank(0, seed=seed) + + size_cp = grid_group_sizes["cp"][0] + B = 1 * grid_group_sizes["dp"] + W = 32 + H = 128 + val_init_min_max = (-0.5, 0.5) + num_sampling_steps = 2 + max_parallel_samples = 2 if multiplicity > 2 else multiplicity + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + N_tokens = 30 * size_cp + N_atoms_raw = N_tokens * n_atoms_per_token_max + N_atoms = ((N_atoms_raw + W - 1) // W) * W + N_msa = 1 + + atom_s = 8 + token_s = 4 + token_z = 4 + atom_z = 8 + atom_encoder_depth = 2 + atom_encoder_heads = 2 + token_transformer_depth = 2 + token_transformer_heads = 2 + atom_decoder_depth = 2 + atom_decoder_heads = 2 + conditioning_transition_layers = 1 + + atom_feature_dim = 3 + 1 + (1 if internalized_conditioning else 0) + boltz_const.num_elements + 4 * 64 + + selected_keys = list(_selected_atom_keys | _selected_token_keys) + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=selected_keys, + ) + feats = {k: v.to(dtype=dtype) if v.dtype == torch.float64 else v for k, v in feats.items()} + + if internalized_conditioning: + v1_input_dim = 2 * token_s + 2 * boltz_const.num_tokens + 1 + len(boltz_const.pocket_contact_info) + s_inputs_dim = v1_input_dim - token_s + else: + s_inputs_dim = token_s + + s_inputs = torch.empty((B, N_tokens, s_inputs_dim), device=device_type, dtype=dtype) + s_trunk = torch.empty((B, N_tokens, token_s), device=device_type, dtype=dtype) + z_trunk = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype) + rel_pos_enc = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype) + init_tensors_uniform([s_inputs, s_trunk, z_trunk, rel_pos_enc], low=val_init_min_max[0], high=val_init_min_max[1]) + + # ------------------------------------------------------------------ + # Build serial modules + # ------------------------------------------------------------------ + if internalized_conditioning: + score_model_kwargs = { + "token_s": token_s, + "token_z": token_z, + "atom_s": atom_s, + "atom_z": atom_z, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "sigma_data": 16, + "dim_fourier": 32, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "atom_feature_dim": atom_feature_dim, + "conditioning_transition_layers": conditioning_transition_layers, + } + conditioning_kwargs = None + conditioning_state_dict = None + serial_atom_diffusion = SerialAtomDiffusionV1( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + alignment_reverse_diff=False, + num_sampling_steps=num_sampling_steps, + ).to(device=device_type, dtype=dtype) + else: + score_model_kwargs = { + "token_s": token_s, + "atom_s": atom_s, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "sigma_data": 16, + "dim_fourier": 32, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "conditioning_transition_layers": conditioning_transition_layers, + } + conditioning_kwargs = { + "token_s": token_s, + "token_z": token_z, + "atom_s": atom_s, + "atom_z": atom_z, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "atom_feature_dim": atom_feature_dim, + "conditioning_transition_layers": conditioning_transition_layers, + } + serial_conditioning = SerialDiffusionConditioning(**conditioning_kwargs).to(device=device_type, dtype=dtype) + serial_conditioning.train() + init_module_params_uniform(serial_conditioning, low=val_init_min_max[0], high=val_init_min_max[1]) + serial_conditioning.apply(SetModuleInfValues()) + conditioning_state_dict = serial_conditioning.state_dict() + serial_atom_diffusion = SerialAtomDiffusionV2( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + alignment_reverse_diff=False, + num_sampling_steps=num_sampling_steps, + ).to(device=device_type, dtype=dtype) + + serial_atom_diffusion.train() + init_module_params_uniform(serial_atom_diffusion, low=val_init_min_max[0], high=val_init_min_max[1]) + serial_atom_diffusion.apply(SetModuleInfValues()) + serial_atom_diffusion.eval() + atom_diffusion_state_dict = serial_atom_diffusion.state_dict() + + # V1 serial sample() needs token_index in feats for token_repr_shape + if internalized_conditioning and "token_index" not in feats: + feats["token_index"] = torch.arange(N_tokens, device=device_type).unsqueeze(0).expand(B, -1) + + N_atoms_actual = feats["atom_pad_mask"].shape[1] + _B_M = B * multiplicity + + # Pre-generate non-zero noise tensors for deterministic comparison. + # sample() calls torch.randn once for init_noise and once per step. + init_noise = torch.empty((_B_M, N_atoms_actual, 3), device=device_type, dtype=dtype) + step_noise_list = [ + torch.empty((_B_M, N_atoms_actual, 3), device=device_type, dtype=dtype) for _ in range(num_sampling_steps) + ] + init_tensors_uniform([init_noise, *step_noise_list], low=val_init_min_max[0], high=val_init_min_max[1]) + + # ------------------------------------------------------------------ + # Serial sample (with monkeypatched determinism using non-zero noise) + # ------------------------------------------------------------------ + serial_mod = serial_diffusion_v1_module if internalized_conditioning else serial_diffusion_v2_module + + def _identity_compute_random_augmentation(multiplicity_arg, device=None, dtype=None): + R = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(_B_M, -1, -1) + tr = torch.zeros(_B_M, 1, 3, device=device, dtype=dtype) + return R, tr + + _serial_randn_calls = [] + _serial_randn_sequence = [init_noise] + step_noise_list + + def _fixed_randn(*args, **kwargs): + idx = len(_serial_randn_calls) + _serial_randn_calls.append(idx) + return _serial_randn_sequence[idx].clone() + + _monkeypatch = pytest.MonkeyPatch() + _monkeypatch.setattr(serial_mod, "compute_random_augmentation", _identity_compute_random_augmentation) + _monkeypatch.setattr(serial_mod.torch, "randn", _fixed_randn) + + with torch.no_grad(): + if internalized_conditioning: + out_serial = serial_atom_diffusion.sample( + atom_mask=feats["atom_pad_mask"], + multiplicity=multiplicity, + max_parallel_samples=max_parallel_samples, + s_inputs=s_inputs, + s_trunk=s_trunk, + z_trunk=z_trunk, + relative_position_encoding=rel_pos_enc, + feats={k: v.clone() for k, v in feats.items()}, + ) + else: + q_cond, c_cond, to_keys, enc_bias, dec_bias, trans_bias = serial_conditioning( + s_trunk=s_trunk, + z_trunk=z_trunk, + relative_position_encoding=rel_pos_enc, + feats={k: v.detach() for k, v in feats.items()}, + ) + out_serial = serial_atom_diffusion.sample( + atom_mask=feats["atom_pad_mask"], + multiplicity=multiplicity, + max_parallel_samples=max_parallel_samples, + s_inputs=s_inputs, + s_trunk=s_trunk, + feats={k: v.clone() for k, v in feats.items()}, + diffusion_conditioning={ + "q": q_cond, + "c": c_cond, + "to_keys": to_keys, + "atom_enc_bias": enc_bias, + "atom_dec_bias": dec_bias, + "token_trans_bias": trans_bias, + }, + ) + + _monkeypatch.undo() + + spawn_multiprocessing( + parallel_assert_atom_diffusion_sample, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + internalized_conditioning, + dtype, + multiplicity, + num_sampling_steps, + max_parallel_samples, + {k: v.cpu() for k, v in atom_diffusion_state_dict.items()}, + {k: v.cpu() for k, v in conditioning_state_dict.items()} if conditioning_state_dict else None, + score_model_kwargs, + conditioning_kwargs, + W, + H, + {k: v.cpu() for k, v in feats.items()}, + s_inputs.cpu(), + s_trunk.cpu(), + z_trunk.cpu(), + rel_pos_enc.cpu(), + init_noise.cpu(), + [n.cpu() for n in step_noise_list], + out_serial["sample_atom_coords"].cpu(), + ) + + +# ====================================================================== +# Test 4: AtomDiffusion helper functions (c_skip, c_out, c_in, etc.) +# ====================================================================== + + +def parallel_assert_atom_diffusion_helpers( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + dtype: torch.dtype, + atom_diffusion_state_dict, + score_model_kwargs: dict, + # Inputs and expected outputs + sigma_global_host: torch.Tensor, + c_skip_expected_host: torch.Tensor, + c_out_expected_host: torch.Tensor, + c_in_expected_host: torch.Tensor, + c_noise_expected_host: torch.Tensor, + loss_weight_expected_host: torch.Tensor, + noise_dist_expected_host: torch.Tensor, + sample_schedule_expected_host: torch.Tensor, +): + """Parallel assertion for AtomDiffusion helper functions.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Create DTensor module from V2 serial (V1 and V2 have identical helper implementations) + from boltz.model.modules.diffusionv2 import AtomDiffusion as SerialAtomDiffusionV2Local + + serial = SerialAtomDiffusionV2Local( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + ) + serial = serial.to(device=manager.device, dtype=dtype) + serial.load_state_dict(atom_diffusion_state_dict) + serial = serial.eval() + + ring_comm = AttentionPairBiasComm( + manager.group["cp"], + manager.layout_subgroups["cp"], + manager.subgroups["cp"][0], + manager.subgroups["cp"][1], + ) + module = DistributedAtomDiffusion( + layer=serial, + device_mesh=manager.device_mesh_subgroups, + ring_comm=ring_comm, + ).eval() + + # Distribute sigma + placements_scalar = (Shard(0), Replicate(), Replicate()) + sigma_dt = distribute_tensor( + sigma_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_scalar, + ) + + # Test c_skip, c_out, c_in, c_noise, loss_weight + for name, fn, expected_host in [ + ("c_skip", module.c_skip, c_skip_expected_host), + ("c_out", module.c_out, c_out_expected_host), + ("c_in", module.c_in, c_in_expected_host), + ("c_noise", module.c_noise, c_noise_expected_host), + ("loss_weight", module.loss_weight, loss_weight_expected_host), + ]: + result = fn(sigma_dt) + expected = expected_host.to(device=manager.device, dtype=dtype) + torch.testing.assert_close(result.full_tensor(), expected, msg=lambda m: f"{name}: {m}") + + # Test noise_distribution (stochastic — check shape, placement, and dtype) + noise_dist = module.noise_distribution(sigma_dt.shape[0]) + assert ( + noise_dist.shape == sigma_dt.shape + ), f"noise_distribution shape mismatch: {noise_dist.shape} vs {sigma_dt.shape}" + assert noise_dist.placements == placements_scalar + assert noise_dist.dtype == torch.float32, f"noise_distribution default dtype: {noise_dist.dtype} != float32" + + noise_dist_f64 = module.noise_distribution(sigma_dt.shape[0], dtype=torch.float64) + assert noise_dist_f64.dtype == torch.float64, f"noise_distribution float64 dtype: {noise_dist_f64.dtype} != float64" + + # Test sample_schedule (deterministic, returns plain Tensor) + schedule = module.sample_schedule(num_sampling_steps=5) + expected_schedule = sample_schedule_expected_host.to(device=manager.device) + torch.testing.assert_close(schedule, expected_schedule) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [((2, (1, 1)), True, "cuda", "ENV")], + indirect=["setup_env"], + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +def test_dtensor_atom_diffusion_helpers(setup_env): + """Test DTensor AtomDiffusion scalar helper functions. + + Tests c_skip, c_out, c_in, c_noise, loss_weight, noise_distribution, sample_schedule. + V1 and V2 serial implementations are identical for these functions, so only one + serial reference is needed (V2 is used). + Uses dp=2, cp=(1,1) since these functions don't involve CP atom sharding. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + dtype = torch.float64 + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + seed_by_rank(0, seed=seed) + + B = 1 * grid_group_sizes["dp"] + + # Build a minimal serial AtomDiffusion (V2) for the helpers + score_model_kwargs = { + "token_s": 4, + "atom_s": 8, + "atoms_per_window_queries": 32, + "atoms_per_window_keys": 128, + "sigma_data": 16, + "dim_fourier": 32, + "atom_encoder_depth": 1, + "atom_encoder_heads": 1, + "token_transformer_depth": 1, + "token_transformer_heads": 1, + "atom_decoder_depth": 1, + "atom_decoder_heads": 1, + "conditioning_transition_layers": 1, + } + serial = SerialAtomDiffusionV2( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + ).to(device=device_type, dtype=dtype) + serial.eval() + atom_diffusion_state_dict = serial.state_dict() + + # Generate test sigma values + sigma = torch.tensor([0.5, 3.14, 16.0, 160.0], device=device_type, dtype=dtype)[:B] + + # Compute serial reference + c_skip_ref = serial.c_skip(sigma) + c_out_ref = serial.c_out(sigma) + c_in_ref = serial.c_in(sigma) + c_noise_ref = serial.c_noise(sigma) + loss_weight_ref = serial.loss_weight(sigma) + noise_dist_ref = serial.noise_distribution(B) + sample_schedule_ref = serial.sample_schedule(num_sampling_steps=5) + + spawn_multiprocessing( + parallel_assert_atom_diffusion_helpers, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + {k: v.cpu() for k, v in atom_diffusion_state_dict.items()}, + score_model_kwargs, + sigma.cpu(), + c_skip_ref.cpu(), + c_out_ref.cpu(), + c_in_ref.cpu(), + c_noise_ref.cpu(), + loss_weight_ref.cpu(), + noise_dist_ref.cpu(), + sample_schedule_ref.cpu(), + ) + + +# ====================================================================== +# Test 5: AtomDiffusion.compute_loss() +# ====================================================================== + + +def parallel_assert_atom_diffusion_compute_loss( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + internalized_conditioning: bool, + dtype: torch.dtype, + multiplicity: int, + add_smooth_lddt_loss: bool, + nucleotide_loss_weight: float, + ligand_loss_weight: float, + filter_by_plddt: float, + use_triton_kernel: bool, + feats_host: dict, + atom_diffusion_state_dict: dict, + score_model_kwargs: dict, + denoised_atom_coords_global_host: torch.Tensor, + aligned_true_atom_coords_global_host: torch.Tensor, + sigma_global_host: torch.Tensor, + expected_total_loss_global_host: torch.Tensor, + expected_mse_loss_global_host: torch.Tensor, + expected_smooth_lddt_loss_global_host: torch.Tensor, + expected_denoised_atom_coords_grad_global_host: torch.Tensor, + env_per_rank=None, +): + """Parallel assertion for AtomDiffusion.compute_loss(). + + Compares distributed compute_loss (V2-style formula) to serial V1 or V2 + depending on internalized_conditioning. + """ + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Build serial module (V1 or V2) to match reference used for expected values + if internalized_conditioning: + serial_atom_diffusion = SerialAtomDiffusionV1( + score_model_args=score_model_kwargs, + ) + else: + serial_atom_diffusion = SerialAtomDiffusionV2( + score_model_args=score_model_kwargs, + ) + serial_atom_diffusion = serial_atom_diffusion.to(device=manager.device, dtype=dtype) + serial_atom_diffusion.load_state_dict(atom_diffusion_state_dict) + serial_atom_diffusion = serial_atom_diffusion.train() + + ring_comm = AttentionPairBiasComm( + manager.group["cp"], + manager.layout_subgroups["cp"], + manager.subgroups["cp"][0], + manager.subgroups["cp"][1], + ) + transpose_comm = TransposeComm(manager.group["cp"], manager.layout_subgroups["cp"]) + + module = DistributedAtomDiffusion( + layer=serial_atom_diffusion, + device_mesh=manager.device_mesh_subgroups, + ring_comm=ring_comm, + transpose_comm=transpose_comm, + ).train() + + # Placements for compute_loss feats + placements_scalar = (Shard(0), Replicate(), Replicate()) + placements_single = (Shard(0), Shard(1), Replicate()) + + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_host.items() + if k in _placements_cp_atom_features + } + + size_batch = feats_host["atom_resolved_mask"].shape[0] + io_tensors = { + "denoised_atom_coords": denoised_atom_coords_global_host, + "aligned_true_atom_coords": aligned_true_atom_coords_global_host, + "d_r_update_expected": expected_denoised_atom_coords_grad_global_host, + } + for base_name, tensor_host in io_tensors.items(): + unflat = tensor_host.unflatten(0, (size_batch, multiplicity)) + for i_mul in range(multiplicity): + inputs_atom[f"{base_name}_{i_mul}"] = unflat[:, i_mul].to(dtype=dtype) + + io_keys_used = set(io_tensors.keys()) + placements_cp_io_mul = { + f"{k}_{i_mul}": _placements_cp_model_io[k] for k in io_keys_used for i_mul in range(multiplicity) + } + placements_io_mul = {f"{k}_{i_mul}": _placements_model_io[k] for k in io_keys_used for i_mul in range(multiplicity)} + multiplicities = dict.fromkeys(io_keys_used, multiplicity) + + feats_and_io = distribute_atom_features( + inputs_atom, + _placements_cp_atom_features | placements_cp_io_mul, + _placements_atom_features | placements_io_mul, + manager.device_mesh_subgroups, + manager.group["cp"], + multiplicities=multiplicities, + ) + + denoised_atom_coords_dtensor = feats_and_io.pop("denoised_atom_coords").requires_grad_(True) + d_r_update_expected_dtensor = feats_and_io.pop("d_r_update_expected") + aligned_true_atom_coords_dtensor = feats_and_io.pop("aligned_true_atom_coords") + sigma_dtensor = distribute_tensor( + sigma_global_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_scalar, + ) + mol_type = distribute_tensor( + feats_host["mol_type"].to(device=manager.device), + manager.device_mesh_subgroups, + placements_single, + ) + feats_dtensor = { + "mol_type": mol_type, + "atom_to_token": feats_and_io.pop("atom_to_token"), + "atom_resolved_mask": feats_and_io.pop("atom_resolved_mask"), + "plddt": feats_and_io.pop("plddt"), + } + input_dict = { + "denoised_atom_coords": denoised_atom_coords_dtensor, + "sigmas": sigma_dtensor, + "aligned_true_atom_coords": aligned_true_atom_coords_dtensor, + } + input_dict_clone = {k: v.detach().clone().requires_grad_(v.requires_grad) for k, v in input_dict.items()} + + output_dict = module.compute_loss( + feats=feats_dtensor, + out_dict=input_dict, + add_smooth_lddt_loss=add_smooth_lddt_loss, + nucleotide_loss_weight=nucleotide_loss_weight, + ligand_loss_weight=ligand_loss_weight, + multiplicity=multiplicity, + filter_by_plddt=filter_by_plddt, + use_triton_kernel=use_triton_kernel, + ) + + # Ensure input dict is not modified by compute_loss + for k in input_dict.keys(): + assert_tensors_identical( + input_dict[k], + input_dict_clone[k], + check_grad=False, + check_grad_fn=False, + check_storage_offset=True, + check_storage_pointer=False, + ) + + total_loss_dtensor = output_dict["loss"] + mse_loss_dtensor = output_dict["loss_breakdown"]["mse_loss"] + smooth_lddt_loss_dtensor = output_dict["loss_breakdown"]["smooth_lddt_loss"] + assert total_loss_dtensor.placements == ( + Replicate(), + Replicate(), + Replicate(), + ), "total_loss_dtensor should be replicated" + assert mse_loss_dtensor.placements == ( + Replicate(), + Replicate(), + Replicate(), + ), "mse_loss_dtensor should be replicated" + assert smooth_lddt_loss_dtensor.placements == ( + Replicate(), + Replicate(), + Replicate(), + ), "smooth_lddt_loss_dtensor should be replicated" + + total_loss = total_loss_dtensor.full_tensor().cpu() + mse_loss = mse_loss_dtensor.full_tensor().cpu() + smooth_lddt_loss = smooth_lddt_loss_dtensor.full_tensor().cpu() + + assert not (mse_loss == 0.0).all(), "mse_loss should not be 0" + if add_smooth_lddt_loss: + assert not (smooth_lddt_loss == 0.0).all(), "smooth_lddt_loss should not be 0" + assert not (total_loss == 0.0).all(), "total_loss should not be 0" + torch.testing.assert_close(mse_loss, expected_mse_loss_global_host) + torch.testing.assert_close(smooth_lddt_loss, expected_smooth_lddt_loss_global_host) + torch.testing.assert_close(total_loss, expected_total_loss_global_host) + + total_loss_dtensor_clone = total_loss_dtensor.detach().clone().requires_grad_(total_loss_dtensor.requires_grad) + mse_loss_dtensor_clone = mse_loss_dtensor.detach().clone().requires_grad_(mse_loss_dtensor.requires_grad) + smooth_lddt_loss_dtensor_clone = ( + smooth_lddt_loss_dtensor.detach().clone().requires_grad_(smooth_lddt_loss_dtensor.requires_grad) + ) + + total_loss_dtensor.backward() + + assert_tensors_identical( + total_loss_dtensor, + total_loss_dtensor_clone, + check_grad=False, + check_grad_fn=False, + check_storage_offset=True, + check_storage_pointer=False, + ) + assert_tensors_identical( + mse_loss_dtensor, + mse_loss_dtensor_clone, + check_grad=False, + check_grad_fn=False, + check_storage_offset=True, + check_storage_pointer=False, + ) + assert_tensors_identical( + smooth_lddt_loss_dtensor, + smooth_lddt_loss_dtensor_clone, + check_grad=False, + check_grad_fn=False, + check_storage_offset=True, + check_storage_pointer=False, + ) + + torch.testing.assert_close( + denoised_atom_coords_dtensor.grad.full_tensor(), d_r_update_expected_dtensor.full_tensor() + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, device_type={x[2]}, method_init={x[3]}", +) +@pytest.mark.parametrize( + "loss_config", + [ + (False, 1, True, False, 0.0), + (False, 4, True, True, 0.5), + (True, 1, True, False, 0.0), + (True, 4, True, True, 0.0), + ], + ids=lambda x: f"v2={not x[0]}, mul={x[1]}, lddt={x[2]}, triton={x[3]}, plddt={x[4]:.1f}", +) +def test_atom_diffusion_compute_loss( + setup_env, + loss_config, + nucleotide_loss_weight: float = 5.0, + ligand_loss_weight: float = 10.0, + dtype: torch.dtype = torch.float32, +): + """Test AtomDiffusion.compute_loss() with distributed context parallelism. + + Compares DTensor AtomDiffusion.compute_loss() against serial V1 or V2 + (internalized_conditioning True/False) for forward and backward. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + internalized_conditioning, multiplicity, add_smooth_lddt_loss, use_triton_kernel, filter_by_plddt = loss_config + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + if filter_by_plddt > 0 and internalized_conditioning: + pytest.skip("filter_by_plddt is not supported for boltz1 internalized conditioning") + + if not add_smooth_lddt_loss and use_triton_kernel: + pytest.skip("use_triton_kernel requires add_smooth_lddt_loss=True") + + seed = 42 + seed_by_rank(0, seed=seed) + + size_cp = grid_group_sizes["cp"][0] + B = 1 * grid_group_sizes["dp"] + W = 32 + H = 128 + val_init_min_max = (-1.0, 1.0) + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + N_tokens = 10 * size_cp + N_atoms_raw = N_tokens * n_atoms_per_token_max + N_atoms = ((N_atoms_raw + W - 1) // W) * W + N_msa = 1 + + atom_s = 8 + token_s = 4 + token_z = 4 + atom_z = 8 + + atom_encoder_depth = 2 + atom_encoder_heads = 2 + token_transformer_depth = 2 + token_transformer_heads = 2 + atom_decoder_depth = 2 + atom_decoder_heads = 2 + conditioning_transition_layers = 1 + + atom_feature_dim = 3 + 1 + (1 if internalized_conditioning else 0) + boltz_const.num_elements + 4 * 64 + + compute_loss_selected_keys = { + "atom_resolved_mask", + "mol_type", + "atom_to_token", + "atom_counts_per_token", + "plddt", + } + selected_keys = list(_selected_atom_keys | _selected_token_keys | compute_loss_selected_keys) + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=selected_keys, + ) + feats = {k: v.to(dtype=dtype) if v.dtype == torch.float64 else v for k, v in feats.items()} + + # V1 s_inputs has wider dim: + # input_dim = 2 * token_s + 2 * num_tokens + 1 + len(pocket_contact_info) + # s_inputs_dim = input_dim - token_s + # V2 s_inputs has token_s dim + if internalized_conditioning: + v1_input_dim = 2 * token_s + 2 * boltz_const.num_tokens + 1 + len(boltz_const.pocket_contact_info) + s_inputs_dim = v1_input_dim - token_s + else: + s_inputs_dim = token_s + + s_inputs = torch.empty((B, N_tokens, s_inputs_dim), device=device_type, dtype=dtype, requires_grad=True) + s_trunk = torch.empty((B, N_tokens, token_s), device=device_type, dtype=dtype, requires_grad=True) + z_trunk = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype) + rel_pos_enc = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype) + init_tensors_uniform([s_inputs, s_trunk, z_trunk, rel_pos_enc], low=val_init_min_max[0], high=val_init_min_max[1]) + + # noised_atom_coords: (B*M, N_atoms_actual, 3) + N_atoms_actual = feats["atom_pad_mask"].shape[1] + noised_atom_coords = torch.empty( + (B * multiplicity, N_atoms_actual, 3), device=device_type, dtype=dtype, requires_grad=True + ) + init_tensors_uniform([noised_atom_coords], low=val_init_min_max[0], high=val_init_min_max[1]) + + if internalized_conditioning: + score_model_kwargs = { + "token_s": token_s, + "token_z": token_z, + "atom_s": atom_s, + "atom_z": atom_z, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "sigma_data": 16, + "dim_fourier": 32, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "atom_feature_dim": atom_feature_dim, + "conditioning_transition_layers": conditioning_transition_layers, + } + serial_model = SerialAtomDiffusionV1( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + ).to(device=device_type, dtype=dtype) + else: + # V2: uses DiffusionConditioning + score_model_kwargs = { + "token_s": token_s, + "atom_s": atom_s, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "sigma_data": 16, + "dim_fourier": 32, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "conditioning_transition_layers": conditioning_transition_layers, + } + serial_model = SerialAtomDiffusionV2( + score_model_args=score_model_kwargs, + coordinate_augmentation=False, + ).to(device=device_type, dtype=dtype) + init_module_params_uniform(serial_model, low=-0.1, high=0.1) + serial_model.apply(SetModuleInfValues()) + module_state_dict = serial_model.state_dict() + serial_model = serial_model.to(device=device_type, dtype=dtype) + + denoised_atom_coords = torch.empty( + (B * multiplicity, N_atoms, 3), device=device_type, dtype=dtype, requires_grad=True + ) + init_tensors_uniform([denoised_atom_coords], low=val_init_min_max[0], high=val_init_min_max[1]) + aligned_true_atom_coords = torch.empty_like(denoised_atom_coords) + init_tensors_uniform([aligned_true_atom_coords], low=val_init_min_max[0], high=val_init_min_max[1]) + sigma = serial_model.noise_distribution(B * multiplicity).to(device=device_type, dtype=dtype) + denoised_atom_coords.requires_grad = True + + input_dict = { + "denoised_atom_coords": denoised_atom_coords, + "sigmas": sigma, + "aligned_true_atom_coords": aligned_true_atom_coords, + } + feats["coords"] = aligned_true_atom_coords + # V1 compute_loss requires noised_atom_coords in out_dict + if internalized_conditioning: + noised_atom_coords = torch.empty_like(denoised_atom_coords) + init_tensors_uniform([noised_atom_coords], low=val_init_min_max[0], high=val_init_min_max[1]) + input_dict["noised_atom_coords"] = noised_atom_coords + + extra_kwargs = {} + if not internalized_conditioning and filter_by_plddt > 0: + extra_kwargs["filter_by_plddt"] = filter_by_plddt + + output_dict = serial_model.compute_loss( + feats=feats, + out_dict=input_dict, + add_smooth_lddt_loss=add_smooth_lddt_loss, + nucleotide_loss_weight=nucleotide_loss_weight, + ligand_loss_weight=ligand_loss_weight, + multiplicity=multiplicity, + **extra_kwargs, + ) + output_dict["loss"].backward() + + feats_host = {k: v.detach().to(device="cpu", copy=True) if torch.is_tensor(v) else v for k, v in feats.items()} + + spawn_multiprocessing( + parallel_assert_atom_diffusion_compute_loss, + world_size, + grid_group_sizes, + device_type, + backend, + internalized_conditioning, + dtype, + multiplicity, + add_smooth_lddt_loss, + nucleotide_loss_weight, + ligand_loss_weight, + filter_by_plddt, + use_triton_kernel, + feats_host, + {k: v.cpu() for k, v in module_state_dict.items()}, + score_model_kwargs, + denoised_atom_coords.detach().clone().cpu(), + aligned_true_atom_coords.detach().clone().cpu(), + sigma.detach().clone().cpu(), + output_dict["loss"].detach().clone().cpu(), + output_dict["loss_breakdown"]["mse_loss"].detach().clone().cpu(), + output_dict["loss_breakdown"]["smooth_lddt_loss"].detach().clone().cpu(), + denoised_atom_coords.grad.detach().clone().cpu(), + env_per_rank, + ) diff --git a/tests/distributed/model/modules/test_dtensor_diffusion_conditioning.py b/tests/distributed/model/modules/test_dtensor_diffusion_conditioning.py new file mode 100644 index 000000000..047c135dc --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_diffusion_conditioning.py @@ -0,0 +1,551 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor DiffusionConditioning module (V2 only). + +Tests the DTensor DiffusionConditioning against the V2 serial reference, +verifying forward and backward numerical equivalence. + +Uses float64 with default tolerance for exact comparison. +""" + +from functools import partial + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.data import const as boltz_const +from boltz.distributed.data.feature.featurizer import pack_atom_features +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.flatten_and_unflatten import shardwise_unflatten_sharded +from boltz.distributed.model.layers.utils import convert_single_repr_to_window_batched_key +from boltz.distributed.model.modules.diffusion_conditioning import ( + DiffusionConditioning as DistributedDiffusionConditioning, +) +from boltz.model.modules.diffusion_conditioning import DiffusionConditioning as SerialDiffusionConditioning +from boltz.model.modules.encodersv2 import get_indexing_matrix, single_to_keys +from boltz.testing.utils import ( + SetModuleInfValues, + assert_all_identical, + assert_tensors_close_with_pad, + distribute_atom_features, + get_feature_placements, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + pad_or_shrink_to_length, + random_features, + seed_by_rank, + spawn_multiprocessing, +) + +# Subset of keys needed for DiffusionConditioning +_selected_atom_keys = { + "atom_pad_mask", + "ref_pos", + "ref_space_uid", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "atom_to_token", + "atom_counts_per_token", +} + +_placements = get_feature_placements( + token_keys=set(), + msa_keys=set(), + atom_keys=_selected_atom_keys, + model_io_keys=set(), + model_io_fp32_keys=set(), +) +_placements_cp_atom_features = _placements["cp_atom_features"] +_placements_atom_features = _placements["atom_features"] + + +def parallel_assert_diffusion_conditioning( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + dtype: torch.dtype, + # Module dimensions + atom_s: int, + atom_z: int, + token_s: int, + token_z: int, + atom_feature_dim: int, + W: int, + H: int, + atom_encoder_depth: int, + atom_encoder_heads: int, + token_transformer_depth: int, + token_transformer_heads: int, + atom_decoder_depth: int, + atom_decoder_heads: int, + layer_state_dict, + # Inputs + feats_global_host: dict[str, torch.Tensor], + s_trunk_global_host: torch.Tensor, + z_trunk_global_host: torch.Tensor, + rel_pos_enc_global_host: torch.Tensor, + # Expected outputs + q_expected_global_host: torch.Tensor, + c_expected_global_host: torch.Tensor, + atom_enc_bias_expected_global_host: torch.Tensor, + atom_dec_bias_expected_global_host: torch.Tensor, + token_trans_bias_expected_global_host: torch.Tensor, + # Upstream grads + d_q_global_host: torch.Tensor, + d_c_global_host: torch.Tensor, + d_atom_enc_bias_global_host: torch.Tensor, + d_atom_dec_bias_global_host: torch.Tensor, + d_token_trans_bias_global_host: torch.Tensor, + # Expected input grads + d_s_trunk_expected_global_host: torch.Tensor, + d_z_trunk_expected_global_host: torch.Tensor, + d_rel_pos_enc_expected_global_host: torch.Tensor, + # Expected param grads + expected_param_grads_global_host_dict: dict[str, torch.Tensor], +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Re-create serial module (on this rank's device) + module_serial = SerialDiffusionConditioning( + token_s=token_s, + token_z=token_z, + atom_s=atom_s, + atom_z=atom_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_encoder_depth=atom_encoder_depth, + atom_encoder_heads=atom_encoder_heads, + token_transformer_depth=token_transformer_depth, + token_transformer_heads=token_transformer_heads, + atom_decoder_depth=atom_decoder_depth, + atom_decoder_heads=atom_decoder_heads, + atom_feature_dim=atom_feature_dim, + ) + module_serial = module_serial.to(device=manager.device, dtype=dtype) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.train() + + # Create DTensor module from serial + module = DistributedDiffusionConditioning( + layer=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # ------------------------------------------------------------------ + # Distribute atom features (unpacked — module packs internally) + # ------------------------------------------------------------------ + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in _placements_cp_atom_features + } + feats_dt = distribute_atom_features( + inputs_atom, + _placements_cp_atom_features, + _placements_atom_features, + manager.device_mesh_subgroups, + manager.group["cp"], + ) + + # Compute N_atoms_packed from an explicit pack call (for comparison sizing only) + feats_dt_packed_for_sizing = pack_atom_features(feats_dt, set(feats_dt.keys()), W) + N_atoms_packed = feats_dt_packed_for_sizing["atom_pad_mask"].shape[1] + + # Global masks for comparison + atom_pad_mask_global = feats_global_host["atom_pad_mask"].to(device=manager.device, dtype=dtype) + atom_pad_mask_expanded_global = atom_pad_mask_global.unsqueeze(-1) + + # ------------------------------------------------------------------ + # Distribute token-level tensors + # ------------------------------------------------------------------ + placements_single = (Shard(0), Shard(1), Replicate()) + placements_pair = (Shard(0), Shard(1), Shard(2)) + + s_trunk_dt = distribute_tensor( + s_trunk_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ).requires_grad_(True) + z_trunk_dt = distribute_tensor( + z_trunk_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_pair, + ).requires_grad_(True) + rel_pos_enc_dt = distribute_tensor( + rel_pos_enc_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_pair, + ).requires_grad_(True) + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + q_dt, c_dt, atom_enc_bias_dt, atom_dec_bias_dt, token_trans_bias_dt = module( + s_trunk=s_trunk_dt, + z_trunk=z_trunk_dt, + relative_position_encoding=rel_pos_enc_dt, + feats=feats_dt, + ) + + # ------------------------------------------------------------------ + # Forward comparison: q and c (atom-level, padded) + # ------------------------------------------------------------------ + mask_dt_full = feats_dt_packed_for_sizing["atom_pad_mask"].full_tensor() + mask_dt_full_expanded = mask_dt_full.unsqueeze(-1) + + q_expected_device = q_expected_global_host.to(device=manager.device, dtype=dtype) + c_expected_device = c_expected_global_host.to(device=manager.device, dtype=dtype) + + assert_tensors_close_with_pad( + q_dt.full_tensor() * mask_dt_full_expanded, + q_expected_device * atom_pad_mask_expanded_global, + axis=1, + pad_val=0, + ) + assert_tensors_close_with_pad( + c_dt.full_tensor() * mask_dt_full_expanded, + c_expected_device * atom_pad_mask_expanded_global, + axis=1, + pad_val=0, + ) + + # ------------------------------------------------------------------ + # Forward comparison: atom_enc_bias, atom_dec_bias (window-batched atom pair, S(0) S(1) R) + # ------------------------------------------------------------------ + K_packed = N_atoms_packed // W + N_atoms_serial = feats_global_host["atom_pad_mask"].shape[1] + K_serial = N_atoms_serial // W + + mask_dt_query = shardwise_unflatten_sharded( + feats_dt_packed_for_sizing["atom_pad_mask"], axis=1, sizes=(K_packed, W) + ) + mask_dt_query_full = mask_dt_query.full_tensor() + mask_dt_query_full_expanded = mask_dt_query_full[:, :, :, None, None] + mask_dt_key = convert_single_repr_to_window_batched_key(feats_dt_packed_for_sizing["atom_pad_mask"], W, H) + mask_dt_key_full = mask_dt_key.full_tensor() + mask_dt_key_full_expanded = mask_dt_key_full[:, :, None, :, None] + mask_dt_pair_full_expanded = mask_dt_query_full_expanded * mask_dt_key_full_expanded + + compute_dtype = torch.promote_types(dtype, torch.float32) + index_matrix = get_indexing_matrix(K_serial, W, H, manager.device).to(dtype=compute_dtype) + to_keys_fn = partial(single_to_keys, indexing_matrix=index_matrix, W=W, H=H) + + mask_key_expected = to_keys_fn( + feats_global_host["atom_pad_mask"].to(device=manager.device, dtype=compute_dtype).unsqueeze(-1) + ) + mask_key_expected_expanded = mask_key_expected[:, :, None, :, :] + mask_query_expected_expanded = atom_pad_mask_expanded_global.unflatten( + 1, (atom_pad_mask_expanded_global.shape[1] // W, W) + )[:, :, :, None, :] + mask_pair_expected_expanded = mask_query_expected_expanded * mask_key_expected_expanded + + for name, bias_dt, bias_expected_host in [ + ("atom_enc_bias", atom_enc_bias_dt, atom_enc_bias_expected_global_host), + ("atom_dec_bias", atom_dec_bias_dt, atom_dec_bias_expected_global_host), + ]: + bias_expected_device = bias_expected_host.to(device=manager.device, dtype=dtype) + assert_tensors_close_with_pad( + bias_dt.full_tensor() * mask_dt_pair_full_expanded, + bias_expected_device * mask_pair_expected_expanded, + axis=1, + pad_val=0, + ) + + # ------------------------------------------------------------------ + # Forward comparison: token_trans_bias (token pair level) + # ------------------------------------------------------------------ + token_trans_bias_expected_device = token_trans_bias_expected_global_host.to(device=manager.device, dtype=dtype) + torch.testing.assert_close(token_trans_bias_dt.full_tensor(), token_trans_bias_expected_device) + + # ------------------------------------------------------------------ + # Backward pass + # ------------------------------------------------------------------ + d_q_padded = pad_or_shrink_to_length( + d_q_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=N_atoms_packed + ) + d_c_padded = pad_or_shrink_to_length( + d_c_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=N_atoms_packed + ) + d_atom_enc_bias_padded = pad_or_shrink_to_length( + d_atom_enc_bias_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=K_packed + ) + d_atom_dec_bias_padded = pad_or_shrink_to_length( + d_atom_dec_bias_global_host.to(device=manager.device, dtype=dtype), axis=1, target_length=K_packed + ) + + d_q_dt = distribute_tensor(d_q_padded, manager.device_mesh_subgroups, q_dt.placements) + d_c_dt = distribute_tensor(d_c_padded, manager.device_mesh_subgroups, c_dt.placements) + d_atom_enc_bias_dt = distribute_tensor( + d_atom_enc_bias_padded, manager.device_mesh_subgroups, atom_enc_bias_dt.placements + ) + d_atom_dec_bias_dt = distribute_tensor( + d_atom_dec_bias_padded, manager.device_mesh_subgroups, atom_dec_bias_dt.placements + ) + d_token_trans_bias_dt = distribute_tensor( + d_token_trans_bias_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + token_trans_bias_dt.placements, + ) + + torch.autograd.backward( + [q_dt, c_dt, atom_enc_bias_dt, atom_dec_bias_dt, token_trans_bias_dt], + [d_q_dt, d_c_dt, d_atom_enc_bias_dt, d_atom_dec_bias_dt, d_token_trans_bias_dt], + ) + + # Check token-level input gradients + torch.testing.assert_close( + s_trunk_dt.grad.full_tensor(), + d_s_trunk_expected_global_host.to(device=manager.device, dtype=dtype), + ) + torch.testing.assert_close( + z_trunk_dt.grad.full_tensor(), + d_z_trunk_expected_global_host.to(device=manager.device, dtype=dtype), + ) + torch.testing.assert_close( + rel_pos_enc_dt.grad.full_tensor(), + d_rel_pos_enc_expected_global_host.to(device=manager.device, dtype=dtype), + ) + + # Parameter grads + for name, grad_expected_global in expected_param_grads_global_host_dict.items(): + grad_param = get_param_by_key(module, name).grad + assert grad_param is not None, f"Missing grad for param {name}" + + if hasattr(grad_param, "full_tensor"): + grad_global_host = grad_param.full_tensor().cpu() + grad_to_check = grad_param.full_tensor() + else: + grad_global_host = grad_param.detach().cpu() + grad_to_check = grad_param + + torch.testing.assert_close(grad_global_host, grad_expected_global.to(dtype=dtype)) + assert_all_identical(grad_to_check, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env, dtype", + ( + params_test := [ + (((1, (2, 2)), True, "cuda", "ENV"), torch.float64), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64), + ] + ), + indirect=["setup_env"], + ids=[f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, device_type:{x[0][2]}, dtype:{x[1]}" for x in params_test], +) +def test_diffusion_conditioning(setup_env, dtype): + """Test DTensor DiffusionConditioning (V2) vs serial equivalence.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + seed_by_rank(0, seed=seed) + + size_cp = grid_group_sizes["cp"][0] + B = 1 * grid_group_sizes["dp"] + + W = 32 + H = 128 + val_init_min_max = (-0.08, 0.08) + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + N_tokens = 50 * size_cp + N_atoms_raw = N_tokens * n_atoms_per_token_max + N_atoms = ((N_atoms_raw + W - 1) // W) * W + N_msa = 1 + + atom_s = 8 + atom_z = 8 + token_s = 4 + token_z = 4 + + atom_encoder_depth = 2 + atom_encoder_heads = 2 + token_transformer_depth = 3 + token_transformer_heads = 2 + atom_decoder_depth = 2 + atom_decoder_heads = 2 + + atom_feature_dim = 3 + 1 + boltz_const.num_elements + 4 * 64 + + selected_keys = list(_selected_atom_keys) + + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=selected_keys, + ) + feats = {k: v.to(dtype=dtype) if v.dtype == torch.float64 else v for k, v in feats.items()} + + N_atoms_actual = feats["atom_pad_mask"].shape[1] + K = N_atoms_actual // W + + # Token-level inputs + s_trunk = torch.empty((B, N_tokens, token_s), device=device_type, dtype=dtype, requires_grad=True) + z_trunk = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype, requires_grad=True) + rel_pos_enc = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype, requires_grad=True) + init_tensors_uniform([s_trunk, z_trunk, rel_pos_enc], low=val_init_min_max[0], high=val_init_min_max[1]) + + # Build serial reference module + reference_module = SerialDiffusionConditioning( + token_s=token_s, + token_z=token_z, + atom_s=atom_s, + atom_z=atom_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_encoder_depth=atom_encoder_depth, + atom_encoder_heads=atom_encoder_heads, + token_transformer_depth=token_transformer_depth, + token_transformer_heads=token_transformer_heads, + atom_decoder_depth=atom_decoder_depth, + atom_decoder_heads=atom_decoder_heads, + atom_feature_dim=atom_feature_dim, + ).to(device=device_type, dtype=dtype) + reference_module.train() + init_module_params_uniform(reference_module, low=val_init_min_max[0], high=val_init_min_max[1]) + reference_module.apply(SetModuleInfValues()) + layer_state_dict = reference_module.state_dict() + + # Serial forward pass + feats_serial = {k: v.detach().clone() for k, v in feats.items()} + s_trunk_serial = s_trunk.detach().clone().requires_grad_(True) + z_trunk_serial = z_trunk.detach().clone().requires_grad_(True) + rel_pos_enc_serial = rel_pos_enc.detach().clone().requires_grad_(True) + + q_expected, c_expected, _to_keys, atom_enc_bias_expected, atom_dec_bias_expected, token_trans_bias_expected = ( + reference_module( + s_trunk=s_trunk_serial, + z_trunk=z_trunk_serial, + relative_position_encoding=rel_pos_enc_serial, + feats=feats_serial, + ) + ) + + # Upstream gradients + d_q = torch.empty_like(q_expected) + d_c = torch.empty_like(c_expected) + d_atom_enc_bias = torch.empty_like(atom_enc_bias_expected) + d_atom_dec_bias = torch.empty_like(atom_dec_bias_expected) + d_token_trans_bias = torch.empty_like(token_trans_bias_expected) + init_tensors_uniform( + [d_q, d_c, d_atom_enc_bias, d_atom_dec_bias, d_token_trans_bias], + low=val_init_min_max[0], + high=val_init_min_max[1], + ) + + # Apply masks to upstream gradients + mask_expanded = feats_serial["atom_pad_mask"].unsqueeze(-1) + d_q = d_q * mask_expanded + d_c = d_c * mask_expanded + + compute_dtype = torch.promote_types(dtype, torch.float32) + index_matrix = get_indexing_matrix(K, W, H, device_type).to(dtype=compute_dtype) + to_keys_fn_serial = partial(single_to_keys, indexing_matrix=index_matrix, W=W, H=H) + mask_key_serial = to_keys_fn_serial( + feats_serial["atom_pad_mask"].to(dtype=compute_dtype, device=d_atom_enc_bias.device).unsqueeze(-1) + ) + # Pair mask: (B, K, W, H, 1) + pair_mask = mask_key_serial[:, :, None, :, :] * mask_expanded.unflatten(1, (K, W))[:, :, :, None, :] + d_atom_enc_bias = d_atom_enc_bias * pair_mask + d_atom_dec_bias = d_atom_dec_bias * pair_mask + + # Serial backward + torch.autograd.backward( + [q_expected, c_expected, atom_enc_bias_expected, atom_dec_bias_expected, token_trans_bias_expected], + [d_q, d_c, d_atom_enc_bias, d_atom_dec_bias, d_token_trans_bias], + ) + + expected_param_grads = { + name: param.grad.detach().cpu() for name, param in reference_module.named_parameters() if param.grad is not None + } + + spawn_multiprocessing( + parallel_assert_diffusion_conditioning, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + atom_s, + atom_z, + token_s, + token_z, + atom_feature_dim, + W, + H, + atom_encoder_depth, + atom_encoder_heads, + token_transformer_depth, + token_transformer_heads, + atom_decoder_depth, + atom_decoder_heads, + {k: v.detach().cpu() for k, v in layer_state_dict.items()}, + {k: v.detach().cpu() for k, v in feats.items()}, + s_trunk.detach().cpu(), + z_trunk.detach().cpu(), + rel_pos_enc.detach().cpu(), + q_expected.detach().cpu(), + c_expected.detach().cpu(), + atom_enc_bias_expected.detach().cpu(), + atom_dec_bias_expected.detach().cpu(), + token_trans_bias_expected.detach().cpu(), + d_q.detach().cpu(), + d_c.detach().cpu(), + d_atom_enc_bias.detach().cpu(), + d_atom_dec_bias.detach().cpu(), + d_token_trans_bias.detach().cpu(), + s_trunk_serial.grad.detach().cpu(), + z_trunk_serial.grad.detach().cpu(), + rel_pos_enc_serial.grad.detach().cpu(), + expected_param_grads, + ) diff --git a/tests/distributed/model/modules/test_dtensor_diffusion_module.py b/tests/distributed/model/modules/test_dtensor_diffusion_module.py new file mode 100644 index 000000000..e76d42326 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_diffusion_module.py @@ -0,0 +1,611 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor DiffusionModule with window batching. + +Tests the DTensor DiffusionModule against both V1 and V2 serial references, +verifying forward and backward numerical equivalence. + +Parametrized on ``internalized_conditioning``: +- True (V1): module owns pairwise_conditioner, forward takes z_trunk + relative_position_encoding +- False (V2): forward takes pre-computed diffusion_conditioning dict + +Uses float64 with default tolerance for exact comparison. +""" + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.data import const as boltz_const +from boltz.distributed.comm import AttentionPairBiasComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.diffusion import DiffusionModule as DistributedDiffusionModule +from boltz.distributed.model.modules.diffusion_conditioning import ( + DiffusionConditioning as DistributedDiffusionConditioning, +) +from boltz.model.modules.diffusion import DiffusionModule as SerialDiffusionModuleV1 +from boltz.model.modules.diffusion_conditioning import DiffusionConditioning as SerialDiffusionConditioning +from boltz.model.modules.diffusionv2 import DiffusionModule as SerialDiffusionModuleV2 +from boltz.testing.utils import ( + SetModuleInfValues, + distribute_atom_features, + get_feature_placements, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + random_features, + seed_by_rank, + spawn_multiprocessing, +) + +# Atom features needed +_selected_atom_keys = { + "atom_pad_mask", + "ref_pos", + "ref_space_uid", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "atom_to_token", + "atom_counts_per_token", +} +# Token features needed +_selected_token_keys = {"token_pad_mask"} + +_selected_model_io_keys = { + "r_noisy_expected", + "r_update_expected", + "d_r_update_expected", + "d_r_noisy_expected", +} + +_placements = get_feature_placements( + token_keys=set(), + msa_keys=set(), + atom_keys=_selected_atom_keys, + model_io_keys=_selected_model_io_keys, + model_io_fp32_keys=set(), +) +_placements_cp_atom_features = _placements["cp_atom_features"] +_placements_atom_features = _placements["atom_features"] +_placements_model_io = _placements["model_io"] +_placements_cp_model_io = _placements["cp_model_io"] + + +def parallel_assert_diffusion_module( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + internalized_conditioning: bool, + dtype: torch.dtype, + multiplicity: int, + # Module params + diffusion_module_state_dict, + diffusion_conditioning_state_dict, # None for internalized + module_kwargs: dict, + conditioning_kwargs: dict | None, # None for internalized + W: int, + H: int, + # Inputs + feats_global_host: dict[str, torch.Tensor], + s_inputs_global_host: torch.Tensor, + s_trunk_global_host: torch.Tensor, + z_trunk_global_host: torch.Tensor, + rel_pos_enc_global_host: torch.Tensor, + r_noisy_global_host: torch.Tensor, + times_global_host: torch.Tensor, + # Expected outputs + r_update_expected_global_host: torch.Tensor, + # Upstream grad + d_r_update_global_host: torch.Tensor, + # Expected input grads + d_s_inputs_expected_global_host: torch.Tensor, + d_s_trunk_expected_global_host: torch.Tensor, + d_r_noisy_expected_global_host: torch.Tensor, + # Expected param grads + expected_param_grads_global_host_dict: dict[str, torch.Tensor], +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Re-create serial DiffusionModule + if internalized_conditioning: + serial_diffusion_module = SerialDiffusionModuleV1(**module_kwargs) + else: + serial_diffusion_module = SerialDiffusionModuleV2(**module_kwargs) + serial_diffusion_module = serial_diffusion_module.to(device=manager.device, dtype=dtype) + serial_diffusion_module.load_state_dict(diffusion_module_state_dict) + serial_diffusion_module = serial_diffusion_module.train() + + # Create ring_comm for the token-level transformer + ring_comm = AttentionPairBiasComm( + manager.group["cp"], + manager.layout_subgroups["cp"], + manager.subgroups["cp"][0], + manager.subgroups["cp"][1], + ) + + # Create DTensor module + module = DistributedDiffusionModule( + layer=serial_diffusion_module, + device_mesh=manager.device_mesh_subgroups, + ring_comm=ring_comm, + ).train() + + # ------------------------------------------------------------------ + # Distribute token-level tensors (common to both paths) + # ------------------------------------------------------------------ + placements_single = (Shard(0), Shard(1), Replicate()) + placements_pair = (Shard(0), Shard(1), Shard(2)) + placements_times = (Shard(0), Replicate(), Replicate()) + + token_pad_mask_dt = distribute_tensor( + feats_global_host["token_pad_mask"].to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ) + + s_inputs_dt = distribute_tensor( + s_inputs_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ).requires_grad_(True) + s_trunk_dt = distribute_tensor( + s_trunk_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_single, + ).requires_grad_(True) + + times_dt = distribute_tensor( + times_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_times, + ).requires_grad_(False) + + z_trunk_device = z_trunk_global_host.to(device=manager.device, dtype=dtype) + rel_pos_enc_device = rel_pos_enc_global_host.to(device=manager.device, dtype=dtype) + + # ------------------------------------------------------------------ + # Distribute atom features and r_noisy (shared by V1 and V2) + # ------------------------------------------------------------------ + # Both V1 and V2 pass unpacked feats — DiffusionModule packs internally. + # All atom-level I/O must share the same intersperse-padded atom ordering. + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in _placements_cp_atom_features + } + size_batch = feats_global_host["atom_pad_mask"].shape[0] + io_tensors = { + "r_noisy_expected": r_noisy_global_host, + "r_update_expected": r_update_expected_global_host, + "d_r_update_expected": d_r_update_global_host, + "d_r_noisy_expected": d_r_noisy_expected_global_host, + } + for base_name, tensor_host in io_tensors.items(): + unflat = tensor_host.unflatten(0, (size_batch, multiplicity)) + for i_mul in range(multiplicity): + inputs_atom[f"{base_name}_{i_mul}"] = unflat[:, i_mul].to(dtype=dtype) + + placements_cp_io_mul = { + f"{k}_{i_mul}": v for k, v in _placements_cp_model_io.items() for i_mul in range(multiplicity) + } + placements_io_mul = {f"{k}_{i_mul}": v for k, v in _placements_model_io.items() for i_mul in range(multiplicity)} + multiplicities = dict.fromkeys(io_tensors, multiplicity) + + feats_and_io = distribute_atom_features( + inputs_atom, + _placements_cp_atom_features | placements_cp_io_mul, + _placements_atom_features | placements_io_mul, + manager.device_mesh_subgroups, + manager.group["cp"], + multiplicities=multiplicities, + ) + r_noisy_dt = feats_and_io.pop("r_noisy_expected").requires_grad_(True) + r_update_expected_dt = feats_and_io.pop("r_update_expected") + d_r_update_dt_expected = feats_and_io.pop("d_r_update_expected") + d_r_noisy_expected_dt = feats_and_io.pop("d_r_noisy_expected") + feats_dt = feats_and_io + feats_dt["token_pad_mask"] = token_pad_mask_dt + + z_trunk_dt = distribute_tensor(z_trunk_device, manager.device_mesh_subgroups, placements_pair) + rel_pos_enc_dt = distribute_tensor(rel_pos_enc_device, manager.device_mesh_subgroups, placements_pair) + + # ------------------------------------------------------------------ + # Forward pass (depends on internalized_conditioning) + # ------------------------------------------------------------------ + if internalized_conditioning: + # V1: pass z_trunk and rel_pos_enc directly + r_update_result = module( + s_inputs=s_inputs_dt, + s_trunk=s_trunk_dt, + r_noisy=r_noisy_dt, + times=times_dt, + feats=feats_dt, + z_trunk=z_trunk_dt, + relative_position_encoding=rel_pos_enc_dt, + multiplicity=multiplicity, + ) + r_update_dt = r_update_result["r_update"] + else: + # V2: DTensor DiffusionConditioning takes unpacked feats (packs internally) + serial_conditioning = SerialDiffusionConditioning(**conditioning_kwargs) + serial_conditioning = serial_conditioning.to(device=manager.device, dtype=dtype) + serial_conditioning.load_state_dict(diffusion_conditioning_state_dict) + serial_conditioning = serial_conditioning.eval() + + dtensor_conditioning = DistributedDiffusionConditioning( + layer=serial_conditioning, + device_mesh=manager.device_mesh_subgroups, + ).eval() + + with torch.no_grad(): + q_cond_dt, c_cond_dt, atom_enc_bias_dt, atom_dec_bias_dt, token_trans_bias_dt = dtensor_conditioning( + s_trunk=s_trunk_dt.detach(), + z_trunk=z_trunk_dt, + relative_position_encoding=rel_pos_enc_dt, + feats=feats_dt, + ) + + diff_cond_dt = { + "q": q_cond_dt, + "c": c_cond_dt, + "atom_enc_bias": atom_enc_bias_dt, + "atom_dec_bias": atom_dec_bias_dt, + "token_trans_bias": token_trans_bias_dt, + } + + r_update_dt = module( + s_inputs=s_inputs_dt, + s_trunk=s_trunk_dt, + r_noisy=r_noisy_dt, + times=times_dt, + feats=feats_dt, + diffusion_conditioning=diff_cond_dt, + multiplicity=multiplicity, + ) + + # ------------------------------------------------------------------ + # Forward comparison (both V1 and V2 outputs are in intersperse-padded layout) + # ------------------------------------------------------------------ + torch.testing.assert_close(r_update_dt.full_tensor(), r_update_expected_dt.full_tensor()) + + # ------------------------------------------------------------------ + # Backward pass (upstream grad in intersperse-padded layout via distribute_atom_features) + # ------------------------------------------------------------------ + r_update_dt.backward(d_r_update_dt_expected) + + # Check input gradients (token-level) + torch.testing.assert_close( + s_inputs_dt.grad.full_tensor(), + d_s_inputs_expected_global_host.to(device=manager.device, dtype=dtype), + ) + torch.testing.assert_close( + s_trunk_dt.grad.full_tensor(), + d_s_trunk_expected_global_host.to(device=manager.device, dtype=dtype), + ) + + # r_noisy grad comparison (atom-level, intersperse-padded layout) + torch.testing.assert_close(r_noisy_dt.grad.full_tensor(), d_r_noisy_expected_dt.full_tensor()) + + # Parameter grads + for name, grad_expected_global in expected_param_grads_global_host_dict.items(): + grad_param = get_param_by_key(module, name).grad + if grad_param is None: + continue + if hasattr(grad_param, "full_tensor"): + grad_global_host = grad_param.full_tensor().cpu() + else: + grad_global_host = grad_param.detach().cpu() + torch.testing.assert_close( + grad_global_host, + grad_expected_global.to(dtype=dtype), + msg=lambda m: f"Parameter gradient mismatch for {name}: {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=["setup_env"], + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, device_type:{x[2]}", +) +@pytest.mark.parametrize("multiplicity", [1, 4], ids=lambda x: f"mul:{x}") +@pytest.mark.parametrize("internalized_conditioning", [False, True], ids=["extern", "intern"]) +def test_diffusion_module(setup_env, multiplicity, internalized_conditioning: bool): + """Test DTensor DiffusionModule with window batching. + + Parametrized on ``internalized_conditioning``: + - False (V2 / externalized): uses DiffusionConditioning to pre-compute q/c/bias, float64 + - True (V1 / internalized): passes z_trunk + relative_position_encoding directly, float32 + (V1 serial code uses hardcoded .float() casts incompatible with float64) + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + # V1 serial uses .float() casts internally → must test with float32 + # V2 serial uses promote_types → can test with float64 for exact comparison + dtype = torch.float32 if internalized_conditioning else torch.float64 + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + seed = 42 + seed_by_rank(0, seed=seed) + + size_cp = grid_group_sizes["cp"][0] + B = 1 * grid_group_sizes["dp"] + + W = 32 + H = 128 + val_init_min_max = (-0.5, 0.5) + + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + N_tokens = 30 * size_cp + N_atoms_raw = N_tokens * n_atoms_per_token_max + N_atoms = ((N_atoms_raw + W - 1) // W) * W + N_msa = 1 + + atom_s = 8 + token_s = 4 + token_z = 4 + atom_z = 8 + + atom_encoder_depth = 2 + atom_encoder_heads = 2 + token_transformer_depth = 2 + token_transformer_heads = 2 + atom_decoder_depth = 2 + atom_decoder_heads = 2 + conditioning_transition_layers = 1 + + # V1: ref_pos(3) + ref_charge(1) + atom_pad_mask(1) + ref_element + ref_atom_name_chars(4*64) + # V2: ref_pos(3) + ref_charge(1) + ref_element + ref_atom_name_chars(4*64) (no atom_pad_mask) + atom_feature_dim = 3 + 1 + (1 if internalized_conditioning else 0) + boltz_const.num_elements + 4 * 64 + + selected_keys = list(_selected_atom_keys | _selected_token_keys) + + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=selected_keys, + ) + feats = {k: v.to(dtype=dtype) if v.dtype == torch.float64 else v for k, v in feats.items()} + + # Token-level inputs + # V1 s_inputs has wider dim: input_dim - token_s, where + # input_dim = 2 * token_s + 2 * num_tokens + 1 + len(pocket_contact_info) + # V2 s_inputs has token_s dim (s_trunk and s_inputs are same width) + if internalized_conditioning: + v1_input_dim = 2 * token_s + 2 * boltz_const.num_tokens + 1 + len(boltz_const.pocket_contact_info) + s_inputs_dim = v1_input_dim - token_s + else: + s_inputs_dim = token_s + s_inputs = torch.empty((B, N_tokens, s_inputs_dim), device=device_type, dtype=dtype, requires_grad=True) + s_trunk = torch.empty((B, N_tokens, token_s), device=device_type, dtype=dtype, requires_grad=True) + z_trunk = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype) + rel_pos_enc = torch.empty((B, N_tokens, N_tokens, token_z), device=device_type, dtype=dtype) + init_tensors_uniform([s_inputs, s_trunk, z_trunk, rel_pos_enc], low=val_init_min_max[0], high=val_init_min_max[1]) + + # r_noisy: (B*M, N_atoms, 3) + N_atoms_actual = feats["atom_pad_mask"].shape[1] + r_noisy = torch.empty((B * multiplicity, N_atoms_actual, 3), device=device_type, dtype=dtype, requires_grad=True) + times = torch.empty((B * multiplicity,), device=device_type, dtype=dtype) + init_tensors_uniform([r_noisy, times], low=val_init_min_max[0], high=val_init_min_max[1]) + + # ------------------------------------------------------------------ + # Build serial modules and compute reference (depends on internalized_conditioning) + # ------------------------------------------------------------------ + if internalized_conditioning: + # V1: module owns pairwise_conditioner, encoder computes q/c/p internally + module_kwargs = { + "token_s": token_s, + "token_z": token_z, + "atom_s": atom_s, + "atom_z": atom_z, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "sigma_data": 16, + "dim_fourier": 32, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "atom_feature_dim": atom_feature_dim, + "conditioning_transition_layers": conditioning_transition_layers, + } + serial_module = SerialDiffusionModuleV1(**module_kwargs).to(device=device_type, dtype=dtype) + serial_module.train() + init_module_params_uniform(serial_module, low=val_init_min_max[0], high=val_init_min_max[1]) + serial_module.apply(SetModuleInfValues()) + module_state_dict = serial_module.state_dict() + + conditioning_kwargs = None + conditioning_state_dict = None + + # Serial forward: V1 takes z_trunk + relative_position_encoding directly + feats_serial = {k: v.detach().clone() for k, v in feats.items()} + s_inputs_serial = s_inputs.detach().clone().requires_grad_(True) + s_trunk_serial = s_trunk.detach().clone().requires_grad_(True) + r_noisy_serial = r_noisy.detach().clone().requires_grad_(True) + + result_serial = serial_module( + s_inputs=s_inputs_serial, + s_trunk=s_trunk_serial, + z_trunk=z_trunk.detach(), + r_noisy=r_noisy_serial, + times=times.detach(), + relative_position_encoding=rel_pos_enc.detach(), + feats=feats_serial, + multiplicity=multiplicity, + ) + r_update_serial = result_serial["r_update"] + + else: + # V2: uses DiffusionConditioning to pre-compute conditioning + module_kwargs = { + "token_s": token_s, + "atom_s": atom_s, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "sigma_data": 16, + "dim_fourier": 32, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "conditioning_transition_layers": conditioning_transition_layers, + } + serial_module = SerialDiffusionModuleV2(**module_kwargs).to(device=device_type, dtype=dtype) + serial_module.train() + init_module_params_uniform(serial_module, low=val_init_min_max[0], high=val_init_min_max[1]) + serial_module.apply(SetModuleInfValues()) + module_state_dict = serial_module.state_dict() + + conditioning_kwargs = { + "token_s": token_s, + "token_z": token_z, + "atom_s": atom_s, + "atom_z": atom_z, + "atoms_per_window_queries": W, + "atoms_per_window_keys": H, + "atom_encoder_depth": atom_encoder_depth, + "atom_encoder_heads": atom_encoder_heads, + "token_transformer_depth": token_transformer_depth, + "token_transformer_heads": token_transformer_heads, + "atom_decoder_depth": atom_decoder_depth, + "atom_decoder_heads": atom_decoder_heads, + "atom_feature_dim": atom_feature_dim, + "conditioning_transition_layers": conditioning_transition_layers, + } + serial_conditioning = SerialDiffusionConditioning(**conditioning_kwargs).to(device=device_type, dtype=dtype) + serial_conditioning.train() + init_module_params_uniform(serial_conditioning, low=val_init_min_max[0], high=val_init_min_max[1]) + serial_conditioning.apply(SetModuleInfValues()) + conditioning_state_dict = serial_conditioning.state_dict() + + # Serial forward: first conditioning, then diffusion module + feats_serial = {k: v.detach().clone() for k, v in feats.items()} + s_inputs_serial = s_inputs.detach().clone().requires_grad_(True) + s_trunk_serial = s_trunk.detach().clone().requires_grad_(True) + r_noisy_serial = r_noisy.detach().clone().requires_grad_(True) + + with torch.no_grad(): + q_cond, c_cond, to_keys, atom_enc_bias_cond, atom_dec_bias_cond, token_trans_bias_cond = ( + serial_conditioning( + s_trunk=s_trunk.detach(), + z_trunk=z_trunk.detach(), + relative_position_encoding=rel_pos_enc.detach(), + feats={k: v.detach() for k, v in feats.items()}, + ) + ) + + diff_cond_serial = { + "q": q_cond.detach(), + "c": c_cond.detach(), + "to_keys": to_keys, + "atom_enc_bias": atom_enc_bias_cond.detach(), + "atom_dec_bias": atom_dec_bias_cond.detach(), + "token_trans_bias": token_trans_bias_cond.detach(), + } + + r_update_serial = serial_module( + s_inputs=s_inputs_serial, + s_trunk=s_trunk_serial, + r_noisy=r_noisy_serial, + times=times.detach(), + feats=feats_serial, + diffusion_conditioning=diff_cond_serial, + multiplicity=multiplicity, + ) + + # Upstream gradient + d_r_update = torch.empty_like(r_update_serial) + init_tensors_uniform([d_r_update], low=val_init_min_max[0], high=val_init_min_max[1]) + atom_mask_mul = feats_serial["atom_pad_mask"].repeat_interleave(multiplicity, 0).unsqueeze(-1) + d_r_update = d_r_update * atom_mask_mul + + r_update_serial.backward(d_r_update) + + expected_param_grads = { + name: param.grad.detach().cpu() + for name, param in serial_module.named_parameters() + if param.requires_grad and param.grad is not None + } + + spawn_multiprocessing( + parallel_assert_diffusion_module, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + internalized_conditioning, + dtype, + multiplicity, + {k: v.detach().cpu() for k, v in module_state_dict.items()}, + {k: v.detach().cpu() for k, v in conditioning_state_dict.items()} if conditioning_state_dict else None, + module_kwargs, + conditioning_kwargs, + W, + H, + {k: v.detach().cpu() for k, v in feats.items()}, + s_inputs.detach().cpu(), + s_trunk.detach().cpu(), + z_trunk.detach().cpu(), + rel_pos_enc.detach().cpu(), + r_noisy.detach().cpu(), + times.detach().cpu(), + r_update_serial.detach().cpu(), + d_r_update.detach().cpu(), + s_inputs_serial.grad.detach().cpu(), + s_trunk_serial.grad.detach().cpu(), + r_noisy_serial.grad.detach().cpu(), + expected_param_grads, + ) diff --git a/tests/distributed/model/modules/test_dtensor_diffusion_transformer_layer.py b/tests/distributed/model/modules/test_dtensor_diffusion_transformer_layer.py new file mode 100644 index 000000000..55b0f28a1 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_diffusion_transformer_layer.py @@ -0,0 +1,576 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor DiffusionTransformerLayer module. + +Tests both Boltz-1x and Boltz-2 serial DiffusionTransformerLayer modules against +the unified DTensor implementation, verifying forward and backward equivalence. + +Supports two attention modes: +- **Window-batched** (``use_ring_comm=False``): Uses ``AttentionPairBiasShardwise``. + Inputs are 4D/5D window-batched tensors. Tests both V1 and V2 serial modules. +- **Ring attention** (``use_ring_comm=True``): Uses ``AttentionPairBias``. + Inputs are 3D/4D token-level tensors. V2 only (token-level transformer use case). +""" + +from functools import partial + +import pytest +import torch +from torch.distributed.tensor import ( + DTensor, + Replicate, + Shard, + distribute_tensor, +) + +from boltz.distributed.comm import AttentionPairBiasComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.utils import convert_single_repr_window_batched_query_to_key +from boltz.distributed.model.modules.transformers import ( + DiffusionTransformerLayer as DistributedDiffusionTransformerLayer, +) +from boltz.model.modules.encoders import get_indexing_matrix as get_indexing_matrix_v1 +from boltz.model.modules.encoders import single_to_keys as single_to_keys_v1 +from boltz.model.modules.encodersv2 import get_indexing_matrix as get_indexing_matrix_v2 +from boltz.model.modules.encodersv2 import single_to_keys as single_to_keys_v2 +from boltz.model.modules.transformers import DiffusionTransformerLayer as SerialDTLBoltz1 +from boltz.model.modules.transformersv2 import DiffusionTransformerLayer as SerialDTLBoltz2 +from boltz.testing.utils import ( + SetModuleInfValues, + assert_all_identical, + assert_tensors_identical, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_assert_diffusion_transformer_layer( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + serial_module_version: str, + use_ring_comm: bool, + dtype: torch.dtype, + multiplicity: int, + heads: int, + dim: int, + dim_single_cond: int, + dim_pairwise: int, + W: int, + H: int, + post_layer_norm: bool, + layer_state_dict, + a_global_host: torch.Tensor, + s_global_host: torch.Tensor, + z_global_host: torch.Tensor, + mask_global_host: torch.Tensor, + d_out_global_host: torch.Tensor, + out_expected_global_host: torch.Tensor, + d_a_expected_global_host: torch.Tensor, + d_s_expected_global_host: torch.Tensor, + d_z_expected_global_host: torch.Tensor, + expected_param_grads_global_host_dict: dict[str, torch.Tensor], +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Recreate serial module from state dict + if serial_module_version == "boltz1": + module_serial = SerialDTLBoltz1( + heads=heads, + dim=dim, + dim_single_cond=dim_single_cond, + dim_pairwise=dim_pairwise, + ) + else: + module_serial = SerialDTLBoltz2( + heads=heads, + dim=dim, + dim_single_cond=dim_single_cond, + post_layer_norm=post_layer_norm, + ) + + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.to(device=manager.device, dtype=dtype).train() + module_serial.apply(SetModuleInfValues()) + + # Create ring_comm if testing ring attention path + ring_comm = None + if use_ring_comm: + ring_comm = AttentionPairBiasComm( + manager.group["cp"], + manager.layout_subgroups["cp"], + manager.subgroups["cp"][0], + manager.subgroups["cp"][1], + ) + + module = DistributedDiffusionTransformerLayer( + diff_transformer_layer=module_serial, + device_mesh=manager.device_mesh_subgroups, + ring_comm=ring_comm, + ).train() + + # Placements depend on attention mode + placements_single = (Shard(0), Shard(1), Replicate()) + placements_pair = (Shard(0), Shard(1), Shard(2)) + + if use_ring_comm: + # Ring attention: a, s are 3D (B*M, N, D); z is 4D (B, N, N, heads); mask is 2D (B, N) + a_dt = distribute_tensor( + a_global_host.to(device=manager.device, dtype=dtype), manager.device_mesh_subgroups, placements_single + ).requires_grad_(True) + s_dt = distribute_tensor( + s_global_host.to(device=manager.device, dtype=dtype), manager.device_mesh_subgroups, placements_single + ).requires_grad_(True) + z_dt = distribute_tensor( + z_global_host.to(device=manager.device, dtype=dtype), manager.device_mesh_subgroups, placements_pair + ).requires_grad_(True) + mask_dt = distribute_tensor( + mask_global_host.to(device=manager.device, dtype=dtype), manager.device_mesh_subgroups, placements_single + ).requires_grad_(False) + else: + # Window-batched: a, s are 4D (B*M, K, W, D); z is 5D (B, K, W, H, D); mask is 3D (B, K, W) + placements = placements_single + a_dt = distribute_tensor( + a_global_host.to(device=manager.device, dtype=dtype), manager.device_mesh_subgroups, placements + ).requires_grad_(True) + s_dt = distribute_tensor( + s_global_host.to(device=manager.device, dtype=dtype), manager.device_mesh_subgroups, placements + ).requires_grad_(True) + z_dt = distribute_tensor( + z_global_host.to(device=manager.device, dtype=dtype), manager.device_mesh_subgroups, placements + ).requires_grad_(True) + mask_dt = distribute_tensor( + mask_global_host.to(device=manager.device, dtype=dtype), manager.device_mesh_subgroups, placements + ).requires_grad_(False) + + # Copies to ensure inputs aren't modified in-place + a_dt_copy = a_dt.detach().clone().requires_grad_(True) + s_dt_copy = s_dt.detach().clone().requires_grad_(True) + z_dt_copy = z_dt.detach().clone().requires_grad_(True) + mask_dt_copy = mask_dt.detach().clone() + + # Forward pass + if use_ring_comm: + # Ring attention: no to_keys, use multiplicity + out_dt: DTensor = module( + a_dt, + s_dt, + z_dt, + mask=mask_dt, + multiplicity=multiplicity, + layer_cache=None, + pair_mask=None, + ) + else: + # Window-batched: to_keys converts query→key space + to_keys = partial(convert_single_repr_window_batched_query_to_key, W=W, H=H) + out_dt: DTensor = module( + a_dt, + s_dt, + z_dt, + mask=mask_dt, + to_keys=to_keys, + layer_cache=None, + pair_mask=None, + ) + + # Ensure no input mutation + assert_tensors_identical(a_dt_copy.to_local(), a_dt.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(s_dt_copy.to_local(), s_dt.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(z_dt_copy.to_local(), z_dt.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(mask_dt_copy.to_local(), mask_dt.to_local(), check_grad=False, check_grad_fn=False) + + # Forward compare + out_expected_dt = distribute_tensor( + out_expected_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + out_dt.placements, + ) + # Mask the forward comparison: fully-masked attention windows produce numerically + # undefined softmax output (softmax of all -inf), so we only compare unmasked positions. + _mask_local = mask_dt.to_local().unsqueeze(-1) + if multiplicity > 1: + _mask_local = _mask_local.repeat_interleave(multiplicity, 0) + torch.testing.assert_close(out_dt.to_local() * _mask_local, out_expected_dt.to_local() * _mask_local) + _mask_full = mask_dt.full_tensor().unsqueeze(-1) + if multiplicity > 1: + _mask_full = _mask_full.repeat_interleave(multiplicity, 0) + torch.testing.assert_close( + out_dt.full_tensor() * _mask_full, + out_expected_global_host.to(device=manager.device, dtype=dtype) * _mask_full, + ) + + # Backward pass + d_out_dt = distribute_tensor( + d_out_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + out_dt.placements, + ) + d_out_dt_copy = d_out_dt.detach().clone() + out_dt.backward(d_out_dt) + assert_tensors_identical(d_out_dt_copy.to_local(), d_out_dt.to_local(), check_grad=False, check_grad_fn=False) + + # Input grad checks + torch.testing.assert_close(a_dt.grad.full_tensor().cpu(), d_a_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(s_dt.grad.full_tensor().cpu(), d_s_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(z_dt.grad.full_tensor().cpu(), d_z_expected_global_host.to(dtype=dtype)) + + # Parameter grads (gather full tensors) + for name, grad_expected_global in expected_param_grads_global_host_dict.items(): + grad_param = get_param_by_key(module, name).grad + assert grad_param is not None, f"Missing grad for param {name}" + + if isinstance(grad_param, DTensor): + grad_global_host = grad_param.full_tensor().cpu() + grad_to_check = grad_param.full_tensor() + else: + grad_global_host = grad_param.detach().cpu() + grad_to_check = grad_param + + torch.testing.assert_close(grad_global_host, grad_expected_global.to(dtype=dtype)) + assert_all_identical(grad_to_check, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +def _create_serial_reference( + serial_module_version: str, + use_ring_comm: bool, + heads: int, + dim: int, + dim_single_cond: int, + dim_pairwise: int, + post_layer_norm: bool, + device_type: str, + dtype: torch.dtype, + val_init_min_max: tuple[float, float], + B: int, + N: int, + K: int, + W: int, + H: int, + multiplicity: int = 1, +): + """Create serial module and compute reference forward/backward outputs. + + Supports two modes: + - Window-batched (use_ring_comm=False): 4D/5D inputs, to_keys-based attention + - Ring attention (use_ring_comm=True): 3D/4D token-level inputs, V2 only + + Returns tuple of (layer_state_dict, a, s, z, mask, d_out_global, out_expected_global, + d_a_expected_global, d_s_expected_global, d_z_expected_global, + expected_param_grads_global) + """ + if serial_module_version == "boltz1": + get_indexing_matrix = get_indexing_matrix_v1 + single_to_keys = single_to_keys_v1 + reference_module = SerialDTLBoltz1( + heads=heads, + dim=dim, + dim_single_cond=dim_single_cond, + dim_pairwise=dim_pairwise, + ).to(device=device_type, dtype=dtype) + else: + get_indexing_matrix = get_indexing_matrix_v2 + single_to_keys = single_to_keys_v2 + reference_module = SerialDTLBoltz2( + heads=heads, + dim=dim, + dim_single_cond=dim_single_cond, + post_layer_norm=post_layer_norm, + ).to(device=device_type, dtype=dtype) + + reference_module.train() + + # Keep values small to reduce numerical noise + init_module_params_uniform(reference_module, low=val_init_min_max[0], high=val_init_min_max[1]) + reference_module.apply(SetModuleInfValues()) + layer_state_dict = reference_module.state_dict() + + if use_ring_comm: + # ------------------------------------------------------------------ + # Ring attention path: 3D/4D token-level inputs + # a, s: (B*M, N, D); z: (B, N, N, z_dim); mask: (B, N) + # V1: z_dim = dim_pairwise (raw pair repr, proj_z computes bias) + # V2: z_dim = heads (pre-computed bias) + # ------------------------------------------------------------------ + z_last_dim = dim_pairwise if serial_module_version == "boltz1" else heads + a = torch.empty((B * multiplicity, N, dim), device=device_type, dtype=dtype, requires_grad=True) + s = torch.empty((B * multiplicity, N, dim_single_cond), device=device_type, dtype=dtype, requires_grad=True) + z = torch.empty((B, N, N, z_last_dim), device=device_type, dtype=dtype, requires_grad=True) + mask = torch.ones((B, N), device=device_type, dtype=dtype) + mask[0, N // 2 :] = 0 + init_tensors_uniform([a, s, z], low=val_init_min_max[0], high=val_init_min_max[1]) + + # Serial forward: no to_keys. + # Pre-expand z and mask by multiplicity, then pass multiplicity=1 to avoid + # double-expansion (serial AttentionPairBias internally repeats z by multiplicity). + a_serial = a.detach().clone().requires_grad_(True) + s_serial = s.detach().clone().requires_grad_(True) + z_serial = z.detach().clone().requires_grad_(True) + z_serial_mul = z_serial.repeat_interleave(multiplicity, 0) + mask_mul = mask.repeat_interleave(multiplicity, 0) + + if serial_module_version == "boltz1": + out_expected = reference_module( + a_serial, + s_serial, + z_serial_mul, + mask=mask_mul.detach(), + to_keys=None, + multiplicity=1, # pre-applied above + layer_cache=None, + ) + else: + out_expected = reference_module( + a_serial, + s_serial, + bias=z_serial_mul, + mask=mask_mul.detach(), + to_keys=None, + multiplicity=1, # pre-applied above + ) + + d_out = torch.empty_like(out_expected) + init_tensors_uniform([d_out], low=val_init_min_max[0], high=val_init_min_max[1]) + d_out = d_out * mask_mul.unsqueeze(-1) + + out_expected.backward(d_out) + + out_expected_global_host = out_expected.detach().cpu() + d_a_expected_global_host = a_serial.grad.detach().cpu() + d_s_expected_global_host = s_serial.grad.detach().cpu() + d_z_expected_global_host = z_serial.grad.detach().cpu() + d_out_global_host = d_out.detach().cpu() + else: + # ------------------------------------------------------------------ + # Window-batched path: 4D/5D inputs (V1 and V2) + # a, s: (B*M, K, W, D); z: (B, K, W, H, z_dim); mask: (B, K, W) + # ------------------------------------------------------------------ + z_last_dim = dim_pairwise if serial_module_version == "boltz1" else heads + a = torch.empty((B * multiplicity, K, W, dim), device=device_type, dtype=dtype, requires_grad=True) + s = torch.empty((B * multiplicity, K, W, dim_single_cond), device=device_type, dtype=dtype, requires_grad=True) + z = torch.empty((B, K, W, H, z_last_dim), device=device_type, dtype=dtype, requires_grad=True) + mask = torch.ones((B, K, W), device=device_type, dtype=dtype) + mask[0, K // 2 :] = 0 + init_tensors_uniform([a, s, z], low=val_init_min_max[0], high=val_init_min_max[1]) + + # Serial forward: flatten (B*M, K, ...) -> (B*M*K, ...) + a_serial = a.detach().clone().requires_grad_(True) + a_serial_flattened = a_serial.flatten(0, 1) # (B*M*K, W, dim) + s_serial = s.detach().clone().requires_grad_(True) + s_serial_flattened = s_serial.flatten(0, 1) # (B*M*K, W, dim_single_cond) + z_serial = z.detach().clone().requires_grad_(True) + + BM = B * multiplicity + indexing_matrix = get_indexing_matrix(K=K, W=W, H=H, device=device_type).to(dtype=dtype) + + def to_keys_serial(x: torch.Tensor) -> torch.Tensor: + BMK, W_in, Dflat = x.shape + assert BMK == BM * K + assert W_in == W + x_full = x.view(BM, K * W, Dflat) + x_key = single_to_keys(x_full, indexing_matrix, W=W, H=H) # (B*M, K, H, Dflat) + return x_key.flatten(0, 1) # (B*M*K, H, Dflat) + + # Pre-apply multiplicity to z and mask, then pass multiplicity=1 + z_serial_multiplex = z_serial.repeat_interleave(multiplicity, 0) + z_serial_flattened = z_serial_multiplex.flatten(0, 1) + mask_serial_multiplex = mask.detach().repeat_interleave(multiplicity, 0) + mask_serial_flattened = mask_serial_multiplex.flatten(0, 1) + + if serial_module_version == "boltz1": + mask_query = mask_serial_multiplex.view(BM * K, W) + out_expected = reference_module( + a_serial_flattened, + s_serial_flattened, + z_serial_flattened, + mask=mask_query, + to_keys=to_keys_serial, + multiplicity=1, + layer_cache=None, + ) + else: + out_expected = reference_module( + a_serial_flattened, + s_serial_flattened, + bias=z_serial_flattened, + mask=mask_serial_flattened, + to_keys=to_keys_serial, + multiplicity=1, + ) + + d_out = torch.empty_like(out_expected) + init_tensors_uniform([d_out], low=val_init_min_max[0], high=val_init_min_max[1]) + d_out = d_out * mask_serial_flattened.unsqueeze(-1) + + out_expected.backward(d_out) + + # Unflatten (B*M*K, W, ...) -> (B*M, K, W, ...) + out_expected_global_host = out_expected.detach().unflatten(0, a.shape[:2]).cpu() + d_a_expected_global_host = a_serial.grad.detach().cpu() + d_s_expected_global_host = s_serial.grad.detach().cpu() + d_z_expected_global_host = z_serial.grad.detach().cpu() + d_out_global_host = d_out.detach().unflatten(0, a.shape[:2]).cpu() + + expected_param_grads_global_host_dict = { + name: param.grad.detach().cpu() for name, param in reference_module.named_parameters() if param.grad is not None + } + + return ( + layer_state_dict, + a.detach().cpu(), + s.detach().cpu(), + z.detach().cpu(), + mask.detach().cpu(), + d_out_global_host, + out_expected_global_host, + d_a_expected_global_host, + d_s_expected_global_host, + d_z_expected_global_host, + expected_param_grads_global_host_dict, + ) + + +@pytest.mark.slow +@pytest.mark.parametrize("multiplicity", [1, 4], ids=lambda m: f"multiplicity:{m}") +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize("use_ring_comm", [False, True], ids=["wb", "ring"]) +@pytest.mark.parametrize("serial_module_version", ["boltz1", "boltz2"]) +def test_diffusion_transformer_layer(setup_env, multiplicity: int, use_ring_comm: bool, serial_module_version: str): + """Test DiffusionTransformerLayer DTensor vs serial equivalence. + + Parametrized on: + - ``use_ring_comm``: False for window-batched (AttentionPairBiasShardwise), + True for ring attention (AttentionPairBias). + - ``serial_module_version``: "boltz1" or "boltz2". + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + dtype = torch.float32 + seed = 42 + seed_by_rank(0, seed=seed) + + size_cp = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = 10 * size_cp * 32 # token count for ring attention (= K * W) + K = 10 * size_cp # windows, divisible by cp size + W = 32 # queries per window (must be even) + H = 128 # keys per window + val_init_min_max = (-0.2, 0.2) + + dim = 32 + dim_single_cond = dim + dim_pairwise = 32 + heads = 2 + post_layer_norm = False + + ( + layer_state_dict, + a_host, + s_host, + z_host, + mask_host, + d_out_global_host, + out_expected_global_host, + d_a_expected_global_host, + d_s_expected_global_host, + d_z_expected_global_host, + expected_param_grads_global_host_dict, + ) = _create_serial_reference( + serial_module_version=serial_module_version, + use_ring_comm=use_ring_comm, + heads=heads, + dim=dim, + dim_single_cond=dim_single_cond, + dim_pairwise=dim_pairwise, + post_layer_norm=post_layer_norm, + device_type=device_type, + dtype=dtype, + val_init_min_max=val_init_min_max, + B=B, + N=N, + K=K, + W=W, + H=H, + multiplicity=multiplicity, + ) + + spawn_multiprocessing( + parallel_assert_diffusion_transformer_layer, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_module_version, + use_ring_comm, + dtype, + multiplicity, + heads, + dim, + dim_single_cond, + dim_pairwise, + W, + H, + post_layer_norm, + layer_state_dict, + a_host, + s_host, + z_host, + mask_host, + d_out_global_host, + out_expected_global_host, + d_a_expected_global_host, + d_s_expected_global_host, + d_z_expected_global_host, + expected_param_grads_global_host_dict, + ) diff --git a/tests/distributed/model/modules/test_dtensor_encoders.py b/tests/distributed/model/modules/test_dtensor_encoders.py new file mode 100644 index 000000000..86557068f --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_encoders.py @@ -0,0 +1,425 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor-based CP RelativePositionEncoder. + +This module tests the DTensor context-parallel RelativePositionEncoder against +both serial v1 and v2 implementations. + +Verification checks: + V1: single-proc FW input tensor values unchanged by FW and BW + V2: single-proc BW input tensor values unchanged by BW + V4a: multi-proc FW input tensor values unchanged by FW + V4b: multi-proc FW input tensor values unchanged after BW + V5: multi-proc BW input tensor (output_grad) values unchanged by BW + V8: multi-proc FW output tensor values close-to single-proc + V9: (N/A — integer inputs have no gradient) + V10: multi-proc parameter gradient values close-to single-proc + V10b: replicated parameter gradients identical across all CP ranks + +bf16 coverage (CUDA-only): + Both serial and distributed forward passes are wrapped in + torch.autocast("cuda", dtype=torch.bfloat16). The module weights + remain fp32; autocast handles the downcast for matmul ops. This + verifies that outputs and parameter gradients match under mixed + precision. +""" + +from collections import OrderedDict +from contextlib import nullcontext + +import pytest +import torch +from torch import Tensor +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor +from torch.testing import assert_close + +from boltz.distributed.comm import TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.encoders import ( + RelativePositionEncoder as DistributedRelativePositionEncoder, +) +from boltz.model.modules.encoders import ( + RelativePositionEncoder as SerialRelativePositionEncoderV1, +) +from boltz.model.modules.encodersv2 import ( + RelativePositionEncoder as SerialRelativePositionEncoderV2, +) +from boltz.testing.utils import ( + assert_all_identical, + assert_tensors_identical, + init_module_params_uniform, + random_features, + skip_if_cuda_not_avail_or_device_count_less_than_word_size, + spawn_multiprocessing, +) + +SEED = 42 + + +def _assert_unchanged(actual, expected, *, serial=False): + """Shorthand for assert_tensors_identical with standard immutability kwargs.""" + assert_tensors_identical( + actual, + expected, + check_stride=True, + check_grad=False, + check_grad_fn=False, + check_storage_pointer=False, + check_storage_offset=serial, + ) + + +_RELPOS_KEYS = ["asym_id", "entity_id", "residue_index", "token_index", "sym_id", "cyclic_period"] + + +def _make_serial_module( + serial_version: str, + token_z: int, + r_max: int, + s_max: int, + fix_sym_check: bool, + cyclic_pos_enc: bool, +): + """Construct the appropriate serial RelativePositionEncoder.""" + if serial_version == "v1": + return SerialRelativePositionEncoderV1(token_z=token_z, r_max=r_max, s_max=s_max) + return SerialRelativePositionEncoderV2( + token_z=token_z, + r_max=r_max, + s_max=s_max, + fix_sym_check=fix_sym_check, + cyclic_pos_enc=cyclic_pos_enc, + ) + + +def _get_serial_reference( + B: int, + N: int, + token_z: int, + r_max: int, + s_max: int, + fix_sym_check: bool, + cyclic_pos_enc: bool, + serial_version: str, + device: str = "cpu", + use_autocast: bool = False, + seed: int = SEED, +): + """Run serial RelativePositionEncoder and collect all reference data. + + Returns inputs, outputs, gradients, and state_dict for distributed comparison. + Also performs serial immutability checks. + + When use_autocast=True, the forward pass is wrapped in + torch.autocast("cuda", dtype=torch.bfloat16) and the module weights + stay fp32 (autocast handles downcast). + """ + with torch.random.fork_rng(devices=[], enabled=True): + torch.manual_seed(seed) + + serial = _make_serial_module(serial_version, token_z, r_max, s_max, fix_sym_check, cyclic_pos_enc) + init_module_params_uniform(serial, low=-0.5, high=0.5) + serial = serial.to(device=device) + state_dict = {k: v.cpu().clone() for k, v in serial.state_dict().items()} + + feats = random_features( + size_batch=B, + n_tokens=N, + n_atoms=N, + n_msa=1, + atom_counts_per_token_range=(1, 1), + device=torch.device(device), + float_value_range=(-1.0, 1.0), + selected_keys=_RELPOS_KEYS, + ) + + # Clone inputs for V1 immutability check + feats_copy = {k: v.detach().clone() for k, v in feats.items()} + + ac_ctx = torch.autocast("cuda", dtype=torch.bfloat16) if use_autocast else nullcontext() + + with ac_ctx: + out_ref = serial(feats) + + # V1a: serial FW input tensor values unchanged + for k in feats: + _assert_unchanged(feats[k], feats_copy[k], serial=True) + + grad_out = torch.randn_like(out_ref) + grad_out_copy = grad_out.detach().clone() + + out_ref.backward(grad_out) + + # V1b: serial FW input tensor values unchanged after backward + for k in feats: + _assert_unchanged(feats[k], feats_copy[k], serial=True) + + # V2: serial BW input tensor (grad_out) values unchanged + _assert_unchanged(grad_out, grad_out_copy, serial=True) + + param_grads = OrderedDict() + for n, p in serial.named_parameters(): + if p.grad is not None: + param_grads[n] = p.grad.detach().cpu().clone() + + feats_host = {k: v.detach().cpu().clone() for k, v in feats.items()} + + return ( + feats_host, + out_ref.detach().cpu(), + grad_out.detach().cpu(), + param_grads, + state_dict, + ) + + +def _worker_relpos_parity( + rank: int, + feats_on_host: dict[str, Tensor], + output_ref_on_host: Tensor, + grad_output_on_host: Tensor, + param_grads_ref: OrderedDict[str, Tensor | None], + state_dict: dict, + token_z: int, + r_max: int, + s_max: int, + fix_sym_check: bool, + cyclic_pos_enc: bool, + serial_version: str, + use_autocast: bool, + grid_group_sizes: dict, + device_type: str, + backend: str, + env_map: dict[str, str] | None = None, +): + """Worker: compare distributed RelativePositionEncoder against serial reference. + + Performs V4a, V4b, V5, V8, V10, V10b checks. + When use_autocast=True, wraps the distributed forward in + torch.autocast("cuda", dtype=torch.bfloat16). + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + dm = DistributedManager() + + serial = _make_serial_module(serial_version, token_z, r_max, s_max, fix_sym_check, cyclic_pos_enc) + serial.load_state_dict(state_dict) + serial = serial.to(device=dm.device) + + transpose_comm = TransposeComm( + process_group=dm.group["cp"], + group_layout=dm.layout_subgroups["cp"], + ) + dist_mod = DistributedRelativePositionEncoder(serial, dm.device_mesh_subgroups, transpose_comm) + dist_mod.train() + + # Distribute features: single representation placements (Shard(0), Shard(1), Replicate()) + single_placements = (Shard(0), Shard(1)) + (Replicate(),) * (dm.device_mesh_subgroups.ndim - 2) + feats_dt = {} + for key, val in sorted(feats_on_host.items()): + feats_dt[key] = distribute_tensor( + val.to(dm.device), + dm.device_mesh_subgroups, + single_placements, + ) + + # V4a setup: clone inputs for immutability check + feats_dt_copy = {k: v.detach().clone().requires_grad_(v.requires_grad) for k, v in feats_dt.items()} + + ac_ctx = torch.autocast("cuda", dtype=torch.bfloat16) if use_autocast else nullcontext() + + # Forward + with ac_ctx: + out = dist_mod(feats_dt) + + # V4a: FW input tensor values unchanged by FW (binary identity) + for key in feats_dt: + _assert_unchanged(feats_dt[key].to_local(), feats_dt_copy[key].to_local()) + + # V8: forward parity + out_full = out.full_tensor() + ref_on_device = output_ref_on_host.to(device=dm.device) + # bf16 autocast: output dtype is bf16; ref was also produced under autocast + fw_atol = 1e-2 if use_autocast else 1e-5 + fw_rtol = 1e-2 if use_autocast else 1.3e-6 + assert_close( + out_full, + ref_on_device, + atol=fw_atol, + rtol=fw_rtol, + msg=lambda m: f"Rank {rank} forward output mismatch\n{m}", + ) + + # Backward setup + pair_placements = (Shard(0), Shard(1), Shard(2)) + grad_out_dt = distribute_tensor( + grad_output_on_host.to(device=dm.device), + dm.device_mesh_subgroups, + pair_placements, + ) + + # V5 setup: clone output and grad for immutability check + out_clone = out.detach().clone().requires_grad_(out.requires_grad) + grad_out_dt_clone = grad_out_dt.detach().clone().requires_grad_(grad_out_dt.requires_grad) + + out.backward(grad_out_dt) + + # V4b: FW input tensor values unchanged after backward + for key in feats_dt: + _assert_unchanged(feats_dt[key].to_local(), feats_dt_copy[key].to_local()) + + # V5: BW input tensor values unchanged by BW + _assert_unchanged(out.to_local(), out_clone.to_local()) + _assert_unchanged(grad_out_dt.to_local(), grad_out_dt_clone.to_local()) + + # V10: parameter gradient parity + # Tolerance rationale: the linear weight gradient is accumulated across + # world_size ranks via reduce-sum, changing float32 accumulation order. + # For N_accum~1376 elements and up to 8 ranks, the expected absolute + # error is O(sqrt(N_accum) * N_ranks * eps_f32) ≈ 3.5e-05. + # bf16 autocast: forward used bf16 intermediates so small values may + # be flushed to zero. Backward runs outside autocast (grads are fp32) + # but inherits the bf16 rounding from saved tensors. With N_ranks=4 + # (cp=2×2), partial-sum vs full-sum accumulation order differences on + # bf16-rounded inputs can reach ~8×eps_bf16 ≈ 0.063 per element. + grad_atol = 8e-2 if use_autocast else 5e-5 + grad_rtol = 8e-2 if use_autocast else 5e-5 + checked = 0 + for name, param in dist_mod.named_parameters(): + ref = param_grads_ref.get(name) + if ref is None: + continue + assert param.grad is not None, f"Param {name} grad is None but serial had gradient" + actual = param.grad.full_tensor() if isinstance(param.grad, DTensor) else param.grad + assert_close( + actual, + ref.to(device=dm.device), + atol=grad_atol, + rtol=grad_rtol, + msg=lambda m, n=name: f"Rank {rank} {n} grad mismatch\n{m}", + ) + + # V10b: replicated parameter gradients identical across all CP ranks + grad_for_ident = actual.detach() if not isinstance(param.grad, DTensor) else param.grad.full_tensor().detach() + assert_all_identical(grad_for_ident, dm.group["cp"]) + + checked += 1 + + assert checked >= 1, f"Expected at least 1 parameter gradient check, got {checked}" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, fix_sym_check, cyclic_pos_enc, serial_version, use_autocast", + ( + _params := [ + (((2, (2, 2)), True, "cuda", "ENV"), False, False, "v2", False), + (((2, (2, 2)), True, "cuda", "ENV"), True, True, "v2", True), + (((2, (2, 2)), True, "cuda", "ENV"), False, False, "v1", False), + (((2, (3, 3)), True, "cpu", "ENV"), False, False, "v2", False), + ] + ), + indirect=("setup_env",), + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, device:{x[0][2]}, " + f"fix_sym:{x[1]}, cyclic:{x[2]}, ver:{x[3]}, autocast:{x[4]}" + for x in _params + ], +) +def test_dtensor_relpos_forward_backward( + setup_env, + fix_sym_check: bool, + cyclic_pos_enc: bool, + serial_version: str, + use_autocast: bool, +): + """RelativePositionEncoder: distributed output and gradients match serial reference. + + Covers fp32 parity (CPU+CUDA) and bf16 autocast precision (CUDA-only) + across multiple mesh topologies, both v1 and v2 serial modules. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + skip_if_cuda_not_avail_or_device_count_less_than_word_size( + device_type=device_type, + world_size=world_size, + ) + + B = 2 + token_z = 32 + r_max = 8 + s_max = 2 + N = 8 * grid_group_sizes["cp"][0] + + ( + feats_host, + out_ref, + grad_out, + param_grads, + state_dict, + ) = _get_serial_reference( + B=B, + N=N, + token_z=token_z, + r_max=r_max, + s_max=s_max, + fix_sym_check=fix_sym_check, + cyclic_pos_enc=cyclic_pos_enc, + serial_version=serial_version, + device=device_type, + use_autocast=use_autocast, + seed=SEED, + ) + + spawn_multiprocessing( + _worker_relpos_parity, + world_size, + feats_host, + out_ref, + grad_out, + param_grads, + state_dict, + token_z, + r_max, + s_max, + fix_sym_check, + cyclic_pos_enc, + serial_version, + use_autocast, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/distributed/model/modules/test_dtensor_fourier_embedding.py b/tests/distributed/model/modules/test_dtensor_fourier_embedding.py new file mode 100644 index 000000000..a03fced4a --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_fourier_embedding.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor FourierEmbedding module. + +Tests both Boltz-1x and Boltz-2 serial FourierEmbedding modules against the unified +DTensor FourierEmbedding implementation, verifying forward-only equivalence. +FourierEmbedding has frozen (non-trainable) parameters, so no backward test is needed. + +The V1 and V2 serial FourierEmbedding implementations are identical in structure +and math; the test parametrizes over both to verify the DTensor wrapper accepts +either serial class. + +Adapted from Boltz-1x CP test (tests_v1/distributed/model/modules/test_dtensor_encoders.py). +""" + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.encoders import FourierEmbedding as DTensorFourierEmbedding +from boltz.model.modules.encoders import FourierEmbedding as FourierEmbeddingBoltz1 +from boltz.model.modules.encodersv2 import FourierEmbedding as FourierEmbeddingBoltz2 +from boltz.testing.utils import ( + assert_all_identical, + assert_tensors_identical, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_assert_fourier_embedding( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + serial_class_tag: str, + dim: int, + layer_state_dict, + input_global_host: torch.Tensor, + output_expected_global_host: torch.Tensor, +): + """Parallel assertion for FourierEmbedding forward pass.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Create serial module from state dict + serial_class = FourierEmbeddingBoltz1 if serial_class_tag == "boltz1" else FourierEmbeddingBoltz2 + module_serial = serial_class(dim) + module_serial = module_serial.to(device=manager.device) + module_serial.load_state_dict(layer_state_dict) + + # Create DTensor module from serial + module_dt = DTensorFourierEmbedding(module_serial, manager.device_mesh_subgroups) + module_dt.train() + + # Placements: times is (B,) sharded along DP axis + placements_times = (Shard(0), Replicate(), Replicate()) + + # Distribute input + input_dt = distribute_tensor( + input_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_times, + ).requires_grad_(False) + + input_dt_copy = input_dt.detach().clone() + + # Forward pass + output_dt = module_dt(input_dt) + + # Verify input wasn't modified + assert_tensors_identical(input_dt_copy.to_local(), input_dt.to_local(), check_grad=False, check_grad_fn=False) + + # Forward compare: local shard + output_expected_dt = distribute_tensor( + output_expected_global_host.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_times, + ) + torch.testing.assert_close(output_dt.to_local(), output_expected_dt.to_local()) + + # Verify all CP ranks produce identical output (frozen params, replicated computation) + assert_all_identical(output_dt.to_local().detach(), manager.group["cp"]) + + # Verify full tensor matches serial reference + torch.testing.assert_close(output_dt.full_tensor().cpu(), output_expected_global_host) + + # Verify all parameters are frozen + for name, param in module_dt.named_parameters(): + assert not param.requires_grad, f"Parameter {name} should be frozen but requires_grad=True" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize( + "serial_class_tag", + ["boltz1", "boltz2"], + ids=["v1", "v2"], +) +def test_fourier_embedding(setup_env, serial_class_tag: str): + """Test FourierEmbedding DTensor vs serial equivalence for both V1 and V2. + + FourierEmbedding has frozen parameters (non-trainable), so this is a + forward-only test. V1 and V2 serial implementations are identical; both + are tested to verify the DTensor wrapper's isinstance check accepts either. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + B = 2 * grid_group_sizes["dp"] + dim = 256 + + seed_by_rank(0, seed=42) + + serial_class = FourierEmbeddingBoltz1 if serial_class_tag == "boltz1" else FourierEmbeddingBoltz2 + + # Create serial module — cast to device before forward + module_serial = serial_class(dim) + module_serial = module_serial.to(device=device_type) + module_serial.train() + layer_state_dict = module_serial.state_dict() + + # Create input + input_global = torch.rand((B,), device=device_type, requires_grad=False) + + # Serial forward (no backward — frozen params) + output_expected_global = module_serial(input_global) + + # Move to host for multiprocessing + input_global_host = input_global.detach().cpu() + output_expected_global_host = output_expected_global.detach().cpu() + + spawn_multiprocessing( + parallel_assert_fourier_embedding, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_class_tag, + dim, + layer_state_dict, + input_global_host, + output_expected_global_host, + ) diff --git a/tests/distributed/model/modules/test_dtensor_input_embedder_wb.py b/tests/distributed/model/modules/test_dtensor_input_embedder_wb.py new file mode 100644 index 000000000..25605b922 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_input_embedder_wb.py @@ -0,0 +1,749 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor InputEmbedder for Boltz-2. + +Tests the distributed InputEmbedder forward and backward passes using window- +batching atom attention, with the serial Boltz-2 InputEmbedder as reference. + +Verification: + V8: multi-proc FW output tensor values close-to single-proc + V10: multi-proc parameter gradient values close-to single-proc +""" + +import pytest +import torch + +from boltz.data import const +from boltz.distributed.data.utils import distribute_features +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.trunkv2 import InputEmbedder as DistInputEmbedder +from boltz.model.modules.trunkv2 import InputEmbedder as SerialInputEmbedder +from boltz.testing.utils import ( + SetModuleInfValues, + assert_all_identical, + distribute_atom_features, + get_feature_placements, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + random_features, + seed_by_rank, + spawn_multiprocessing, +) + +_selected_token_keys = { + "token_pad_mask", + "res_type", +} +_selected_msa_keys = { + "profile", + "deletion_mean", +} +_selected_atom_keys = { + "atom_pad_mask", + "ref_pos", + "ref_space_uid", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "atom_to_token", + "atom_counts_per_token", +} + +# Additional token-level keys required when add_extra_feats=True +_extra_token_keys = { + "method_feature", + "modified", + "cyclic_period", + "mol_type", +} + + +def _get_placements(add_extra_feats: bool): + """Build placement dicts, optionally including extra conditioning keys.""" + # mol_type and cyclic_period are known to get_feature_placements; + # method_feature and modified are not, so we add them manually. + known_extra = {"mol_type", "cyclic_period"} + token_keys = _selected_token_keys | (known_extra if add_extra_feats else set()) + placements = get_feature_placements( + token_keys=token_keys, + msa_keys=_selected_msa_keys, + atom_keys=_selected_atom_keys, + model_io_keys=set(), + model_io_fp32_keys=set(), + ) + if add_extra_feats: + placements_single = placements["single"] + for key in ("method_feature", "modified"): + placements["token_features"][key] = placements_single + return placements + + +# Boltz-2 atom_feature_dim: 3 (ref_pos) + 1 (ref_charge) + 128 (ref_element) + 256 (ref_atom_name_chars) +_ATOM_FEATURE_DIM = 388 + + +def parallel_assert_input_embedder( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + dtype: torch.dtype, + atom_s: int, + atom_z: int, + token_s: int, + token_z: int, + W: int, + H: int, + atom_encoder_depth: int, + atom_encoder_heads: int, + add_extra_feats: bool, + layer_state_dict, + feats_global_host: dict[str, torch.Tensor], + d_s_global_host: torch.Tensor, + s_expected_global_host: torch.Tensor, + expected_param_grads_global_host_dict: dict[str, torch.Tensor], +): + """Parallel worker for testing DTensor InputEmbedder forward and backward.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + try: + placements = _get_placements(add_extra_feats) + placements_token_features = placements["token_features"] + placements_msa_features = placements["msa_features"] + placements_cp_atom_features = placements["cp_atom_features"] + placements_atom_features = placements["atom_features"] + placements_single = placements["single"] + + selected_token_msa_keys = ( + _selected_token_keys | _selected_msa_keys | (_extra_token_keys if add_extra_feats else set()) + ) + + # Recreate serial module from state dict + module_serial = SerialInputEmbedder( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=_ATOM_FEATURE_DIM, + atom_encoder_depth=atom_encoder_depth, + atom_encoder_heads=atom_encoder_heads, + add_method_conditioning=add_extra_feats, + add_modified_flag=add_extra_feats, + add_cyclic_flag=add_extra_feats, + add_mol_type_feat=add_extra_feats, + ) + module_serial = module_serial.to(device=manager.device, dtype=dtype) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.train() + module_serial.apply(SetModuleInfValues()) + + # Create distributed module + module = DistInputEmbedder( + module=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # Get token_pad_mask for valid-region comparison + token_pad_mask_global = feats_global_host["token_pad_mask"].to(device=manager.device, dtype=torch.bool) + token_pad_mask_expanded_global = token_pad_mask_global.unsqueeze(-1) + + # ==================================================================== + # Distribute token and MSA features + # ==================================================================== + if manager.group_rank["world"] == 0: + input_feats_token_msa_global = { + k: v.to(device=manager.device, dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in selected_token_msa_keys + } + else: + input_feats_token_msa_global = None + + feats_token_msa = distribute_features( + input_feats_token_msa_global, + placements_token_features | placements_msa_features, + manager.group["world"], + manager.group_ranks["world"][0], + manager.device_mesh_subgroups, + ) + + # ==================================================================== + # Distribute atom features via distribute_atom_features + # ==================================================================== + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in placements_cp_atom_features + } + + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp_atom_features, + placements_atom_features, + manager.device_mesh_subgroups, + manager.group["cp"], + ) + + # ==================================================================== + # Merge all features (pack_atom_features is internalized in the module) + # ==================================================================== + feats_dt = {**feats_token_msa, **feats_atom} + + # The serial AtomAttentionEncoder computes atom_to_token_mean with an + # epsilon bias (+ 1e-6) while the distributed scatter_reduce("mean") + # uses exact division. This causes ~1e-5 absolute error in FP64. + tol = {} + if dtype == torch.float64: + tol = {"atol": 5e-5, "rtol": 1e-4} + + # ==================================================================== + # Forward pass + # ==================================================================== + s_dt = module(feats_dt) + + s_dt_full = s_dt.full_tensor() + s_expected_device = s_expected_global_host.to(device=manager.device, dtype=dtype) + torch.testing.assert_close( + s_dt_full * token_pad_mask_expanded_global, + s_expected_device * token_pad_mask_expanded_global, + **tol, + ) + + # ==================================================================== + # Backward pass + # ==================================================================== + d_s_expected_dtensor = distribute_features( + {"d_s": d_s_global_host.to(device=manager.device, dtype=dtype)} + if manager.group_rank["world"] == 0 + else None, + {"d_s": placements_single}, + manager.group["world"], + manager.group_ranks["world"][0], + manager.device_mesh_subgroups, + )["d_s"] + + s_dt.backward(d_s_expected_dtensor) + + # Parameter grads comparison + for name, grad_expected_global in expected_param_grads_global_host_dict.items(): + grad_param = get_param_by_key(module, name).grad + assert grad_param is not None, f"Missing grad for param {name}" + + if hasattr(grad_param, "full_tensor"): + grad_global_host = grad_param.full_tensor().cpu() + grad_to_check = grad_param.full_tensor() + else: + grad_global_host = grad_param.detach().cpu() + grad_to_check = grad_param + + torch.testing.assert_close(grad_global_host, grad_expected_global.to(dtype=dtype), **tol) + assert_all_identical(grad_to_check, manager.group["cp"]) + + finally: + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env, dtype, add_extra_feats", + ( + params_test := [ + # 2 GPU Test + (((2, (1, 1)), True, "cuda", "ENV"), torch.float32, False), + (((2, (1, 1)), True, "cuda", "ENV"), torch.float32, True), + # # 8 GPU Test + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, True), + ] + ), + indirect=["setup_env"], + ids=[f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, dtype:{x[1]}, extra_feats:{x[2]}" for x in params_test], +) +def test_input_embedder_window_batching(setup_env, dtype, add_extra_feats): + """Test DTensor InputEmbedder forward and backward for Boltz-2.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + seed = 42 + seed_by_rank(0, seed=seed) + + size_cp = grid_group_sizes["cp"][0] + B = 1 * grid_group_sizes["dp"] + + val_init_min_max_dtype = {torch.float64: (-0.08, 0.08), torch.float32: (-0.02, 0.02)} + val_init_min_max = val_init_min_max_dtype[dtype] + + selected_keys = list(_selected_token_keys | _selected_msa_keys | _selected_atom_keys) + if add_extra_feats: + selected_keys.extend(["mol_type", "cyclic_period"]) + + W = 32 + H = 128 + + # n_atoms_per_token range chosen so that N_atoms > W, ensuring there is + # interspersed padding going into the parallel data path to actually test + # the pack_atom_features code path inside the distributed InputEmbedder. + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + N_tokens = 30 * size_cp + N_atoms = (N_tokens * n_atoms_per_token_max + W - 1) // W * W + # MSA features are not used in the InputEmbedder + # so we set N_msa to 1 (independent of size_cp) + N_msa = 1 + + assert N_tokens % size_cp == 0, f"N_tokens ({N_tokens}) must be divisible by size_cp ({size_cp})" + + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=selected_keys, + ) + assert feats["token_pad_mask"].shape == (B, N_tokens) + + if add_extra_feats: + feats["method_feature"] = torch.randint( + 0, const.num_method_types, (B, N_tokens), device=feats["res_type"].device + ) + feats["modified"] = torch.randint(0, 2, (B, N_tokens), device=feats["res_type"].device) + + atom_s = 8 + atom_z = 8 + token_s = 2 + token_z = 2 + atom_encoder_depth = 2 + atom_encoder_heads = 2 + + feats = {k: v.to(dtype=dtype) if v.dtype.is_floating_point else v for k, v in feats.items()} + + reference_module = SerialInputEmbedder( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=_ATOM_FEATURE_DIM, + atom_encoder_depth=atom_encoder_depth, + atom_encoder_heads=atom_encoder_heads, + add_method_conditioning=add_extra_feats, + add_modified_flag=add_extra_feats, + add_cyclic_flag=add_extra_feats, + add_mol_type_feat=add_extra_feats, + ).to(device=device_type, dtype=dtype) + reference_module.train() + + init_module_params_uniform(reference_module, low=val_init_min_max[0], high=val_init_min_max[1]) + reference_module.apply(SetModuleInfValues()) + + layer_state_dict = reference_module.state_dict() + + # Serial forward pass + feats_serial = {k: v.detach().clone() for k, v in feats.items()} + s_expected = reference_module(feats_serial) + s_expected_global_host = s_expected.detach().cpu() + + # Serial backward pass + d_s = torch.empty_like(s_expected) + init_tensors_uniform([d_s], low=val_init_min_max[0], high=val_init_min_max[1]) + d_s = d_s * feats_serial["token_pad_mask"].unsqueeze(-1) + s_expected.backward(d_s) + d_s_global_host = d_s.detach().cpu() + expected_param_grads_global_host_dict = { + name: param.grad.detach().cpu() for name, param in reference_module.named_parameters() if param.grad is not None + } + + feats_global_host = {k: v.detach().cpu() for k, v in feats.items()} + + spawn_multiprocessing( + parallel_assert_input_embedder, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + atom_s, + atom_z, + token_s, + token_z, + W, + H, + atom_encoder_depth, + atom_encoder_heads, + add_extra_feats, + layer_state_dict, + feats_global_host, + d_s_global_host, + s_expected_global_host, + expected_param_grads_global_host_dict, + ) + + +# ======================================================================== +# Standalone BF16 mixed-precision test (via torch.autocast) +# ======================================================================== + + +def parallel_assert_input_embedder_bf16( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + atom_s: int, + atom_z: int, + token_s: int, + token_z: int, + W: int, + H: int, + atom_encoder_depth: int, + atom_encoder_heads: int, + add_extra_feats: bool, + layer_state_dict, + feats_global_host: dict[str, torch.Tensor], + d_s_global_host: torch.Tensor, + s_expected_global_host: torch.Tensor, + serial_output_dtype: torch.dtype, + serial_param_grad_dtypes: dict[str, torch.dtype], +): + """Parallel worker for BF16 mixed-precision test of DTensor InputEmbedder. + + Simulates production BF16 mixed-precision training by wrapping the + forward pass in ``torch.autocast("cuda", dtype=torch.bfloat16)``. + Module weights and input features are FP32; autocast handles precision + inside eligible operations. The distributed AtomEncoder internally + disables autocast (matching serial behavior) for numerical stability. + + Checks: + - Output dtype matches the serial reference output dtype. + - Output values are close to serial reference (with mixed-precision tolerance). + - Parameter gradient dtypes match serial gradient dtypes. + """ + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + try: + placements = _get_placements(add_extra_feats) + placements_token_features = placements["token_features"] + placements_msa_features = placements["msa_features"] + placements_cp_atom_features = placements["cp_atom_features"] + placements_atom_features = placements["atom_features"] + placements_single = placements["single"] + + selected_token_msa_keys = ( + _selected_token_keys | _selected_msa_keys | (_extra_token_keys if add_extra_feats else set()) + ) + + # Recreate serial module from FP32 state dict (production mixed-precision setup) + module_serial = SerialInputEmbedder( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=_ATOM_FEATURE_DIM, + atom_encoder_depth=atom_encoder_depth, + atom_encoder_heads=atom_encoder_heads, + add_method_conditioning=add_extra_feats, + add_modified_flag=add_extra_feats, + add_cyclic_flag=add_extra_feats, + add_mol_type_feat=add_extra_feats, + ) + module_serial = module_serial.to(device=manager.device, dtype=torch.float32) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.train() + module_serial.apply(SetModuleInfValues()) + + # AtomEncoder internally disables autocast, matching serial behavior + module = DistInputEmbedder( + module=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # Get token_pad_mask for valid-region comparison + token_pad_mask_global = feats_global_host["token_pad_mask"].to(device=manager.device, dtype=torch.bool) + token_pad_mask_expanded_global = token_pad_mask_global.unsqueeze(-1) + + # ==================================================================== + # Distribute token and MSA features (FP32) + # ==================================================================== + if manager.group_rank["world"] == 0: + input_feats_token_msa_global = { + k: v.to(device=manager.device, dtype=torch.float32 if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in selected_token_msa_keys + } + else: + input_feats_token_msa_global = None + + feats_token_msa = distribute_features( + input_feats_token_msa_global, + placements_token_features | placements_msa_features, + manager.group["world"], + manager.group_ranks["world"][0], + manager.device_mesh_subgroups, + ) + + # ==================================================================== + # Distribute atom features (FP32) + # ==================================================================== + inputs_atom = { + k: v.to(dtype=torch.float32 if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in placements_cp_atom_features + } + + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp_atom_features, + placements_atom_features, + manager.device_mesh_subgroups, + manager.group["cp"], + ) + + feats_dt = {**feats_token_msa, **feats_atom} + + # ==================================================================== + # Forward pass under autocast (mirrors production training) + # ==================================================================== + with torch.autocast("cuda", dtype=torch.bfloat16): + s_dt = module(feats_dt) + + s_dt_full = s_dt.full_tensor() + + # Dtype check: distributed output dtype must match serial + assert ( + s_dt_full.dtype == serial_output_dtype + ), f"Distributed output dtype ({s_dt_full.dtype}) != serial output dtype ({serial_output_dtype})" + + # Value check with mixed-precision tolerance + s_expected_device = s_expected_global_host.to(device=manager.device) + compare_dtype = torch.promote_types(s_dt_full.dtype, s_expected_device.dtype) + torch.testing.assert_close( + (s_dt_full * token_pad_mask_expanded_global).to(compare_dtype), + (s_expected_device * token_pad_mask_expanded_global).to(compare_dtype), + atol=0.05, + rtol=0.05, + ) + + # ==================================================================== + # Backward pass (outside autocast, matching production training) + # ==================================================================== + d_s_expected_dtensor = distribute_features( + {"d_s": d_s_global_host.to(device=manager.device)} if manager.group_rank["world"] == 0 else None, + {"d_s": placements_single}, + manager.group["world"], + manager.group_ranks["world"][0], + manager.device_mesh_subgroups, + )["d_s"] + + s_dt.backward(d_s_expected_dtensor) + + # Gradient dtype check: each parameter's grad dtype must match serial + for name, expected_dtype in serial_param_grad_dtypes.items(): + grad_param = get_param_by_key(module, name).grad + assert grad_param is not None, f"Missing grad for param {name}" + actual_dtype = grad_param.full_tensor().dtype if hasattr(grad_param, "full_tensor") else grad_param.dtype + assert ( + actual_dtype == expected_dtype + ), f"Grad dtype mismatch for {name}: distributed={actual_dtype}, serial={expected_dtype}" + + finally: + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env, add_extra_feats", + ( + params_test_bf16 := [ + (((2, (2, 2)), True, "cuda", "ENV"), False), + (((2, (2, 2)), True, "cuda", "ENV"), True), + ] + ), + indirect=["setup_env"], + ids=[f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, extra_feats:{x[1]}" for x in params_test_bf16], +) +def test_input_embedder_bf16(setup_env, add_extra_feats): + """BF16 mixed-precision test for DTensor InputEmbedder. + + Simulates production training with ``torch.autocast("cuda", dtype=torch.bfloat16)`` + wrapping the forward pass. Module weights and input features are FP32; + autocast handles precision internally. Verifies that the distributed and + serial modules produce matching output dtypes and gradient dtypes. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + seed = 42 + seed_by_rank(0, seed=seed) + + size_cp = grid_group_sizes["cp"][0] + B = 1 * grid_group_sizes["dp"] + val_init_min_max = (-0.02, 0.02) + + selected_keys = list(_selected_token_keys | _selected_msa_keys | _selected_atom_keys) + if add_extra_feats: + selected_keys.extend(["mol_type", "cyclic_period"]) + + W = 32 + H = 128 + n_atoms_per_token_min = 8 + n_atoms_per_token_max = 20 + N_tokens = 30 * size_cp + N_atoms = (N_tokens * n_atoms_per_token_max + W - 1) // W * W + N_msa = 1 + + assert N_tokens % size_cp == 0 + + feats = random_features( + size_batch=B, + n_tokens=N_tokens, + n_atoms=N_atoms, + n_msa=N_msa, + atom_counts_per_token_range=(n_atoms_per_token_min, n_atoms_per_token_max), + device=torch.device(device_type), + float_value_range=val_init_min_max, + selected_keys=selected_keys, + ) + # random_features returns float64; cast to FP32 for production-like mixed precision + feats = {k: v.to(torch.float32) if v.is_floating_point() else v for k, v in feats.items()} + assert feats["token_pad_mask"].shape == (B, N_tokens) + + if add_extra_feats: + feats["method_feature"] = torch.randint( + 0, const.num_method_types, (B, N_tokens), device=feats["res_type"].device + ) + feats["modified"] = torch.randint(0, 2, (B, N_tokens), device=feats["res_type"].device) + + atom_s = 8 + atom_z = 8 + token_s = 2 + token_z = 2 + atom_encoder_depth = 2 + atom_encoder_heads = 2 + + reference_module = SerialInputEmbedder( + atom_s=atom_s, + atom_z=atom_z, + token_s=token_s, + token_z=token_z, + atoms_per_window_queries=W, + atoms_per_window_keys=H, + atom_feature_dim=_ATOM_FEATURE_DIM, + atom_encoder_depth=atom_encoder_depth, + atom_encoder_heads=atom_encoder_heads, + add_method_conditioning=add_extra_feats, + add_modified_flag=add_extra_feats, + add_cyclic_flag=add_extra_feats, + add_mol_type_feat=add_extra_feats, + ).to(device=device_type, dtype=torch.float32) + reference_module.train() + + init_module_params_uniform(reference_module, low=val_init_min_max[0], high=val_init_min_max[1]) + reference_module.apply(SetModuleInfValues()) + + layer_state_dict = reference_module.state_dict() + + # Serial forward with autocast (mirrors production training) + feats_serial = {k: v.detach().clone() for k, v in feats.items()} + with torch.autocast("cuda", dtype=torch.bfloat16): + s_expected = reference_module(feats_serial) + serial_output_dtype = s_expected.dtype + + # Serial backward (outside autocast) + d_s = torch.empty_like(s_expected) + init_tensors_uniform([d_s], low=val_init_min_max[0], high=val_init_min_max[1]) + d_s = d_s * feats_serial["token_pad_mask"].unsqueeze(-1) + s_expected.backward(d_s) + d_s_global_host = d_s.detach().cpu() + + serial_param_grad_dtypes = { + name: param.grad.dtype for name, param in reference_module.named_parameters() if param.grad is not None + } + + s_expected_global_host = s_expected.detach().cpu() + feats_global_host = {k: v.detach().cpu() for k, v in feats.items()} + + spawn_multiprocessing( + parallel_assert_input_embedder_bf16, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + atom_s, + atom_z, + token_s, + token_z, + W, + H, + atom_encoder_depth, + atom_encoder_heads, + add_extra_feats, + layer_state_dict, + feats_global_host, + d_s_global_host, + s_expected_global_host, + serial_output_dtype, + serial_param_grad_dtypes, + ) diff --git a/tests/distributed/model/modules/test_dtensor_msa_layer.py b/tests/distributed/model/modules/test_dtensor_msa_layer.py new file mode 100644 index 000000000..782769503 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_msa_layer.py @@ -0,0 +1,602 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for Boltz-2 CP MSALayer (distributed.model.modules.trunkv2.MSALayer).""" + +import pytest +import torch +from torch.distributed.tensor import Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.trunkv2 import MSALayer as DistributedMSALayer +from boltz.model.modules.trunkv2 import MSALayer as SerialMSALayer +from boltz.testing.utils import ( + assert_all_identical, + assert_no_percentile_upshift, + assert_tensors_identical, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + set_dtype_specific_inf_values, + spawn_multiprocessing, +) + + +def parallel_assert_msa_layer( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + msa_s, + token_z, + msa_dropout, + z_dropout, + pairwise_head_width, + pairwise_num_heads, + layer_state_dict, + input_z_global_host, + input_m_global_host, + token_mask_global_host, + msa_mask_global_host, + output_z_expected_global_host, + output_m_expected_global_host, + d_output_z_expected_global_host, + d_output_m_expected_global_host, + d_input_z_expected_global_host, + d_input_m_expected_global_host, + expected_param_grads_global_host_dict, + output_z_global_fp32_host: torch.Tensor | None = None, + output_m_global_fp32_host: torch.Tensor | None = None, + d_input_z_global_fp32_host: torch.Tensor | None = None, + d_input_m_global_fp32_host: torch.Tensor | None = None, + grad_params_fp32_global_host: dict[str, torch.Tensor] | None = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + if torch.finfo(dtype).resolution < torch.finfo(output_z_expected_global_host.dtype).resolution: + raise ValueError( + f"Target dtype {dtype} has higher precision than reference output's dtype {output_z_expected_global_host.dtype}" + ) + + if ( + (output_z_global_fp32_host is None) != (output_m_global_fp32_host is None) + or (output_z_global_fp32_host is None) != (d_input_z_global_fp32_host is None) + or (output_z_global_fp32_host is None) != (d_input_m_global_fp32_host is None) + or (output_z_global_fp32_host is None) != (grad_params_fp32_global_host is None) + ): + raise ValueError( + "output_z_global_fp32_host, output_m_global_fp32_host, d_input_z_global_fp32_host, d_input_m_global_fp32_host, and grad_params_fp32_global_host must be either all None or all not None" + ) + + check_error_hist = output_z_global_fp32_host is not None + + # Create serial reference module + module_serial = SerialMSALayer( + msa_s=msa_s, + token_z=token_z, + msa_dropout=msa_dropout, + z_dropout=z_dropout, + pairwise_head_width=pairwise_head_width, + pairwise_num_heads=pairwise_num_heads, + ) + module_serial = module_serial.to(dtype=dtype, device=manager.device) + module_serial.load_state_dict(layer_state_dict) + set_dtype_specific_inf_values(module_serial, dtype) + + # Create distributed module + module = DistributedMSALayer(module_serial, manager) + module.train() + + # Input tensors have sharding patterns: + # z: (B, N, N, token_z) - sharded on dims 0, 1, 2 + # m: (B, S, N, msa_s) - sharded on dims 0, 1, 2 + # token_mask: (B, N, N) - sharded on dims 0, 1, 2 + # msa_mask: (B, S, N) - sharded on dims 0, 1, 2 + placements_z_token_mask = (Shard(0), Shard(1), Shard(2)) # For z and token_mask tensors + placements_m_msa_mask = (Shard(0), Shard(1), Shard(2)) # For m and msa_mask tensors + + # Distribute input tensors + input_z_dtensor = distribute_tensor( + input_z_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_token_mask, + ).requires_grad_(True) + + input_m_dtensor = distribute_tensor( + input_m_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_m_msa_mask, + ).requires_grad_(True) + + token_mask_dtensor = distribute_tensor( + token_mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_token_mask, + ) + + msa_mask_dtensor = distribute_tensor( + msa_mask_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_m_msa_mask, + ) + + # Distribute expected outputs + d_output_z_expected_dtensor = distribute_tensor( + d_output_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_token_mask, + ) + d_output_m_expected_dtensor = distribute_tensor( + d_output_m_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_m_msa_mask, + ) + output_z_expected_dtensor = distribute_tensor( + output_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_token_mask, + src_data_rank=None, + ) + output_m_expected_dtensor = distribute_tensor( + output_m_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_m_msa_mask, + src_data_rank=None, + ) + d_input_z_expected_dtensor = distribute_tensor( + d_input_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_token_mask, + src_data_rank=None, + ) + d_input_m_expected_dtensor = distribute_tensor( + d_input_m_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_m_msa_mask, + src_data_rank=None, + ) + + # Create copies to verify inputs aren't modified + input_z_dtensor_copy = input_z_dtensor.detach().clone().requires_grad_(True) + input_m_dtensor_copy = input_m_dtensor.detach().clone().requires_grad_(True) + token_mask_dtensor_copy = token_mask_dtensor.detach().clone() + msa_mask_dtensor_copy = msa_mask_dtensor.detach().clone() + + if check_error_hist: + # Forward and backward pass for error histogram checking + output_z_dtensor_result, output_m_dtensor_result = module( + input_z_dtensor, input_m_dtensor, token_mask_dtensor, msa_mask_dtensor + ) + torch.autograd.backward( + [output_z_dtensor_result, output_m_dtensor_result], + [d_output_z_expected_dtensor, d_output_m_expected_dtensor], + ) + + # Distribute FP32 reference results for comparison + output_z_fp32_dtensor = distribute_tensor( + output_z_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_token_mask, + src_data_rank=None, + ) + output_m_fp32_dtensor = distribute_tensor( + output_m_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_m_msa_mask, + src_data_rank=None, + ) + d_input_z_fp32_dtensor = distribute_tensor( + d_input_z_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z_token_mask, + src_data_rank=None, + ) + d_input_m_fp32_dtensor = distribute_tensor( + d_input_m_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_m_msa_mask, + src_data_rank=None, + ) + + # Check error histograms for outputs + assert_no_percentile_upshift( + output_z_dtensor_result.to_local(), + output_z_expected_dtensor.to_local(), + output_z_fp32_dtensor.to_local(), + names_input=("output_z_cp_fp32", "output_z_serial_fp64", "output_z_serial_fp32"), + ) + + assert_no_percentile_upshift( + output_m_dtensor_result.to_local(), + output_m_expected_dtensor.to_local(), + output_m_fp32_dtensor.to_local(), + names_input=("output_m_cp_fp32", "output_m_serial_fp64", "output_m_serial_fp32"), + ) + + # Check error histograms for input gradients + assert_no_percentile_upshift( + input_z_dtensor.grad.to_local(), + d_input_z_expected_dtensor.to_local(), + d_input_z_fp32_dtensor.to_local(), + names_input=("d_input_z_cp_fp32", "d_input_z_serial_fp64", "d_input_z_serial_fp32"), + ) + + assert_no_percentile_upshift( + input_m_dtensor.grad.to_local(), + d_input_m_expected_dtensor.to_local(), + d_input_m_fp32_dtensor.to_local(), + names_input=("d_input_m_cp_fp32", "d_input_m_serial_fp64", "d_input_m_serial_fp32"), + ) + + # Check parameter gradients error histograms + for name, grad_param_expected_global in expected_param_grads_global_host_dict.items(): + grad_param_result_global = get_param_by_key(module, name).grad.full_tensor().cpu() + assert_no_percentile_upshift( + grad_param_result_global, + grad_param_expected_global.to(dtype=grad_param_result_global.dtype), + grad_params_fp32_global_host[name], + names_input=(f"d_{name}_cp_fp32", f"d_{name}_serial_fp64", f"d_{name}_serial_fp32"), + ) + else: + # Forward pass + output_z_dtensor_result, output_m_dtensor_result = module( + input_z_dtensor, input_m_dtensor, token_mask_dtensor, msa_mask_dtensor + ) + + # Verify inputs weren't modified + assert_tensors_identical( + input_z_dtensor_copy.to_local(), input_z_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical( + input_m_dtensor_copy.to_local(), input_m_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical(token_mask_dtensor_copy.to_local(), token_mask_dtensor.to_local()) + assert_tensors_identical(msa_mask_dtensor_copy.to_local(), msa_mask_dtensor.to_local()) + + # Test forward pass results + torch.testing.assert_close(output_z_dtensor_result.to_local(), output_z_expected_dtensor.to_local()) + torch.testing.assert_close(output_m_dtensor_result.to_local(), output_m_expected_dtensor.to_local()) + + # Backward pass + d_output_z_expected_dtensor_copy = d_output_z_expected_dtensor.detach().clone() + d_output_m_expected_dtensor_copy = d_output_m_expected_dtensor.detach().clone() + torch.autograd.backward( + [output_z_dtensor_result, output_m_dtensor_result], + [d_output_z_expected_dtensor, d_output_m_expected_dtensor], + ) + + # Verify upstream gradients weren't modified + assert_tensors_identical(d_output_z_expected_dtensor_copy.to_local(), d_output_z_expected_dtensor.to_local()) + assert_tensors_identical(d_output_m_expected_dtensor_copy.to_local(), d_output_m_expected_dtensor.to_local()) + + # Test input gradients + torch.testing.assert_close(input_z_dtensor.grad.to_local(), d_input_z_expected_dtensor.to_local()) + torch.testing.assert_close(input_m_dtensor.grad.to_local(), d_input_m_expected_dtensor.to_local()) + + # Test full tensor gathering - verify distributed results match serial results + output_z_global_result_host = output_z_dtensor_result.full_tensor().cpu() + output_m_global_result_host = output_m_dtensor_result.full_tensor().cpu() + d_input_z_global_result_host = input_z_dtensor.grad.full_tensor().cpu() + d_input_m_global_result_host = input_m_dtensor.grad.full_tensor().cpu() + + # Verify full tensors match expected results + torch.testing.assert_close(output_z_global_result_host, output_z_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(output_m_global_result_host, output_m_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_z_global_result_host, d_input_z_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_m_global_result_host, d_input_m_expected_global_host.to(dtype=dtype)) + + # Gather weight gradients using named_parameters + # NOTE: the layer weights are all replicated and their gradients are in Partial(Sum) state + # of their dtensor form so testing the full_tensor() results is equivalent to testing the + # DTensor versions + result_param_grads_dict = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in expected_param_grads_global_host_dict: + raise ValueError(f"Parameter {name} has a resulting gradient but it is not in the reference module") + result_param_grads_dict[name] = param.grad + + # Compare parameter gradients + for name, expected_grad_global_host in expected_param_grads_global_host_dict.items(): + assert name in result_param_grads_dict, f"Parameter {name}'s gradient is not found in result gradients" + result_grad = result_param_grads_dict[name] + result_grad_global = result_grad.full_tensor() + torch.testing.assert_close(result_grad_global.cpu(), expected_grad_global_host.to(dtype=dtype)) + assert_all_identical(result_grad_global, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env, dtype, check_error_hist", + ( + params_test := [ + ## CUDA tests (2 GPUs) + (((2, (1, 1)), True, "cuda", "ENV"), torch.float32, True), + (((2, (1, 1)), True, "cuda", "ENV"), torch.float64, True), + ## CUDA tests (8 GPUs) + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, False), + ## CPU tests + (((1, (3, 3)), True, "cuda", "ENV"), torch.float32, False), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}, check_error_hist:{x[2]}" + for x in params_test + ], +) +def test_msa_layer_parallel(setup_env, dtype, check_error_hist): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + # dtype is the dtype used by the parallel computation + # check_error_hist determine whether to compare the error histograms between + # (CP_in_FP32, serial_in_FP64) and (serial_in_FP32, serial_in_FP64) + # Typically, check_error_hist will use large input dimensions to emulate + # the real-world use cases. Same with dtype==torch.float64. + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + if check_error_hist: + if grid_group_sizes["dp"] > 2: + pytest.skip("skip error histogram check for dp > 1 to save test time") + + # For float64 and error histogram check, we use a realistic model and input size + # with heavier computation to test the numerical stability. On the other hand, + # a smaller model and input size incur less numerical error accumulation to allow + # a larger range of input values to detect logical bugs inexpensively by using + # smaller dimensions. + test_large_model = check_error_hist or dtype == torch.float64 + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + if test_large_model: + N = size_ring * 128 # Number of tokens + S = size_ring * 128 # Number of sequences + msa_s = 64 # MSA embedding dimension + token_z = 128 # Pairwise embedding dimension + pairwise_head_width = 32 + pairwise_num_heads = 4 + min_val_init = -5e-2 if dtype == torch.float64 else -1e-3 + max_val_init = -min_val_init + else: + N = size_ring * 2 # Number of tokens + S = size_ring * 3 # Number of sequences + msa_s = 8 # MSA embedding dimension + token_z = 12 # Pairwise embedding dimension + pairwise_head_width = 4 + pairwise_num_heads = 2 + min_val_init = -0.5 + max_val_init = 0.5 + msa_dropout = 0.0 # disable dropout as we have not way to match the random sequences between serial and CP + z_dropout = 0.0 + + seed = 42 + seed_by_rank(0, seed=seed) + + # Compute reference results with FP64 + input_z_global_fp64 = torch.empty((B, N, N, token_z), dtype=torch.float64, requires_grad=True, device=device_type) + input_m_global_fp64 = torch.empty((B, S, N, msa_s), dtype=torch.float64, requires_grad=True, device=device_type) + + token_mask_global_fp64 = torch.randint( + 0, 2, (B, N, N), dtype=torch.float64, requires_grad=False, device=device_type + ) + token_mask_global_fp64[0, N // size_ring :, :] = 0 + token_mask_global_fp64[0, :, N // size_ring :] = 0 + + msa_mask_global_fp64 = torch.ones((B, S, N), dtype=torch.float64, requires_grad=False, device=device_type) + msa_mask_global_fp64[0, (S // size_ring) :, :] = 0 + msa_mask_global_fp64[0, :, (N // size_ring) :] = 0 + + # Create reference serial module + reference_module = SerialMSALayer( + msa_s=msa_s, + token_z=token_z, + msa_dropout=msa_dropout, + z_dropout=z_dropout, + pairwise_head_width=pairwise_head_width, + pairwise_num_heads=pairwise_num_heads, + ) + + # Initialize parameters to ensure reproducible behavior + # The output activation and gradient of the layer weights typically increase by 3 to 4 orders of magnitude, + # where the ULP would be too large and numerical error distribution becomes very wide, i.e., we would have + # very unpredictable numerical errors. That would make the test results very noisy and not very useful to + # detect logical bugs in the code. To avoid this, we use a smaller range for the input and layer weights. + init_tensors_uniform([input_z_global_fp64, input_m_global_fp64], low=min_val_init, high=max_val_init) + reference_module = reference_module.to(dtype=torch.float64, device=device_type).train() + init_module_params_uniform(reference_module, low=min_val_init, high=max_val_init) + + set_dtype_specific_inf_values(reference_module, torch.float64) + + layer_state_dict_fp64 = reference_module.state_dict() + + # Run forward pass + output_z_expected_global_fp64, output_m_expected_global_fp64 = reference_module( + input_z_global_fp64, input_m_global_fp64, token_mask_global_fp64, msa_mask_global_fp64 + ) + d_output_z_expected_global_fp64 = torch.rand_like(output_z_expected_global_fp64) + d_output_m_expected_global_fp64 = torch.rand_like(output_m_expected_global_fp64) + torch.autograd.backward( + [output_z_expected_global_fp64, output_m_expected_global_fp64], + [d_output_z_expected_global_fp64, d_output_m_expected_global_fp64], + ) + + grad_params_fp64_expected_global_host = { + name: param.grad.detach().to(dtype=dtype, device="cpu", copy=True) + for name, param in reference_module.named_parameters() + } + + if check_error_hist: + # Run serial FP32 reference for three-way error histogram comparison + input_z_global_fp32 = input_z_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(True) + input_m_global_fp32 = input_m_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(True) + token_mask_global_fp32 = ( + token_mask_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(False) + ) + msa_mask_global_fp32 = msa_mask_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(False) + + reference_module_fp32 = SerialMSALayer( + msa_s=msa_s, + token_z=token_z, + msa_dropout=msa_dropout, + z_dropout=z_dropout, + pairwise_head_width=pairwise_head_width, + pairwise_num_heads=pairwise_num_heads, + ) + + reference_module_fp32.load_state_dict(layer_state_dict_fp64) + reference_module_fp32 = reference_module_fp32.to(dtype=torch.float32, device=device_type).train() + set_dtype_specific_inf_values(reference_module_fp32, torch.float32) + + output_z_global_fp32, output_m_global_fp32 = reference_module_fp32( + input_z_global_fp32, input_m_global_fp32, token_mask_global_fp32, msa_mask_global_fp32 + ) + d_output_z_expected_global_fp32 = d_output_z_expected_global_fp64.to(dtype=torch.float32) + d_output_m_expected_global_fp32 = d_output_m_expected_global_fp64.to(dtype=torch.float32) + torch.autograd.backward( + [output_z_global_fp32, output_m_global_fp32], + [d_output_z_expected_global_fp32, d_output_m_expected_global_fp32], + ) + + output_z_global_fp32_host = output_z_global_fp32.detach().to(device="cpu", copy=True) + output_m_global_fp32_host = output_m_global_fp32.detach().to(device="cpu", copy=True) + d_input_z_global_fp32_host = input_z_global_fp32.grad.detach().to(device="cpu", copy=True) + d_input_m_global_fp32_host = input_m_global_fp32.grad.detach().to(device="cpu", copy=True) + grad_params_fp32_global_host = { + name: param.grad.detach().to(device="cpu", copy=True) + for name, param in reference_module_fp32.named_parameters() + } + + output_z_for_worker = output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True) + output_m_for_worker = output_m_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True) + d_input_z_for_worker = input_z_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True) + d_input_m_for_worker = input_m_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True) + grad_params_for_worker = grad_params_fp64_expected_global_host + elif dtype == torch.float32: + # check_error_hist=False with FP32: the spawned worker compares the CP output + # directly against the expected output via assert_close. Because the FP64 + # reference uses genuine FP64 parameters while the CP module's parameters are + # truncated to FP32 on load, numerical discrepancies from both the parameter + # truncation and the lower-precision arithmetic accumulate through the composed + # MSA layer and exceed assert_close tolerances. To avoid this, we run a serial + # FP32 reference so that both sides start from identical parameters. + ref_fp32 = SerialMSALayer( + msa_s=msa_s, + token_z=token_z, + msa_dropout=msa_dropout, + z_dropout=z_dropout, + pairwise_head_width=pairwise_head_width, + pairwise_num_heads=pairwise_num_heads, + ) + ref_fp32.load_state_dict(layer_state_dict_fp64) + ref_fp32 = ref_fp32.to(dtype=torch.float32, device=device_type).train() + set_dtype_specific_inf_values(ref_fp32, torch.float32) + + inp_z = input_z_global_fp64.detach().to(dtype=torch.float32, device=device_type).requires_grad_(True) + inp_m = input_m_global_fp64.detach().to(dtype=torch.float32, device=device_type).requires_grad_(True) + tok_mask = token_mask_global_fp64.detach().to(dtype=torch.float32, device=device_type) + msa_msk = msa_mask_global_fp64.detach().to(dtype=torch.float32, device=device_type) + + out_z, out_m = ref_fp32(inp_z, inp_m, tok_mask, msa_msk) + d_out_z = d_output_z_expected_global_fp64.to(dtype=torch.float32) + d_out_m = d_output_m_expected_global_fp64.to(dtype=torch.float32) + torch.autograd.backward([out_z, out_m], [d_out_z, d_out_m]) + + output_z_for_worker = out_z.detach().cpu() + output_m_for_worker = out_m.detach().cpu() + d_input_z_for_worker = inp_z.grad.detach().cpu() + d_input_m_for_worker = inp_m.grad.detach().cpu() + grad_params_for_worker = {name: param.grad.detach().cpu() for name, param in ref_fp32.named_parameters()} + + output_z_global_fp32_host = None + output_m_global_fp32_host = None + d_input_z_global_fp32_host = None + d_input_m_global_fp32_host = None + grad_params_fp32_global_host = None + else: + # check_error_hist=False with FP64: use FP64 reference directly + output_z_for_worker = output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True) + output_m_for_worker = output_m_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True) + d_input_z_for_worker = input_z_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True) + d_input_m_for_worker = input_m_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True) + grad_params_for_worker = grad_params_fp64_expected_global_host + + output_z_global_fp32_host = None + output_m_global_fp32_host = None + d_input_z_global_fp32_host = None + d_input_m_global_fp32_host = None + grad_params_fp32_global_host = None + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_msa_layer, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + msa_s, + token_z, + msa_dropout, + z_dropout, + pairwise_head_width, + pairwise_num_heads, + layer_state_dict_fp64, + input_z_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + input_m_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + token_mask_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + msa_mask_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + output_z_for_worker, + output_m_for_worker, + d_output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + d_output_m_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + d_input_z_for_worker, + d_input_m_for_worker, + grad_params_for_worker, + output_z_global_fp32_host, + output_m_global_fp32_host, + d_input_z_global_fp32_host, + d_input_m_global_fp32_host, + grad_params_fp32_global_host, + ) diff --git a/tests/distributed/model/modules/test_dtensor_msa_module.py b/tests/distributed/model/modules/test_dtensor_msa_module.py new file mode 100644 index 000000000..04d4e5f19 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_msa_module.py @@ -0,0 +1,729 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for Boltz-2 CP MSAModule (distributed.model.modules.trunkv2.MSAModule).""" + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.data import const +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.trunkv2 import MSAModule as DistributedMSAModule +from boltz.model.modules.trunkv2 import MSAModule as SerialMSAModule +from boltz.testing.utils import ( + assert_all_identical, + assert_no_percentile_upshift, + assert_tensors_identical, + create_msa_module_init_params_v2, + get_param_by_key, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + set_dtype_specific_inf_values, + spawn_multiprocessing, +) + + +def _feats_for_distributed(feats_global, dtype, device="cpu"): + """Convert feats to the target dtype/device for the distributed module. + + Integer features (e.g. ``msa``) are kept as-is because the distributed + MSAModule now applies ``shardwise_one_hot`` internally, matching the + serial MSAModule which calls ``F.one_hot`` in its forward pass. + """ + out = {} + for key, value in feats_global.items(): + if value.dtype.is_floating_point: + out[key] = value.to(dtype=dtype, device=device) + else: + out[key] = value.to(device=device) + return out + + +def parallel_assert_msa_module( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + msa_module_params, + module_state_dict, + input_z_global_host, + input_emb_global_host, + input_feats_global_host, + output_z_expected_global_host, + d_output_z_expected_global_host, + d_input_z_expected_global_host, + d_input_emb_expected_global_host, + expected_param_grads_global_host_dict, + output_z_global_fp32_host=None, + d_input_z_global_fp32_host=None, + d_input_emb_global_fp32_host=None, + grad_params_fp32_global_host=None, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + if torch.finfo(dtype).resolution < torch.finfo(output_z_expected_global_host.dtype).resolution: + raise ValueError( + f"Target dtype {dtype} has higher precision than reference output's dtype {output_z_expected_global_host.dtype}" + ) + + check_error_hist = output_z_global_fp32_host is not None + + module_serial = SerialMSAModule(**msa_module_params) + module_serial = module_serial.to(dtype=dtype, device=manager.device) + module_serial.load_state_dict(module_state_dict) + + set_dtype_specific_inf_values(module_serial, dtype) + + module = DistributedMSAModule(module_serial, manager) + assert module.activation_checkpointing == msa_module_params["activation_checkpointing"] + module.train() + + placements_z = (Shard(0), Shard(1), Shard(2)) + placements_emb = (Shard(0), Replicate(), Shard(1)) + placements_msa = (Shard(0), Shard(1), Shard(2)) + placements_token_mask = (Shard(0), Shard(1), Shard(2)) + + input_z_dtensor = distribute_tensor( + input_z_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z, + ).requires_grad_(True) + + input_emb_dtensor = distribute_tensor( + input_emb_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_emb, + ).requires_grad_(True) + + input_feats_dtensor = {} + for key, value in input_feats_global_host.items(): + if key in ["msa", "has_deletion", "deletion_value", "msa_paired", "msa_mask"]: + input_feats_dtensor[key] = distribute_tensor( + value.to(dtype=dtype if value.dtype.is_floating_point else value.dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_msa, + ) + elif key == "token_pair_pad_mask": + input_feats_dtensor[key] = distribute_tensor( + value.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_token_mask, + ) + + d_output_z_expected_dtensor = distribute_tensor( + d_output_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z, + ) + output_z_expected_dtensor = distribute_tensor( + output_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z, + src_data_rank=None, + ) + d_input_z_expected_dtensor = distribute_tensor( + d_input_z_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z, + src_data_rank=None, + ) + d_input_emb_expected_dtensor = distribute_tensor( + d_input_emb_expected_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_emb, + src_data_rank=None, + ) + + input_z_dtensor_copy = input_z_dtensor.detach().clone().requires_grad_(True) + input_emb_dtensor_copy = input_emb_dtensor.detach().clone().requires_grad_(True) + input_feats_dtensor_copy = {k: v.detach().clone() for k, v in input_feats_dtensor.items()} + + if check_error_hist: + output_z_dtensor_result = module(input_z_dtensor, input_emb_dtensor, input_feats_dtensor) + torch.autograd.backward([output_z_dtensor_result], [d_output_z_expected_dtensor]) + + output_z_fp32_dtensor = distribute_tensor( + output_z_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z, + src_data_rank=None, + ) + d_input_z_fp32_dtensor = distribute_tensor( + d_input_z_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z, + src_data_rank=None, + ) + d_input_emb_fp32_dtensor = distribute_tensor( + d_input_emb_global_fp32_host.to(device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_emb, + src_data_rank=None, + ) + + assert_no_percentile_upshift( + output_z_dtensor_result.to_local(), + output_z_expected_dtensor.to_local(), + output_z_fp32_dtensor.to_local(), + names_input=("output_z_cp_fp32", "output_z_serial_fp64", "output_z_serial_fp32"), + ) + + assert_no_percentile_upshift( + input_z_dtensor.grad.to_local(), + d_input_z_expected_dtensor.to_local(), + d_input_z_fp32_dtensor.to_local(), + names_input=("d_input_z_cp_fp32", "d_input_z_serial_fp64", "d_input_z_serial_fp32"), + ) + + assert_no_percentile_upshift( + input_emb_dtensor.grad.to_local(), + d_input_emb_expected_dtensor.to_local(), + d_input_emb_fp32_dtensor.to_local(), + names_input=("d_input_emb_cp_fp32", "d_input_emb_serial_fp64", "d_input_emb_serial_fp32"), + ) + + for name, grad_param_expected_global in expected_param_grads_global_host_dict.items(): + grad_param_result_global = get_param_by_key(module, name).grad.full_tensor().cpu() + assert_no_percentile_upshift( + grad_param_result_global, + grad_param_expected_global.to(dtype=grad_param_result_global.dtype), + grad_params_fp32_global_host[name], + names_input=(f"d_{name}_cp_fp32", f"d_{name}_serial_fp64", f"d_{name}_serial_fp32"), + ) + else: + output_z_dtensor_result = module(input_z_dtensor, input_emb_dtensor, input_feats_dtensor) + + assert_tensors_identical( + input_z_dtensor_copy.to_local(), input_z_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical( + input_emb_dtensor_copy.to_local(), input_emb_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + for key in input_feats_dtensor_copy: + assert_tensors_identical(input_feats_dtensor_copy[key].to_local(), input_feats_dtensor[key].to_local()) + + torch.testing.assert_close(output_z_dtensor_result.to_local(), output_z_expected_dtensor.to_local()) + + d_output_z_expected_dtensor_copy = d_output_z_expected_dtensor.detach().clone() + torch.autograd.backward([output_z_dtensor_result], [d_output_z_expected_dtensor]) + + assert_tensors_identical(d_output_z_expected_dtensor_copy.to_local(), d_output_z_expected_dtensor.to_local()) + + torch.testing.assert_close(input_z_dtensor.grad.to_local(), d_input_z_expected_dtensor.to_local()) + torch.testing.assert_close(input_emb_dtensor.grad.to_local(), d_input_emb_expected_dtensor.to_local()) + + output_z_global_result_host = output_z_dtensor_result.full_tensor().cpu() + d_input_z_global_result_host = input_z_dtensor.grad.full_tensor().cpu() + d_input_emb_global_result_host = input_emb_dtensor.grad.full_tensor().cpu() + + torch.testing.assert_close(output_z_global_result_host, output_z_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_z_global_result_host, d_input_z_expected_global_host.to(dtype=dtype)) + torch.testing.assert_close(d_input_emb_global_result_host, d_input_emb_expected_global_host.to(dtype=dtype)) + + result_param_grads_dict = {} + for name, param in module.named_parameters(): + if param.grad is not None: + if name not in expected_param_grads_global_host_dict: + raise ValueError(f"Parameter {name} has a resulting gradient but it is not in the reference module") + result_param_grads_dict[name] = param.grad + + for name, expected_grad_global_host in expected_param_grads_global_host_dict.items(): + assert name in result_param_grads_dict, f"Parameter {name}'s gradient is not found in result gradients" + result_grad = result_param_grads_dict[name] + result_grad_global = result_grad.full_tensor() + torch.testing.assert_close(result_grad_global.cpu(), expected_grad_global_host.to(dtype=dtype)) + assert_all_identical(result_grad_global, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env, dtype, check_error_hist, activation_checkpointing", + ( + params_test := [ + ## CUDA tests (2 GPUs) + (((2, (1, 1)), True, "cuda", "ENV"), torch.float32, True, False), + (((2, (1, 1)), True, "cuda", "ENV"), torch.float64, True, True), + ## CUDA tests (8 GPUs) + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, True, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, True, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32, False, False), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, True, True), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float64, True, False), + ## CPU tests + (((1, (3, 3)), True, "cpu", "ENV"), torch.float32, False, False), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}, check_error_hist:{x[2]}, checkpoint:{x[3]}" + for x in params_test + ], +) +def test_msa_module_parallel(setup_env, dtype, check_error_hist, activation_checkpointing): + """Test Boltz-2 CP MSAModule against serial reference (forward, backward, param grads).""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cpu" and grid_group_sizes["cp"] == (3, 3): + pytest.skip("CPU with 3x3 CP ring not yet validated for numerical parity (distributed vs serial)") + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + test_large_model = check_error_hist or dtype == torch.float64 + + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + + if test_large_model: + N = size_ring * 64 + S = size_ring * 64 + min_val_init = -0.01 + max_val_init = 0.01 + else: + N = size_ring * 2 + S = size_ring * 3 + min_val_init = -0.5 + max_val_init = 0.5 + + msa_module_params = create_msa_module_init_params_v2(test_large_model) + msa_module_params["activation_checkpointing"] = activation_checkpointing + + seed = 42 + seed_by_rank(0, seed=seed) + + input_z_global_fp64 = torch.empty( + (B, N, N, msa_module_params["token_z"]), dtype=torch.float64, requires_grad=True, device=device_type + ) + input_emb_global_fp64 = torch.empty( + (B, N, msa_module_params["token_s"]), dtype=torch.float64, requires_grad=True, device=device_type + ) + + dim_input_msa = const.num_tokens + input_feats_global_fp64 = { + "msa": torch.randint(0, dim_input_msa, (B, S, N), dtype=torch.int64, device=device_type), + "has_deletion": torch.empty((B, S, N), dtype=torch.float64, device=device_type), + "deletion_value": torch.empty((B, S, N), dtype=torch.float64, device=device_type), + "msa_paired": torch.randint(0, 2, (B, S, N), dtype=torch.float64, device=device_type), + "msa_mask": torch.ones((B, S, N), dtype=torch.float64, device=device_type), + "token_pad_mask": torch.randint(0, 2, (B, N), dtype=torch.float64, device=device_type), + } + input_feats_global_fp64["token_pad_mask"][0, N // size_ring :] = 0 + input_feats_global_fp64["token_pair_pad_mask"] = ( + input_feats_global_fp64["token_pad_mask"][:, :, None] * input_feats_global_fp64["token_pad_mask"][:, None, :] + ) + + input_feats_global_fp64["msa_mask"][0, (S // size_ring) :, :] = 0 + input_feats_global_fp64["msa_mask"][0, :, (N // size_ring) :] = 0 + + reference_module = SerialMSAModule(**msa_module_params) + + init_tensors_uniform([input_z_global_fp64, input_emb_global_fp64], low=min_val_init, high=max_val_init) + for key, tensor in input_feats_global_fp64.items(): + if tensor.dtype.is_floating_point and "mask" not in key and "msa_paired" not in key: + init_tensors_uniform([tensor], low=min_val_init, high=max_val_init) + + reference_module = reference_module.to(dtype=torch.float64, device=device_type).train() + init_module_params_uniform(reference_module, low=min_val_init, high=max_val_init) + + set_dtype_specific_inf_values(reference_module, torch.float64) + + module_state_dict_fp64 = reference_module.state_dict() + + output_z_expected_global_fp64 = reference_module( + input_z_global_fp64, input_emb_global_fp64, input_feats_global_fp64 + ) + d_output_z_expected_global_fp64 = torch.rand_like(output_z_expected_global_fp64) + torch.autograd.backward([output_z_expected_global_fp64], [d_output_z_expected_global_fp64]) + + grad_params_fp64_expected_global_host = { + name: param.grad.detach().to(dtype=dtype, device="cpu", copy=True) + for name, param in reference_module.named_parameters() + if param.grad is not None + } + + if check_error_hist: + # Run serial FP32 reference for three-way error histogram comparison + input_z_global_fp32 = input_z_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(True) + input_emb_global_fp32 = input_emb_global_fp64.detach().to(dtype=torch.float32, copy=True).requires_grad_(True) + input_feats_global_fp32 = {} + for key, tensor in input_feats_global_fp64.items(): + if key == "msa": + input_feats_global_fp32[key] = tensor.detach().clone() + elif tensor.dtype.is_floating_point: + input_feats_global_fp32[key] = tensor.detach().to(dtype=torch.float32, copy=True) + else: + input_feats_global_fp32[key] = tensor.detach().clone() + + reference_module_fp32 = SerialMSAModule(**msa_module_params) + + reference_module_fp32.load_state_dict(module_state_dict_fp64) + set_dtype_specific_inf_values(reference_module_fp32, torch.float32) + + reference_module_fp32 = reference_module_fp32.to(dtype=torch.float32, device=device_type).train() + + output_z_global_fp32 = reference_module_fp32( + input_z_global_fp32, input_emb_global_fp32, input_feats_global_fp32 + ) + d_output_z_expected_global_fp32 = d_output_z_expected_global_fp64.to(dtype=torch.float32) + torch.autograd.backward([output_z_global_fp32], [d_output_z_expected_global_fp32]) + + output_z_global_fp32_host = output_z_global_fp32.detach().to(device="cpu", copy=True) + d_input_z_global_fp32_host = input_z_global_fp32.grad.detach().to(device="cpu", copy=True) + d_input_emb_global_fp32_host = input_emb_global_fp32.grad.detach().to(device="cpu", copy=True) + grad_params_fp32_global_host = { + name: param.grad.detach().to(device="cpu", copy=True) + for name, param in reference_module_fp32.named_parameters() + if param.grad is not None + } + + output_z_for_worker = output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True) + d_input_z_for_worker = input_z_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True) + d_input_emb_for_worker = input_emb_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True) + grad_params_for_worker = grad_params_fp64_expected_global_host + elif dtype == torch.float32: + # check_error_hist=False with FP32: the spawned worker compares the CP output + # directly against the expected output via assert_close. Because the FP64 + # reference uses genuine FP64 parameters while the CP module's parameters are + # truncated to FP32 on load, numerical discrepancies from both the parameter + # truncation and the lower-precision arithmetic accumulate through the composed + # MSA module and exceed assert_close tolerances. To avoid this, we run a serial + # FP32 reference so that both sides start from identical parameters. + ref_fp32 = SerialMSAModule(**msa_module_params) + ref_fp32.load_state_dict(module_state_dict_fp64) + set_dtype_specific_inf_values(ref_fp32, torch.float32) + ref_fp32 = ref_fp32.to(dtype=torch.float32, device=device_type).train() + + inp_z = input_z_global_fp64.detach().to(dtype=torch.float32, device=device_type).requires_grad_(True) + inp_emb = input_emb_global_fp64.detach().to(dtype=torch.float32, device=device_type).requires_grad_(True) + inp_feats = {} + for key, tensor in input_feats_global_fp64.items(): + if key == "msa": + inp_feats[key] = tensor.detach().clone() + elif tensor.dtype.is_floating_point: + inp_feats[key] = tensor.detach().to(dtype=torch.float32, device=device_type) + else: + inp_feats[key] = tensor.detach().clone().to(device=device_type) + + out_z = ref_fp32(inp_z, inp_emb, inp_feats) + d_out_z = d_output_z_expected_global_fp64.to(dtype=torch.float32) + torch.autograd.backward([out_z], [d_out_z]) + + output_z_for_worker = out_z.detach().cpu() + d_input_z_for_worker = inp_z.grad.detach().cpu() + d_input_emb_for_worker = inp_emb.grad.detach().cpu() + grad_params_for_worker = { + name: param.grad.detach().cpu() for name, param in ref_fp32.named_parameters() if param.grad is not None + } + + output_z_global_fp32_host = None + d_input_z_global_fp32_host = None + d_input_emb_global_fp32_host = None + grad_params_fp32_global_host = None + else: + # check_error_hist=False with FP64: use FP64 reference directly + output_z_for_worker = output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True) + d_input_z_for_worker = input_z_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True) + d_input_emb_for_worker = input_emb_global_fp64.grad.detach().to(dtype=dtype, device="cpu", copy=True) + grad_params_for_worker = grad_params_fp64_expected_global_host + + output_z_global_fp32_host = None + d_input_z_global_fp32_host = None + d_input_emb_global_fp32_host = None + grad_params_fp32_global_host = None + + input_feats_for_distributed = _feats_for_distributed(input_feats_global_fp64, dtype, device="cpu") + + spawn_multiprocessing( + parallel_assert_msa_module, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + msa_module_params, + module_state_dict_fp64, + input_z_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + input_emb_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + input_feats_for_distributed, + output_z_for_worker, + d_output_z_expected_global_fp64.detach().to(dtype=dtype, device="cpu", copy=True), + d_input_z_for_worker, + d_input_emb_for_worker, + grad_params_for_worker, + output_z_global_fp32_host, + d_input_z_global_fp32_host, + d_input_emb_global_fp32_host, + grad_params_fp32_global_host, + ) + + +def parallel_assert_msa_module_activation_checkpointing( + rank, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + msa_module_params, + min_val_init, + max_val_init, + input_z_global_host, + input_emb_global_host, + input_feats_global_host, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + seed_by_rank(0, seed=42) + + # The first module runs WITHOUT activation checkpointing (regular forward/backward). + # The second module runs WITH activation checkpointing (checkpoint-wrapped forward, + # recomputed during backward). Comparing the two verifies that the checkpointing + # mechanism preserves numerical correctness for both outputs and gradients. + msa_module_params = dict(msa_module_params) + msa_module_params["activation_checkpointing"] = False + module_serial = SerialMSAModule(**msa_module_params) + module_serial = module_serial.to(dtype=dtype, device=manager.device) + init_module_params_uniform(module_serial, low=min_val_init, high=max_val_init) + set_dtype_specific_inf_values(module_serial, dtype) + + module_state_dict_ref = module_serial.state_dict() + + module = DistributedMSAModule(module_serial, manager) + module.train() + + placements_z = (Shard(0), Shard(1), Shard(2)) + placements_emb = (Shard(0), Replicate(), Shard(1)) + placements_msa = (Shard(0), Shard(1), Shard(2)) + placements_token_mask = (Shard(0), Shard(1), Shard(2)) + + input_z_dtensor = distribute_tensor( + input_z_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_z, + ).requires_grad_(True) + + input_emb_dtensor = distribute_tensor( + input_emb_global_host.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_emb, + ).requires_grad_(True) + + input_feats_dtensor = {} + for key, value in input_feats_global_host.items(): + if key in ["msa", "has_deletion", "deletion_value", "msa_paired", "msa_mask"]: + input_feats_dtensor[key] = distribute_tensor( + value.to(dtype=dtype if value.dtype.is_floating_point else value.dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_msa, + ) + elif key == "token_pair_pad_mask": + input_feats_dtensor[key] = distribute_tensor( + value.to(dtype=dtype, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=placements_token_mask, + ) + + input_z_dtensor_copy = input_z_dtensor.detach().clone().requires_grad_(True) + input_emb_dtensor_copy = input_emb_dtensor.detach().clone().requires_grad_(True) + input_feats_dtensor_copy = {k: v.detach().clone() for k, v in input_feats_dtensor.items()} + + # Save RNG state so the second forward pass (with activation checkpointing) + # sees the same dropout masks as the first forward pass. + cpu_rng_state = torch.random.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state(device=manager.device) if device_type == "cuda" else None + + output_z_dtensor_result = module(input_z_dtensor, input_emb_dtensor, input_feats_dtensor) + + assert_tensors_identical( + input_z_dtensor_copy.to_local(), input_z_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + assert_tensors_identical( + input_emb_dtensor_copy.to_local(), input_emb_dtensor.to_local(), check_grad=False, check_grad_fn=False + ) + for key in input_feats_dtensor_copy: + assert_tensors_identical(input_feats_dtensor_copy[key].to_local(), input_feats_dtensor[key].to_local()) + + d_output_z_dtensor = torch.distributed.tensor.rand( + output_z_dtensor_result.shape, + requires_grad=False, + dtype=dtype, + device_mesh=manager.device_mesh_subgroups, + placements=output_z_dtensor_result.placements, + ) + d_output_z_dtensor_copy = d_output_z_dtensor.detach().clone() + + torch.autograd.backward([output_z_dtensor_result], [d_output_z_dtensor]) + + assert_tensors_identical(d_output_z_dtensor_copy.to_local(), d_output_z_dtensor.to_local()) + + # Create second module with the same weights but activation checkpointing enabled + msa_module_params["activation_checkpointing"] = True + module_serial_act_ckpt = SerialMSAModule(**msa_module_params) + module_serial_act_ckpt.load_state_dict(module_state_dict_ref) + set_dtype_specific_inf_values(module_serial_act_ckpt, dtype) + + module_serial_act_ckpt = module_serial_act_ckpt.to(dtype=dtype, device=manager.device) + module_act_ckpt = DistributedMSAModule(module_serial_act_ckpt, manager) + module_act_ckpt.train() + + # Restore RNG state so dropout masks match the first forward pass + torch.random.set_rng_state(cpu_rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state, device=manager.device) + + output_z_dtensor_result_act_ckpt = module_act_ckpt( + input_z_dtensor_copy, input_emb_dtensor_copy, input_feats_dtensor_copy + ) + + assert_tensors_identical( + output_z_dtensor_result_act_ckpt.to_local(), + output_z_dtensor_result.to_local(), + check_grad=False, + check_grad_fn=False, + ) + + torch.autograd.backward([output_z_dtensor_result_act_ckpt], [d_output_z_dtensor]) + + assert_tensors_identical(input_z_dtensor.grad.to_local(), input_z_dtensor_copy.grad.to_local()) + assert_tensors_identical(input_emb_dtensor.grad.to_local(), input_emb_dtensor_copy.grad.to_local()) + + result_param_grads_dict = {} + for name, param in module.named_parameters(): + if param.grad is not None: + result_param_grads_dict[name] = param.grad + + for name, param_act_ckpt_grad in module_act_ckpt.named_parameters(): + assert name in result_param_grads_dict, f"Parameter {name}'s gradient is not found in result gradients" + result_grad = result_param_grads_dict[name] + assert_tensors_identical(result_grad.to_local(), param_act_ckpt_grad.grad.to_local()) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, dtype", + ( + params_test := [ + (((2, (1, 1)), True, "cuda", "ENV"), torch.float32), + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32), + ] + ), + indirect=["setup_env"], + ids=[ + f"dp:{x[0][0][0]}, cp:{x[0][0][1]}, specify_method:{x[0][1]}, device_type:{x[0][2]}, method_init:{x[0][3]}, " + f"dtype:{x[1]}" + for x in params_test + ], +) +def test_msa_module_parallel_activation_checkpointing(setup_env, dtype): + """MSAModule with activation checkpointing vs CP without; results should match.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + msa_module_params = create_msa_module_init_params_v2(use_large_model=False) + msa_module_params["msa_dropout"] = 0.5 + msa_module_params["z_dropout"] = 0.5 + + B = 2 * grid_group_sizes["dp"] + size_ring = grid_group_sizes["cp"][0] + N = size_ring * 2 + S = size_ring * 3 + min_val_init = -1 + max_val_init = 1 + dim_input_msa = const.num_tokens + + input_z_global = torch.empty((B, N, N, msa_module_params["token_z"]), dtype=dtype, requires_grad=True, device="cpu") + input_emb_global = torch.empty((B, N, msa_module_params["token_s"]), dtype=dtype, requires_grad=True, device="cpu") + + input_feats_global_host = { + "msa": torch.randint(0, dim_input_msa, (B, S, N), dtype=torch.int64, device="cpu"), + "has_deletion": torch.empty((B, S, N), dtype=dtype, device="cpu"), + "deletion_value": torch.empty((B, S, N), dtype=dtype, device="cpu"), + "msa_paired": torch.randint(0, 2, (B, S, N), dtype=dtype, device="cpu"), + "msa_mask": torch.ones((B, S, N), dtype=dtype, device="cpu"), + "token_pad_mask": torch.randint(0, 2, (B, N), dtype=dtype, device="cpu"), + } + input_feats_global_host["token_pad_mask"][0, N // size_ring :] = 0 + input_feats_global_host["token_pair_pad_mask"] = ( + input_feats_global_host["token_pad_mask"][:, :, None] * input_feats_global_host["token_pad_mask"][:, None, :] + ) + + input_feats_global_host["msa_mask"][0, (S // size_ring) :, :] = 0 + input_feats_global_host["msa_mask"][0, :, (N // size_ring) :] = 0 + + init_tensors_uniform([input_z_global, input_emb_global], low=min_val_init, high=max_val_init) + for key, tensor in input_feats_global_host.items(): + if tensor.dtype.is_floating_point and "mask" not in key: + init_tensors_uniform([tensor], low=min_val_init, high=max_val_init) + + input_feats_for_distributed = _feats_for_distributed(input_feats_global_host, dtype, device="cpu") + + spawn_multiprocessing( + parallel_assert_msa_module_activation_checkpointing, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + msa_module_params, + min_val_init, + max_val_init, + input_z_global.detach().to(dtype=dtype, device="cpu", copy=True), + input_emb_global.detach().to(dtype=dtype, device="cpu", copy=True), + input_feats_for_distributed, + ) diff --git a/tests/distributed/model/modules/test_dtensor_pairwise_conditioning.py b/tests/distributed/model/modules/test_dtensor_pairwise_conditioning.py new file mode 100644 index 000000000..11678746d --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_pairwise_conditioning.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor PairwiseConditioning module. + +Tests both Boltz-1x and Boltz-2 serial PairwiseConditioning modules against the unified +DTensor PairwiseConditioning implementation, verifying forward and backward equivalence. + +Uses float64 with default tolerance for exact comparison. +""" + +import pytest +import torch +from torch.distributed.tensor import Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.encoders import PairwiseConditioning as DTensorPairwiseConditioning +from boltz.model.modules.encoders import PairwiseConditioning as PairwiseConditioningBoltz1 +from boltz.model.modules.encodersv2 import PairwiseConditioning as PairwiseConditioningBoltz2 +from boltz.testing.utils import ( + assert_tensors_identical, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_assert_pairwise_conditioning( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + serial_module_version: str, + layer_state_dict, + dtype: torch.dtype, + # Input tensors (global, on host) + z_trunk_global_host: torch.Tensor, + token_rel_pos_feats_global_host: torch.Tensor, + # Expected outputs + z_expected_global_host: torch.Tensor, + # Upstream gradients + d_z_global_host: torch.Tensor, + # Expected input grads + d_z_trunk_expected_global_host: torch.Tensor, + d_token_rel_pos_feats_expected_global_host: torch.Tensor, + # Expected parameter grads + expected_param_grads_global_host_dict: dict[str, torch.Tensor], + # Module constructor kwargs + module_kwargs: dict, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Create serial module from state dict + if serial_module_version == "boltz1": + module_serial = PairwiseConditioningBoltz1(**module_kwargs) + else: + module_serial = PairwiseConditioningBoltz2(**module_kwargs) + module_serial = module_serial.to(device=manager.device, dtype=dtype) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.train() + + # Create DTensor module from serial + module_dt = DTensorPairwiseConditioning( + layer=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # Pair placements: shard along both token dimensions + placements_pair = (Shard(0), Shard(1), Shard(2)) + + # Distribute inputs + z_trunk_dt = distribute_tensor( + z_trunk_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_pair, + ).requires_grad_(True) + token_rel_pos_feats_dt = distribute_tensor( + token_rel_pos_feats_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_pair, + ).requires_grad_(True) + + # Copies to verify inputs aren't modified + z_trunk_dt_copy = z_trunk_dt.detach().clone().requires_grad_(True) + token_rel_pos_feats_dt_copy = token_rel_pos_feats_dt.detach().clone().requires_grad_(True) + + # Forward pass + z_dt = module_dt(z_trunk_dt, token_rel_pos_feats_dt) + + # Ensure no input mutation + assert_tensors_identical(z_trunk_dt_copy.to_local(), z_trunk_dt.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical( + token_rel_pos_feats_dt_copy.to_local(), + token_rel_pos_feats_dt.to_local(), + check_grad=False, + check_grad_fn=False, + ) + + # Forward compare + torch.testing.assert_close(z_dt.full_tensor().cpu(), z_expected_global_host) + + # Backward pass + d_z_dt = distribute_tensor( + d_z_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_pair, + ) + z_dt.backward(d_z_dt) + + # Compare input gradients + torch.testing.assert_close(z_trunk_dt.grad.full_tensor().cpu(), d_z_trunk_expected_global_host) + torch.testing.assert_close( + token_rel_pos_feats_dt.grad.full_tensor().cpu(), d_token_rel_pos_feats_expected_global_host + ) + + # Compare parameter gradients + for name, param in module_dt.named_parameters(): + assert param.grad is not None, f"Parameter {name} has no gradient" + expected_grad = expected_param_grads_global_host_dict[name] + torch.testing.assert_close( + param.grad.full_tensor().cpu(), + expected_grad, + msg=lambda m: f"Parameter gradient mismatch for {name}: {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize("serial_module_version", ["boltz1", "boltz2"]) +def test_pairwise_conditioning(setup_env, serial_module_version: str): + """Test PairwiseConditioning DTensor vs serial equivalence for both Boltz-1x and Boltz-2.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + dtype = torch.float64 + + # Module dimensions + token_z = 32 + dim_token_rel_pos_feats = 8 + num_transitions = 2 + + # Data dimensions — N must be divisible by dp * cp + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 4 # tokens (keep small for pair O(N^2)) + + val_init_min, val_init_max = -0.08, 0.08 + + seed_by_rank(0, seed=42) + + module_kwargs = { + "token_z": token_z, + "dim_token_rel_pos_feats": dim_token_rel_pos_feats, + "num_transitions": num_transitions, + } + + if serial_module_version == "boltz1": + module_serial = PairwiseConditioningBoltz1(**module_kwargs) + else: + module_serial = PairwiseConditioningBoltz2(**module_kwargs) + + # Cast to target dtype BEFORE param init to prevent precision loss + module_serial = module_serial.to(dtype=dtype).train() + init_module_params_uniform(module_serial, low=val_init_min, high=val_init_max) + layer_state_dict = module_serial.state_dict() + + # Create input tensors + z_trunk_global = torch.empty(B, N, N, token_z, dtype=dtype, requires_grad=True) + token_rel_pos_feats_global = torch.empty(B, N, N, dim_token_rel_pos_feats, dtype=dtype, requires_grad=True) + init_tensors_uniform([z_trunk_global, token_rel_pos_feats_global], low=val_init_min, high=val_init_max) + + # Serial forward pass + z_serial = module_serial(z_trunk_global, token_rel_pos_feats_global) + + # Create upstream gradient + d_z = torch.empty_like(z_serial) + init_tensors_uniform([d_z], low=val_init_min, high=val_init_max) + + # Serial backward pass + z_serial.backward(d_z) + + # Collect expected results + z_expected = z_serial.detach().clone().cpu() + d_z_trunk_expected = z_trunk_global.grad.detach().clone().cpu() + d_token_rel_pos_feats_expected = token_rel_pos_feats_global.grad.detach().clone().cpu() + + expected_param_grads = {} + for name, param in module_serial.named_parameters(): + if param.grad is not None: + expected_param_grads[name] = param.grad.detach().clone().cpu() + + # Launch parallel test + spawn_multiprocessing( + parallel_assert_pairwise_conditioning, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_module_version, + layer_state_dict, + dtype, + z_trunk_global.detach().clone().cpu(), + token_rel_pos_feats_global.detach().clone().cpu(), + z_expected, + d_z.detach().clone().cpu(), + d_z_trunk_expected, + d_token_rel_pos_feats_expected, + expected_param_grads, + module_kwargs, + ) diff --git a/tests/distributed/model/modules/test_dtensor_single_conditioning.py b/tests/distributed/model/modules/test_dtensor_single_conditioning.py new file mode 100644 index 000000000..1515c6c52 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_single_conditioning.py @@ -0,0 +1,347 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor SingleConditioning module. + +Tests both Boltz-1x and Boltz-2 serial SingleConditioning modules against the unified +DTensor SingleConditioning implementation, verifying forward and backward equivalence. + +The functional differences between V1 and V2 SingleConditioning are: + - ``disable_times``: V2-only flag. When True, fourier time embedding is absent. + - ``v1_input_layout``: V1 uses a wider input_dim for norm_single/single_embed + (``2*token_s + 2*num_tokens + 1 + len(pocket_contact_info)``), while V2 uses + ``2*token_s``. When True, the V1 serial class is used. + +The serial module version is inferred from these flags: + - ``v1_input_layout=True`` → ``SingleConditioningBoltz1`` (V1 does not support disable_times) + - ``v1_input_layout=False`` → ``SingleConditioningBoltz2`` + +Uses float64 with default tolerance for exact comparison. +""" + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.data import const +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.encoders import SingleConditioning as DTensorSingleConditioning +from boltz.model.modules.encoders import SingleConditioning as SingleConditioningBoltz1 +from boltz.model.modules.encodersv2 import SingleConditioning as SingleConditioningBoltz2 +from boltz.testing.utils import ( + assert_tensors_identical, + init_module_params_uniform, + init_tensors_uniform, + seed_by_rank, + spawn_multiprocessing, +) + + +def parallel_assert_single_conditioning( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + env_per_rank, + serial_class_tag: str, # "boltz1" or "boltz2" — inferred from flags + layer_state_dict, + dtype: torch.dtype, + # Input tensors (global, on host) + times_global_host: torch.Tensor, + s_trunk_global_host: torch.Tensor, + s_inputs_global_host: torch.Tensor, + # Expected outputs + s_expected_global_host: torch.Tensor, + normed_fourier_expected_global_host: torch.Tensor | None, + # Upstream gradients + d_s_global_host: torch.Tensor, + d_normed_fourier_global_host: torch.Tensor | None, + # Expected input grads + d_s_trunk_expected_global_host: torch.Tensor, + d_s_inputs_expected_global_host: torch.Tensor, + # Expected parameter grads + expected_param_grads_global_host_dict: dict[str, torch.Tensor], + # Module constructor kwargs + module_kwargs: dict, +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Create serial module from state dict + if serial_class_tag == "boltz1": + module_serial = SingleConditioningBoltz1(**module_kwargs) + else: + module_serial = SingleConditioningBoltz2(**module_kwargs) + module_serial = module_serial.to(device=manager.device, dtype=dtype) + module_serial.load_state_dict(layer_state_dict) + module_serial = module_serial.train() + + # Create DTensor module from serial + module_dt = DTensorSingleConditioning( + layer=module_serial, + device_mesh=manager.device_mesh_subgroups, + ).train() + + # Placements + placements_times = (Shard(0), Replicate(), Replicate()) + placements_s = (Shard(0), Shard(1), Replicate()) + + # Distribute inputs + times_dt = distribute_tensor( + times_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_times, + ).requires_grad_(False) # times has no gradient in serial code + s_trunk_dt = distribute_tensor( + s_trunk_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_s, + ).requires_grad_(True) + s_inputs_dt = distribute_tensor( + s_inputs_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_s, + ).requires_grad_(True) + + # Copies to verify inputs aren't modified + times_dt_copy = times_dt.detach().clone() + s_trunk_dt_copy = s_trunk_dt.detach().clone().requires_grad_(True) + s_inputs_dt_copy = s_inputs_dt.detach().clone().requires_grad_(True) + + # Forward pass + s_dt, normed_fourier_dt = module_dt(times_dt, s_trunk_dt, s_inputs_dt) + + # Ensure no input mutation + assert_tensors_identical(times_dt_copy.to_local(), times_dt.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(s_trunk_dt_copy.to_local(), s_trunk_dt.to_local(), check_grad=False, check_grad_fn=False) + assert_tensors_identical(s_inputs_dt_copy.to_local(), s_inputs_dt.to_local(), check_grad=False, check_grad_fn=False) + + # Forward compare + torch.testing.assert_close(s_dt.full_tensor().cpu(), s_expected_global_host) + if normed_fourier_expected_global_host is not None: + assert normed_fourier_dt is not None + torch.testing.assert_close(normed_fourier_dt.full_tensor().cpu(), normed_fourier_expected_global_host) + else: + assert normed_fourier_dt is None + + # Backward pass + outputs = [s_dt] + grad_outputs = [ + distribute_tensor( + d_s_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_s, + ) + ] + if normed_fourier_dt is not None and d_normed_fourier_global_host is not None: + outputs.append(normed_fourier_dt) + grad_outputs.append( + distribute_tensor( + d_normed_fourier_global_host.to(device=manager.device, dtype=dtype), + manager.device_mesh_subgroups, + placements_times, + ) + ) + torch.autograd.backward(outputs, grad_outputs) + + # Compare input gradients + torch.testing.assert_close(s_trunk_dt.grad.full_tensor().cpu(), d_s_trunk_expected_global_host) + torch.testing.assert_close(s_inputs_dt.grad.full_tensor().cpu(), d_s_inputs_expected_global_host) + + # Compare parameter gradients (skip frozen params like FourierEmbedding.proj) + for name, param in module_dt.named_parameters(): + if not param.requires_grad: + assert param.grad is None, f"Frozen parameter {name} should have no gradient" + continue + assert param.grad is not None, f"Parameter {name} has no gradient" + expected_grad = expected_param_grads_global_host_dict[name] + torch.testing.assert_close( + param.grad.full_tensor().cpu(), + expected_grad, + msg=lambda m: f"Parameter gradient mismatch for {name}: {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize( + "disable_times, v1_input_layout", + [ + (False, False), # V2 serial, times enabled, standard 2*token_s input_dim + (False, True), # V1 serial, times enabled, wider input_dim + (True, False), # V2 serial, times disabled (V2-only feature) + # (True, True) is invalid — V1 does not support disable_times + ], + ids=["times:on-input:v2", "times:on-input:v1", "times:off"], +) +def test_single_conditioning(setup_env, disable_times: bool, v1_input_layout: bool): + """Test SingleConditioning DTensor vs serial equivalence. + + Parametrized on the functional flags that differ between V1 and V2: + - ``disable_times``: whether fourier time embedding is skipped (V2-only) + - ``v1_input_layout``: whether to use V1's wider input_dim (implies V1 serial class) + + The serial module class is inferred from these flags. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + # Infer serial class from functional flags + if v1_input_layout: + assert not disable_times, "V1 does not support disable_times" + serial_class, serial_class_tag = SingleConditioningBoltz1, "boltz1" + else: + serial_class, serial_class_tag = SingleConditioningBoltz2, "boltz2" + + dtype = torch.float64 + + # Module dimensions + token_s = 64 + dim_fourier = 32 + num_transitions = 2 + sigma_data = 1.0 + + # Data dimensions + size_ring = grid_group_sizes["cp"][0] + B = 2 * grid_group_sizes["dp"] + N = size_ring * 8 + + val_init_min, val_init_max = -0.08, 0.08 + + seed_by_rank(0, seed=42) + + # Build serial module kwargs — the input_dim and available kwargs differ by version + if v1_input_layout: + # V1: input_dim = 2 * token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info) + input_dim = 2 * token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info) + s_inputs_dim = input_dim - token_s # s_trunk is token_s, s_inputs is the rest + module_kwargs = { + "sigma_data": sigma_data, + "token_s": token_s, + "dim_fourier": dim_fourier, + "num_transitions": num_transitions, + } + else: + # V2: input_dim = 2 * token_s + s_inputs_dim = token_s # s_trunk and s_inputs both have token_s dim + module_kwargs = { + "sigma_data": sigma_data, + "token_s": token_s, + "dim_fourier": dim_fourier, + "num_transitions": num_transitions, + "disable_times": disable_times, + } + + # Create serial module — cast to target dtype BEFORE param init to prevent precision loss + module_serial = serial_class(**module_kwargs) + module_serial = module_serial.to(dtype=dtype).train() + init_module_params_uniform(module_serial, low=val_init_min, high=val_init_max) + layer_state_dict = module_serial.state_dict() + + # Create input tensors + times_global = torch.empty(B, dtype=dtype) + s_trunk_global = torch.empty(B, N, token_s, dtype=dtype, requires_grad=True) + s_inputs_global = torch.empty(B, N, s_inputs_dim, dtype=dtype, requires_grad=True) + init_tensors_uniform([times_global, s_trunk_global, s_inputs_global], low=val_init_min, high=val_init_max) + + # Serial forward pass — V1 uses keyword-only args, V2 uses positional + if v1_input_layout: + s_serial, normed_fourier_serial = module_serial( + times=times_global, s_trunk=s_trunk_global, s_inputs=s_inputs_global + ) + else: + s_serial, normed_fourier_serial = module_serial(times_global, s_trunk_global, s_inputs_global) + + # Create upstream gradients + d_s = torch.empty_like(s_serial) + init_tensors_uniform([d_s], low=val_init_min, high=val_init_max) + + d_normed_fourier = None + if normed_fourier_serial is not None: + d_normed_fourier = torch.empty_like(normed_fourier_serial) + init_tensors_uniform([d_normed_fourier], low=val_init_min, high=val_init_max) + + # Serial backward pass + outputs = [s_serial] + grad_outputs = [d_s] + if normed_fourier_serial is not None and d_normed_fourier is not None: + outputs.append(normed_fourier_serial) + grad_outputs.append(d_normed_fourier) + torch.autograd.backward(outputs, grad_outputs) + + # Collect expected results + s_expected = s_serial.detach().clone().cpu() + normed_fourier_expected = ( + normed_fourier_serial.detach().clone().cpu() if normed_fourier_serial is not None else None + ) + d_s_trunk_expected = s_trunk_global.grad.detach().clone().cpu() + d_s_inputs_expected = s_inputs_global.grad.detach().clone().cpu() + + expected_param_grads = {} + for name, param in module_serial.named_parameters(): + if param.requires_grad and param.grad is not None: + expected_param_grads[name] = param.grad.detach().clone().cpu() + + # Launch parallel test + spawn_multiprocessing( + parallel_assert_single_conditioning, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_class_tag, + layer_state_dict, + dtype, + times_global.detach().clone().cpu(), + s_trunk_global.detach().clone().cpu(), + s_inputs_global.detach().clone().cpu(), + s_expected, + normed_fourier_expected, + d_s.detach().clone().cpu(), + d_normed_fourier.detach().clone().cpu() if d_normed_fourier is not None else None, + d_s_trunk_expected, + d_s_inputs_expected, + expected_param_grads, + module_kwargs, + ) diff --git a/tests/distributed/model/modules/test_dtensor_trunkv2.py b/tests/distributed/model/modules/test_dtensor_trunkv2.py new file mode 100644 index 000000000..d492bfa25 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_trunkv2.py @@ -0,0 +1,913 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor-based CP trunk modules for Boltz-2. + +This module tests the DTensor context-parallel trunk modules against the +serial implementations imported from trunkv2: +- DistogramModule +- BFactorModule +- ContactConditioning + +Verification checks (per module): + V1: single-proc FW input tensor values unchanged by FW and BW + V2: single-proc BW input tensor values unchanged by BW + V4a: multi-proc FW input tensor values unchanged by FW + V4b: multi-proc FW input tensor values unchanged after BW + V5: multi-proc BW input tensor values unchanged by BW + V8: multi-proc FW output tensor values close-to single-proc + V9: multi-proc FW input gradient values close-to single-proc + V10: multi-proc parameter gradient values close-to single-proc + V10b: replicated parameter gradients identical across all CP ranks +""" + +import math +from collections import OrderedDict + +import pytest +import torch +from torch import Tensor +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor +from torch.testing import assert_close + +from boltz.data import const +from boltz.distributed.comm import TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.trunkv2 import BFactorModule as BFactorModuleDTensor +from boltz.distributed.model.modules.trunkv2 import ContactConditioning as ContactConditioningDTensor +from boltz.distributed.model.modules.trunkv2 import DistogramModule as DistogramModuleDTensor +from boltz.model.modules.trunkv2 import BFactorModule as SerialBFactorModule +from boltz.model.modules.trunkv2 import ContactConditioning as SerialContactConditioning +from boltz.model.modules.trunkv2 import DistogramModule as SerialDistogramModule +from boltz.testing.utils import ( + assert_all_identical, + assert_tensors_identical, + init_module_params_uniform, + init_tensors_uniform, + make_random_contact_conditioning_features, + skip_if_cuda_not_avail_or_device_count_less_than_word_size, + spawn_multiprocessing, +) + +SEED = 42 + + +def _assert_unchanged(actual, expected, *, serial=False): + """Shorthand for assert_tensors_identical with standard immutability kwargs. + + serial=True uses check_storage_offset=True (serial-side V1/V2 checks). + serial=False uses check_storage_offset=False (worker-side V4/V5 checks on DTensor locals). + """ + assert_tensors_identical( + actual, + expected, + check_stride=True, + check_grad=False, + check_grad_fn=False, + check_storage_pointer=False, + check_storage_offset=serial, + ) + + +def assert_dtensor_distogram( + rank: int, + input_example_on_host: Tensor, + output_ref_on_host: Tensor, + output_grad_example_on_host: Tensor, + input_grad_ref_on_host: Tensor, + parameter_grads_ref_on_host: OrderedDict[str, Tensor | None], + module_state_dict: dict, + token_z: int, + num_bins: int, + num_distograms: int, + grid_group_sizes: dict, + device_type: str, + backend: str, + env_map: dict[str, str] | None = None, +): + """Worker function for distributed DTensor DistogramModule testing. + + Follows the Boltz-1x CP pattern: uses full_tensor() for comparisons, + assert_tensors_identical for binary identity checks, and grad.full_tensor() + for backward gradient comparison. + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + dist_manager = DistributedManager() + + _assert_dtensor_distogram_impl( + dist_manager, + input_example_on_host, + output_ref_on_host, + output_grad_example_on_host, + input_grad_ref_on_host, + parameter_grads_ref_on_host, + module_state_dict, + token_z, + num_bins, + num_distograms, + rank, + ) + DistributedManager.cleanup() + monkeypatch.undo() + + +def _assert_dtensor_distogram_impl( + dist_manager: DistributedManager, + input_example_on_host: Tensor, + output_ref_on_host: Tensor, + output_grad_example_on_host: Tensor, + input_grad_ref_on_host: Tensor, + parameter_grads_ref_on_host: OrderedDict[str, Tensor | None], + module_state_dict: dict, + token_z: int, + num_bins: int, + num_distograms: int, + rank: int, +) -> None: + """Inner implementation of distogram DTensor test (extracted for try/finally).""" + # Move inputs and reference outputs to device + input_example_on_device = input_example_on_host.detach().to(dist_manager.device).requires_grad_(True) + output_ref_on_device = output_ref_on_host.detach().to(dist_manager.device) + output_grad_example_on_device = output_grad_example_on_host.detach().to(dist_manager.device) + input_grad_ref_on_device = input_grad_ref_on_host.detach().to(dist_manager.device) + + # Create serial module (needed to wrap in DTensor module) + serial_module = SerialDistogramModule( + token_z=token_z, + num_bins=num_bins, + num_distograms=num_distograms, + ) + serial_module.load_state_dict(state_dict=module_state_dict) + serial_module = serial_module.to(dist_manager.device) + + # Create DTensor module + dtensor_module = DistogramModuleDTensor( + module=serial_module, + dist_manager=dist_manager, + distogram_comm=TransposeComm( + dist_manager.group["cp"], + dist_manager.layout_subgroups["cp"], + ), + ) + dtensor_module = dtensor_module.to(dist_manager.device) + dtensor_module = dtensor_module.train() + + # Create input DTensors with pairlike placements + pairlike_placements = (Shard(0), Shard(1), Shard(2)) + + input_as_dtensor = distribute_tensor( + input_example_on_device, + device_mesh=dist_manager.device_mesh_subgroups, + placements=pairlike_placements, + ).requires_grad_(input_example_on_device.requires_grad) + + output_grad_as_dtensor = distribute_tensor( + output_grad_example_on_device, + device_mesh=dist_manager.device_mesh_subgroups, + placements=pairlike_placements, + ).requires_grad_(output_grad_example_on_device.requires_grad) + + # V4a setup: clone input for immutability check + input_example_clone_as_dtensor = input_as_dtensor.detach().clone().requires_grad_(input_as_dtensor.requires_grad) + + # Forward pass + output_actual_as_dtensor = dtensor_module(input_as_dtensor) + + # V4a: FW input tensor values unchanged by FW (binary identity - atol=0, rtol=0) + _assert_unchanged(input_as_dtensor.full_tensor(), input_example_clone_as_dtensor.full_tensor(), serial=True) + + # V8: multi-proc FW output tensor values close-to single-proc + # Tolerance: fp32 default (atol=1e-5, rtol=1.3e-6). The distogram is + # always tested in fp32. The only numerical difference is summation order + # in symmetrize (z + z.T) which introduces ~N*eps ≈ 16*1.2e-7 ≈ 2e-6 error. + assert_close(output_actual_as_dtensor.full_tensor(), output_ref_on_device, atol=1e-5, rtol=1.3e-6) + + # Verify 5D output shape + assert ( + output_actual_as_dtensor.shape[3] == num_distograms + ), f"Expected dim 3 = {num_distograms}, got {output_actual_as_dtensor.shape[3]}" + assert ( + output_actual_as_dtensor.shape[4] == num_bins + ), f"Expected dim 4 = {num_bins}, got {output_actual_as_dtensor.shape[4]}" + + # V5 setup: save output for comparison + output_grad_example_clone_as_dtensor = ( + output_grad_as_dtensor.detach().clone().requires_grad_(output_grad_as_dtensor.requires_grad) + ) + output_actual_clone_as_dtensor = ( + output_actual_as_dtensor.detach().clone().requires_grad_(output_actual_as_dtensor.requires_grad) + ) + + # Backward pass + output_actual_as_dtensor.backward(output_grad_as_dtensor) + + # V4b: FW input tensor values unchanged after backward (binary identity - atol=0, rtol=0) + _assert_unchanged(input_as_dtensor.full_tensor(), input_example_clone_as_dtensor.full_tensor(), serial=True) + + # V5: BW input tensor values unchanged by BW (binary identity - atol=0, rtol=0) + _assert_unchanged(output_actual_as_dtensor.full_tensor(), output_actual_clone_as_dtensor.full_tensor(), serial=True) + _assert_unchanged( + output_grad_as_dtensor.full_tensor(), output_grad_example_clone_as_dtensor.full_tensor(), serial=True + ) + + # V9: multi-proc FW input gradient values close-to single-proc + # Use grad.full_tensor() because DTensor grad can be in Partial(Sum) placement + assert input_as_dtensor.grad is not None, "Input DTensor gradient is None - trivial equality guard failed" + assert input_grad_ref_on_device is not None, "Reference input gradient is None - test setup error" + input_grad_actual_full: Tensor = input_as_dtensor.grad.full_tensor() + assert_close(input_grad_actual_full, input_grad_ref_on_device, atol=1e-5, rtol=1.3e-6) + + # Non-vacuous: input gradient must be non-zero + assert input_grad_actual_full.abs().sum() > 0, "Input gradient is all-zero — backward did not propagate" + + # V10: multi-proc parameter gradient values close-to single-proc + param_names_checked = [] + for name, param in dtensor_module.named_parameters(): + if name not in parameter_grads_ref_on_host: + msg = f"Module parameter {name} not in parameter_grads_ref_on_host" + raise ValueError(msg) + + grad_ref = parameter_grads_ref_on_host[name] + + if param.grad is None and grad_ref is None: + msg = ( + f"Both actual and reference gradients are None for {name} - " + "trivial equality, test cannot verify correctness" + ) + raise ValueError(msg) + + if (param.grad is None) != (grad_ref is None): + msg = f"Inconsistent grad state for {name} on rank {rank}: result={param.grad}, ref={grad_ref}" + raise ValueError(msg) + + param_names_checked.append(name) + if grad_ref is not None: + # Use full_tensor() to reduce Partial(Sum) gradients + grad_actual = param.grad.full_tensor() if isinstance(param.grad, DTensor) else param.grad + param_name = name + assert_close( + grad_actual, + grad_ref.to(dist_manager.device), + atol=1e-5, + rtol=1.3e-6, + msg=lambda m, n=param_name: f"Rank {rank} {n} grad mismatch\n{m}", + ) + + # V10b: replicated parameter gradients identical across all CP ranks + assert_all_identical(grad_actual.detach(), dist_manager.group["cp"]) + + # Non-vacuous: parameter gradient must be non-zero + assert grad_actual.abs().sum() > 0, f"Rank {rank} {param_name} gradient is all-zero" + + assert ( + len(param_names_checked) >= 2 + ), f"Expected at least 2 parameters (weight, bias), but only checked: {param_names_checked}" + + +def get_example_input_and_reference_output( + grid_group_sizes: dict, + B: int, + num_tokens_per_device_grid_unit: int, + token_z: int, + num_bins: int, + num_distograms: int, + dtype_for_test: torch.dtype = torch.float32, + device_for_test: str = "cpu", + seed: int = SEED, +): + """Generate example input and reference output for testing.""" + with torch.random.fork_rng(devices=[], enabled=True): + torch.manual_seed(seed) + + num_tokens = num_tokens_per_device_grid_unit * grid_group_sizes["cp"][0] + + min_init_val = -0.5 + max_init_val = 0.5 + + input_example = torch.empty( + (B, num_tokens, num_tokens, token_z), + device=device_for_test, + dtype=dtype_for_test, + requires_grad=True, + ) + init_tensors_uniform([input_example], low=min_init_val, high=max_init_val) + input_example_copy = input_example.detach().clone().requires_grad_(input_example.requires_grad) + + module = SerialDistogramModule(token_z, num_bins, num_distograms=num_distograms) + init_module_params_uniform(module, low=min_init_val, high=max_init_val) + module = module.to(device_for_test) + module_state_dict = module.state_dict() + + # Run serial forward + output_ref = module(input_example) + + # V1a: single-proc FW input tensor values unchanged + _assert_unchanged(input_example, input_example_copy, serial=True) + + output_grad_example = torch.empty_like(output_ref) + init_tensors_uniform([output_grad_example], low=min_init_val, high=max_init_val) + output_grad_example_copy = output_grad_example.detach().clone() + + # Serial backward + torch.autograd.backward([output_ref], [output_grad_example]) + + # V1b: single-proc FW input tensor values unchanged after backward + _assert_unchanged(input_example, input_example_copy, serial=True) + + # V2: single-proc BW input tensor values unchanged + _assert_unchanged(output_grad_example, output_grad_example_copy, serial=True) + + # Get parameter gradients + parameter_grads_ref = OrderedDict() + for name, param in module.named_parameters(): + if param.grad is not None: + parameter_grads_ref[name] = param.grad.detach().cpu().clone() + else: + parameter_grads_ref[name] = None + + # To host for output + input_example_on_host = input_example.detach().cpu().clone() + output_ref_on_host = output_ref.detach().cpu().clone() + output_grad_example_on_host = output_grad_example.detach().cpu().clone() + input_example_grad_on_host = input_example.grad.detach().cpu().clone() + + return ( + input_example_on_host, + output_ref_on_host, + output_grad_example_on_host, + input_example_grad_on_host, + parameter_grads_ref, + module_state_dict, + ) + + +@pytest.mark.parametrize( + "setup_env, num_distograms", + [ + # CUDA dp=1 cp=(1,1): serial-equivalent sanity check with D=1 (v1 fallback) + (((1, (1, 1)), True, "cuda", "ENV"), 1), + # CUDA dp=2 cp=(1,1): DP-only path (2 GPUs) + (((2, (1, 1)), True, "cuda", "ENV"), 1), + # CUDA dp=2 cp=(2,2): full DP+CP with multi-distogram (D=3) + (((2, (2, 2)), True, "cuda", "ENV"), 3), + # CPU dp=2 cp=(3,3): non-power-of-two CP baseline without GPUs + (((2, (3, 3)), True, "cpu", "ENV"), 1), + ], + indirect=("setup_env",), + ids=[ + "cuda-dp1-cp1x1-D1", + "cuda-dp2-cp1x1-D1", + "cuda-dp2-cp2x2-D3", + "cpu-dp2-cp3x3-D1", + ], +) +def test_dtensor_distogram_forward_backward( + num_distograms: int, + setup_env: dict, + seed: int = SEED, +): + """Test that DTensor DistogramModule matches serial Boltz-2 implementation.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + skip_if_cuda_not_avail_or_device_count_less_than_word_size( + device_type=device_type, + world_size=world_size, + ) + + B = 2 + num_tokens_per_device_grid_unit = 8 + token_z = 32 + num_bins = 16 + + ( + input_example_on_host, + output_ref_on_host, + output_grad_example_on_host, + input_grad_ref_on_host, + parameter_grads_ref, + module_state_dict, + ) = get_example_input_and_reference_output( + grid_group_sizes, + B=B, + num_tokens_per_device_grid_unit=num_tokens_per_device_grid_unit, + token_z=token_z, + num_bins=num_bins, + num_distograms=num_distograms, + dtype_for_test=torch.float32, + device_for_test=device_type, + seed=seed, + ) + + spawn_multiprocessing( + assert_dtensor_distogram, + world_size, + input_example_on_host, + output_ref_on_host, + output_grad_example_on_host, + input_grad_ref_on_host, + parameter_grads_ref, + module_state_dict, + token_z, + num_bins, + num_distograms, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +# ====================================================================== # +# BFactorModule parity tests # +# ====================================================================== # + + +def _worker_bfactor_parity( + rank: int, + input_on_host: Tensor, + output_ref_on_host: Tensor, + grad_output_on_host: Tensor, + input_grad_ref_on_host: Tensor, + param_grads_ref: OrderedDict[str, Tensor | None], + state_dict: dict, + token_s: int, + num_bins: int, + dtype: torch.dtype, + grid_group_sizes: dict, + device_type: str, + backend: str, + env_map: dict[str, str] | None = None, +): + """Worker: compare distributed BFactorModule against serial reference. + + Performs V4a, V4b, V5, V8, V9, V10, V10b checks. + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + dm = DistributedManager() + + serial = SerialBFactorModule(token_s=token_s, num_bins=num_bins) + serial = serial.to(device=dm.device, dtype=dtype) + serial.load_state_dict(state_dict) + + dist_mod = BFactorModuleDTensor(serial, device_mesh=dm.device_mesh_subgroups) + dist_mod.train() + + # Single representation placements: shard B on dp, N on cp0, replicate on cp1 + single_placements = (Shard(0), Shard(1)) + (Replicate(),) * (dm.device_mesh_subgroups.ndim - 2) + x = distribute_tensor( + input_on_host.to(device=dm.device, dtype=dtype), dm.device_mesh_subgroups, single_placements + ).requires_grad_(True) + + # V4a setup: clone input + x_clone = x.detach().clone().requires_grad_(x.requires_grad) + + out = dist_mod(x) + + # V4a: FW input unchanged + _assert_unchanged(x.to_local(), x_clone.to_local()) + + # V8: forward parity + fw_atol = 1e-2 if dtype == torch.bfloat16 else 1e-5 + fw_rtol = 1e-2 if dtype == torch.bfloat16 else 1.3e-6 + assert_close( + out.full_tensor(), + output_ref_on_host.to(device=dm.device, dtype=dtype), + atol=fw_atol, + rtol=fw_rtol, + ) + + # V5 setup: clone output and grad + grad_out = distribute_tensor( + grad_output_on_host.to(device=dm.device, dtype=dtype), dm.device_mesh_subgroups, single_placements + ) + out_clone = out.detach().clone().requires_grad_(out.requires_grad) + grad_out_clone = grad_out.detach().clone().requires_grad_(grad_out.requires_grad) + + out.backward(grad_out) + + # V4b: FW input unchanged after backward + _assert_unchanged(x.to_local(), x_clone.to_local()) + + # V5: BW inputs unchanged + _assert_unchanged(out.to_local(), out_clone.to_local()) + _assert_unchanged(grad_out.to_local(), grad_out_clone.to_local()) + + # V9: input gradient parity + grad_atol = 5e-2 if dtype == torch.bfloat16 else 1e-5 + grad_rtol = 5e-2 if dtype == torch.bfloat16 else 1.3e-6 + assert x.grad is not None, "Input gradient is None" + assert_close( + x.grad.full_tensor(), + input_grad_ref_on_host.to(device=dm.device, dtype=dtype), + atol=grad_atol, + rtol=grad_rtol, + ) + + # V10: parameter gradient parity + for name, param in dist_mod.named_parameters(): + ref = param_grads_ref.get(name) + assert param.grad is not None, f"Param {name} grad is None" + assert ref is not None, f"Reference grad for {name} is None" + actual = param.grad.full_tensor() if isinstance(param.grad, DTensor) else param.grad + assert_close( + actual, + ref.to(device=dm.device, dtype=dtype), + atol=grad_atol, + rtol=grad_rtol, + msg=lambda m, n=name: f"{n} grad mismatch\n{m}", + ) + + # V10b: replicated parameter gradients identical across all CP ranks + grad_full = actual.detach() + assert_all_identical(grad_full, dm.group["cp"]) + + # Non-vacuous: parameter gradients must be non-zero + assert actual.abs().sum() > 0, f"Param {name} gradient is all-zero" + + # Non-vacuous: input gradient must be non-zero + assert x.grad.full_tensor().abs().sum() > 0, "Input gradient is all-zero" + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, dtype", + [ + # CUDA dp=1 cp=(1,1): serial-equivalent sanity check (1 GPU) + (((1, (1, 1)), True, "cuda", "ENV"), torch.float32), + # CUDA dp=1 cp=(2,2): fp32 baseline + (((1, (2, 2)), True, "cuda", "ENV"), torch.float32), + # CUDA dp=2 cp=(2,2): DP + CP + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32), + # CUDA dp=2 cp=(1,1): bf16 mixed precision, 2-GPU + (((2, (1, 1)), True, "cuda", "ENV"), torch.bfloat16), + # CUDA dp=1 cp=(2,2): bf16 with actual CP, 4-GPU + (((1, (2, 2)), True, "cuda", "ENV"), torch.bfloat16), + # CPU dp=2 cp=(3,3): DP + non-power-of-two CP for CPU-only CI + (((2, (3, 3)), True, "cpu", "ENV"), torch.float32), + ], + indirect=("setup_env",), + ids=[ + "cuda-dp1-cp1x1-fp32", + "cuda-dp1-cp2x2-fp32", + "cuda-dp2-cp2x2-fp32", + "cuda-dp2-cp1x1-bf16", + "cuda-dp1-cp2x2-bf16", + "cpu-dp2-cp3x3-fp32", + ], +) +def test_dtensor_bfactor_forward_backward(setup_env, dtype: torch.dtype): + """BFactorModule: distributed output and gradients match serial reference.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + skip_if_cuda_not_avail_or_device_count_less_than_word_size(device_type=device_type, world_size=world_size) + + B, token_s, num_bins = 2, 32, 8 + N = 8 * grid_group_sizes["cp"][0] + + with torch.random.fork_rng(devices=[], enabled=True): + torch.manual_seed(SEED) + x = torch.randn(B, N, token_s, requires_grad=True) + serial = SerialBFactorModule(token_s=token_s, num_bins=num_bins) + init_module_params_uniform(serial, low=-0.5, high=0.5) + serial = serial.to(dtype=dtype) + state_dict = {k: v.cpu().clone() for k, v in serial.state_dict().items()} + + x_typed = x.to(dtype=dtype).detach().requires_grad_(True) + x_copy = x_typed.detach().clone().requires_grad_(True) + + out_ref = serial(x_typed) + + # V1a: serial FW input unchanged + _assert_unchanged(x_typed, x_copy, serial=True) + + grad_out = torch.randn_like(out_ref) + grad_out_copy = grad_out.detach().clone() + out_ref.backward(grad_out) + + # V1b: serial FW input unchanged after backward + _assert_unchanged(x_typed, x_copy, serial=True) + # V2: serial BW input (grad_out) unchanged + _assert_unchanged(grad_out, grad_out_copy, serial=True) + + param_grads = OrderedDict( + (n, p.grad.detach().cpu().clone()) for n, p in serial.named_parameters() if p.grad is not None + ) + + spawn_multiprocessing( + _worker_bfactor_parity, + world_size, + x_typed.detach().cpu(), + out_ref.detach().cpu(), + grad_out.detach().cpu(), + x_typed.grad.detach().cpu(), + param_grads, + state_dict, + token_s, + num_bins, + dtype, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +# ====================================================================== # +# ContactConditioning parity tests # +# ====================================================================== # + + +def _worker_contact_conditioning_parity( + rank: int, + cc_input_on_host: Tensor, + ct_input_on_host: Tensor, + output_ref_on_host: Tensor, + grad_output_on_host: Tensor, + cc_grad_ref_on_host: Tensor, + ct_grad_ref_on_host: Tensor, + param_grads_ref: OrderedDict[str, Tensor | None], + state_dict: dict, + token_z: int, + dtype: torch.dtype, + grid_group_sizes: dict, + device_type: str, + backend: str, + env_map: dict[str, str] | None = None, +): + """Worker: compare distributed ContactConditioning against serial reference. + + Performs V4a, V4b, V5, V8, V9, V10, V10b checks. + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + dm = DistributedManager() + + serial = SerialContactConditioning(token_z=token_z, cutoff_min=0.0, cutoff_max=22.0) + serial = serial.to(device=dm.device, dtype=dtype) + serial.load_state_dict(state_dict) + + dist_mod = ContactConditioningDTensor(serial, device_mesh=dm.device_mesh_subgroups) + dist_mod.train() + + pair_placements = (Shard(0), Shard(1), Shard(2)) + + cc_dt = distribute_tensor( + cc_input_on_host.to(device=dm.device, dtype=dtype), dm.device_mesh_subgroups, pair_placements + ).requires_grad_(True) + ct_dt = distribute_tensor( + ct_input_on_host.to(device=dm.device, dtype=dtype), dm.device_mesh_subgroups, pair_placements + ).requires_grad_(True) + + # V4a setup: clone inputs + cc_dt_clone = cc_dt.detach().clone().requires_grad_(cc_dt.requires_grad) + ct_dt_clone = ct_dt.detach().clone().requires_grad_(ct_dt.requires_grad) + + out = dist_mod({"contact_conditioning": cc_dt, "contact_threshold": ct_dt}) + + # V4a: FW inputs unchanged + _assert_unchanged(cc_dt.to_local(), cc_dt_clone.to_local()) + _assert_unchanged(ct_dt.to_local(), ct_dt_clone.to_local()) + + # V8: forward parity + fw_atol = 1e-2 if dtype == torch.bfloat16 else 1e-5 + fw_rtol = 1e-2 if dtype == torch.bfloat16 else 1.3e-6 + assert_close( + out.full_tensor(), + output_ref_on_host.to(device=dm.device, dtype=dtype), + atol=fw_atol, + rtol=fw_rtol, + ) + + # V5 setup: clone output and grad + grad_out = distribute_tensor( + grad_output_on_host.to(device=dm.device, dtype=dtype), dm.device_mesh_subgroups, pair_placements + ) + out_clone = out.detach().clone().requires_grad_(out.requires_grad) + grad_out_clone = grad_out.detach().clone().requires_grad_(grad_out.requires_grad) + + out.backward(grad_out) + + # V4b: FW inputs unchanged after backward + _assert_unchanged(cc_dt.to_local(), cc_dt_clone.to_local()) + + # V5: BW inputs unchanged + _assert_unchanged(out.to_local(), out_clone.to_local()) + _assert_unchanged(grad_out.to_local(), grad_out_clone.to_local()) + + # V9: input gradient parity + # + # Tolerance rationale (fp32): the encoder.weight gradient is computed as + # dW = x.T @ grad_out, summing over K = B*N*N elements. Serial sums all K + # terms in one matmul; distributed partitions into K/W terms per rank then + # all-reduce-sums across W ranks. The different accumulation orders produce + # a relative error of O(sqrt(K) * eps). For cpu-dp2-cp3x3, K = 2*24*24 = + # 1152, giving sqrt(1152) * eps ≈ 4.0e-6. A 2x safety factor covers + # intermediate ops (Fourier embedding, masking arithmetic). + if dtype == torch.bfloat16: + grad_atol, grad_rtol = 5e-2, 5e-2 + else: + K = cc_input_on_host.shape[0] * cc_input_on_host.shape[1] * cc_input_on_host.shape[2] + fp32_eps = torch.finfo(torch.float32).eps + accum_rtol = 2.0 * math.sqrt(K) * fp32_eps + grad_rtol = max(1.3e-6, accum_rtol) + grad_atol = max(1e-5, 10.0 * accum_rtol) + assert cc_dt.grad is not None, "contact_conditioning input gradient is None" + assert_close( + cc_dt.grad.full_tensor(), + cc_grad_ref_on_host.to(device=dm.device, dtype=dtype), + atol=grad_atol, + rtol=grad_rtol, + ) + if ct_grad_ref_on_host is not None: + assert ct_dt.grad is not None, "contact_threshold input gradient is None" + assert_close( + ct_dt.grad.full_tensor(), + ct_grad_ref_on_host.to(device=dm.device, dtype=dtype), + atol=grad_atol, + rtol=grad_rtol, + ) + + # V10: parameter gradient parity (trainable params only) + for name, param in dist_mod.named_parameters(): + if not param.requires_grad: + continue + ref = param_grads_ref.get(name) + if ref is None: + continue + assert param.grad is not None, f"Param {name} grad is None" + actual = param.grad.full_tensor() if isinstance(param.grad, DTensor) else param.grad + assert_close( + actual, + ref.to(device=dm.device, dtype=dtype), + atol=grad_atol, + rtol=grad_rtol, + msg=lambda m, n=name: f"{n} grad mismatch\n{m}", + ) + + # V10b: replicated parameter gradients identical across all CP ranks + grad_full = actual.detach() + assert_all_identical(grad_full, dm.group["cp"]) + + # Non-vacuous: encoding_unspecified and encoding_unselected must have + # non-zero gradients to prove the UNSPECIFIED and UNSELECTED masking + # branches were exercised. A zero grad would mean the test data + # didn't include that contact type, making V10 trivially pass. + for enc_name in ("encoding_unspecified", "encoding_unselected"): + enc_param = dict(dist_mod.named_parameters())[enc_name] + assert enc_param.grad is not None, f"{enc_name} grad is None" + enc_grad = enc_param.grad.full_tensor() if isinstance(enc_param.grad, DTensor) else enc_param.grad + assert enc_grad.abs().sum() > 0, ( + f"{enc_name} gradient is all-zero — UNSPECIFIED/UNSELECTED masking " + f"branch was not exercised. Test data must include this contact type." + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env, dtype", + [ + # CUDA dp=1 cp=(1,1): serial-equivalent sanity check (1 GPU) + (((1, (1, 1)), True, "cuda", "ENV"), torch.float32), + # CUDA dp=1 cp=(2,2): fp32 baseline + (((1, (2, 2)), True, "cuda", "ENV"), torch.float32), + # CUDA dp=2 cp=(2,2): DP + CP + (((2, (2, 2)), True, "cuda", "ENV"), torch.float32), + # CUDA dp=2 cp=(1,1): bf16 mixed precision, 2-GPU + (((2, (1, 1)), True, "cuda", "ENV"), torch.bfloat16), + # CUDA dp=1 cp=(2,2): bf16 with actual CP, 4-GPU + (((1, (2, 2)), True, "cuda", "ENV"), torch.bfloat16), + # CPU dp=2 cp=(3,3): DP + non-power-of-two CP for CPU-only CI + (((2, (3, 3)), True, "cpu", "ENV"), torch.float32), + ], + indirect=("setup_env",), + ids=[ + "cuda-dp1-cp1x1-fp32", + "cuda-dp1-cp2x2-fp32", + "cuda-dp2-cp2x2-fp32", + "cuda-dp2-cp1x1-bf16", + "cuda-dp1-cp2x2-bf16", + "cpu-dp2-cp3x3-fp32", + ], +) +def test_dtensor_contact_conditioning_forward_backward(setup_env, dtype: torch.dtype): + """ContactConditioning: distributed output and gradients match serial reference.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + skip_if_cuda_not_avail_or_device_count_less_than_word_size(device_type=device_type, world_size=world_size) + + B, token_z = 2, 16 + N = 8 * grid_group_sizes["cp"][0] + num_cc_types = len(const.contact_conditioning_info) + + with torch.random.fork_rng(devices=[], enabled=True): + torch.manual_seed(SEED) + + serial = SerialContactConditioning(token_z=token_z, cutoff_min=0.0, cutoff_max=22.0) + init_module_params_uniform(serial, low=-0.5, high=0.5) + serial = serial.to(dtype=dtype) + state_dict = {k: v.cpu().clone() for k, v in serial.state_dict().items()} + + cc_data, ct_data = make_random_contact_conditioning_features( + B, + N, + num_cc_types, + dtype=dtype, + seed=SEED, + ) + cc_input = cc_data.requires_grad_(True) + ct_input = ct_data.requires_grad_(True) + + cc_copy = cc_input.detach().clone().requires_grad_(True) + ct_copy = ct_input.detach().clone().requires_grad_(True) + + feats = {"contact_conditioning": cc_input, "contact_threshold": ct_input} + out_ref = serial(feats) + + # V1a: serial FW inputs unchanged + _assert_unchanged(cc_input, cc_copy, serial=True) + _assert_unchanged(ct_input, ct_copy, serial=True) + + grad_out = torch.randn_like(out_ref) + grad_out_copy = grad_out.detach().clone() + out_ref.backward(grad_out) + + # V1b: serial FW inputs unchanged after backward + _assert_unchanged(cc_input, cc_copy, serial=True) + _assert_unchanged(ct_input, ct_copy, serial=True) + # V2: serial BW input (grad_out) unchanged + _assert_unchanged(grad_out, grad_out_copy, serial=True) + + param_grads = OrderedDict() + for n, p in serial.named_parameters(): + if p.requires_grad and p.grad is not None: + param_grads[n] = p.grad.detach().cpu().clone() + + ct_grad = ct_input.grad.detach().cpu().clone() if ct_input.grad is not None else None + + spawn_multiprocessing( + _worker_contact_conditioning_parity, + world_size, + cc_input.detach().cpu(), + ct_input.detach().cpu(), + out_ref.detach().cpu(), + grad_out.detach().cpu(), + cc_input.grad.detach().cpu(), + ct_grad, + param_grads, + state_dict, + token_z, + dtype, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/distributed/model/modules/test_dtensor_utils.py b/tests/distributed/model/modules/test_dtensor_utils.py new file mode 100644 index 000000000..0d9cdfb38 --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_utils.py @@ -0,0 +1,747 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Unit tests for DTensor checkpoint conversion helpers.""" + +import socket +from collections import OrderedDict + +import pytest +import torch +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.utils import ( + OffloadActvCkptToCPU, + SDPAWithBiasBackend, + SetAttnPairBiasBackend, + SetAttnPairBiasShardwiseBackend, + SetTriAttnBackend, + TriAttnBackend, + _convert_serial_value_to_template_layout, + convert_distributed_checkpoint_to_serial_state_dict, + convert_dtensors_to_tensors, + convert_serial_checkpoint_to_distributed_state_dict, + has_dtensors, +) +from boltz.testing.utils import create_boltz2_model_init_params, spawn_multiprocessing + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +@pytest.fixture +def single_rank_dist_manager(monkeypatch): + # DistributedManager cleanup currently destroys process groups but does not + # always clear singleton state; force reset so each test starts clean. + DistributedManager._state = {} + + port = str(_find_free_port()) + monkeypatch.setenv("MASTER_ADDR", "127.0.0.1") + monkeypatch.setenv("MASTER_PORT", port) + monkeypatch.setenv("WORLD_SIZE", "1") + monkeypatch.setenv("RANK", "0") + monkeypatch.setenv("LOCAL_RANK", "0") + + grid_group_sizes = OrderedDict(dp=1, cp=(1, 1)) + backend = DistributedManager.backend_for_device()["cpu"] + assert backend is not None, "Gloo backend must be available for CPU DTensor tests" + DistributedManager.initialize(grid_group_sizes, device_type="cpu", backend=backend) + manager = DistributedManager() + yield manager + DistributedManager.cleanup() + DistributedManager._state = {} + + +def _as_replicated_dtensor(tensor: torch.Tensor, manager: DistributedManager) -> DTensor: + return distribute_tensor( + tensor.to(manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=(Replicate(), Replicate(), Replicate()), + ) + + +def test_has_dtensors_detects_nested_dtensors(single_rank_dist_manager: DistributedManager): + """Goals: has_dtensors finds DTensors inside nested dicts/lists.""" + dtensor = _as_replicated_dtensor(torch.randn(2, 3), single_rank_dist_manager) + nested = {"a": [torch.ones(1), {"b": dtensor}], "c": (1, 2, 3)} + assert has_dtensors(nested) + assert not has_dtensors({"a": [1, 2], "b": {"c": torch.tensor([1.0])}}) + + +def test_convert_dtensors_to_tensors_recursively(single_rank_dist_manager: DistributedManager): + """Goals: convert_dtensors_to_tensors strips DTensor metadata through nested containers.""" + dtensor = _as_replicated_dtensor(torch.randn(2, 3), single_rank_dist_manager) + nested = {"x": dtensor, "y": [dtensor], "z": ({"k": dtensor},)} + + converted = convert_dtensors_to_tensors(nested) + + assert isinstance(converted["x"], torch.Tensor) + assert not isinstance(converted["x"], DTensor) + assert not has_dtensors(converted) + torch.testing.assert_close(converted["x"], dtensor.to_local()) + torch.testing.assert_close(converted["y"][0], dtensor.to_local()) + torch.testing.assert_close(converted["z"][0]["k"], dtensor.to_local()) + + +def test_checkpoint_roundtrip_serial_to_distributed_to_serial(single_rank_dist_manager: DistributedManager): + """Goals: serial→distributed→serial roundtrip preserves tensor values.""" + serial_weight = torch.randn(4, 4) + serial_bias = torch.randn(4) + + state_template = { + "layer.weight": _as_replicated_dtensor(torch.zeros_like(serial_weight), single_rank_dist_manager), + "layer.bias": torch.zeros_like(serial_bias), + } + checkpoint = { + "state_dict": { + "layer.weight": serial_weight.clone(), + "layer.bias": serial_bias.clone(), + } + } + + distributed_state = convert_serial_checkpoint_to_distributed_state_dict( + checkpoint=checkpoint, + strict=True, + state_dict_template=state_template, + ) + + assert isinstance(distributed_state["layer.weight"], DTensor) + assert not isinstance(distributed_state["layer.bias"], DTensor) + assert has_dtensors(distributed_state) + + roundtrip_state = convert_distributed_checkpoint_to_serial_state_dict({"state_dict": distributed_state}) + assert not has_dtensors(roundtrip_state) + torch.testing.assert_close(roundtrip_state["layer.weight"], serial_weight) + torch.testing.assert_close(roundtrip_state["layer.bias"], serial_bias) + + +def test_checkpoint_conversion_strict_key_mismatch_raises(single_rank_dist_manager: DistributedManager): + """Goals: strict mode raises KeyError when checkpoint and template keys differ.""" + checkpoint = {"state_dict": {"foo": torch.tensor([1.0])}} + template = {"bar": torch.tensor([1.0])} + + with pytest.raises(KeyError, match="State-dict keys do not match template keys"): + convert_serial_checkpoint_to_distributed_state_dict( + checkpoint=checkpoint, + strict=True, + state_dict_template=template, + ) + + +# --------------------------------------------------------------------------- +# Error-path tests (parametrized) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "checkpoint,match", + [ + ({"meta": 1}, "state_dict"), + ({"state_dict": [1, 2, 3]}, "must be a mapping"), + ], + ids=["missing-state-dict", "non-mapping-state-dict"], +) +def test_to_serial_error_paths(checkpoint, match): + """Goals: convert_distributed_checkpoint_to_serial_state_dict rejects invalid input.""" + with pytest.raises((KeyError, TypeError), match=match): + convert_distributed_checkpoint_to_serial_state_dict(checkpoint) + + +@pytest.mark.parametrize( + "checkpoint,template,error_type,match", + [ + ({"meta": 1}, {"k": torch.tensor([1.0])}, KeyError, "state_dict"), + ({"state_dict": {"k": torch.tensor([1.0])}}, None, ValueError, "state_dict_template is required"), + ({"state_dict": "not_a_dict"}, {"k": torch.tensor([1.0])}, TypeError, "must be a mapping"), + ], + ids=["missing-state-dict", "no-template", "non-mapping-state-dict"], +) +def test_to_distributed_error_paths(checkpoint, template, error_type, match): + """Goals: convert_serial_checkpoint_to_distributed_state_dict rejects invalid input.""" + with pytest.raises(error_type, match=match): + convert_serial_checkpoint_to_distributed_state_dict( + checkpoint=checkpoint, + state_dict_template=template, + ) + + +# --------------------------------------------------------------------------- +# Non-strict mode: extra keys in checkpoint are preserved +# --------------------------------------------------------------------------- + + +def test_to_distributed_non_strict_preserves_extra_keys(): + """Extra checkpoint keys not in the template are passed through in non-strict mode.""" + extra_tensor = torch.tensor([42.0]) + checkpoint = { + "state_dict": { + "in_template": torch.tensor([1.0]), + "extra_key": extra_tensor.clone(), + } + } + template = {"in_template": torch.tensor([0.0])} + + result = convert_serial_checkpoint_to_distributed_state_dict( + checkpoint=checkpoint, + state_dict_template=template, + strict=False, + ) + + assert "in_template" in result + assert "extra_key" in result + torch.testing.assert_close(result["extra_key"], extra_tensor) + + +def test_to_distributed_non_strict_ignores_missing_template_keys(): + """Template keys absent from checkpoint are silently skipped in non-strict mode.""" + checkpoint = {"state_dict": {"a": torch.tensor([1.0])}} + template = {"a": torch.tensor([0.0]), "b": torch.tensor([0.0])} + + result = convert_serial_checkpoint_to_distributed_state_dict( + checkpoint=checkpoint, + state_dict_template=template, + strict=False, + ) + + assert "a" in result + assert "b" not in result + + +# --------------------------------------------------------------------------- +# _convert_serial_value_to_template_layout: shape mismatch +# --------------------------------------------------------------------------- + + +def test_shape_mismatch_raises_value_error(single_rank_dist_manager: DistributedManager): + """ValueError when serial tensor shape does not match DTensor template shape.""" + template_dtensor = _as_replicated_dtensor(torch.zeros(4, 4), single_rank_dist_manager) + wrong_shape_tensor = torch.randn(3, 5) + + with pytest.raises(ValueError, match="does not match template shape"): + _convert_serial_value_to_template_layout(wrong_shape_tensor, template_dtensor) + + +def test_dtensor_to_dtensor_passthrough(single_rank_dist_manager: DistributedManager): + """Goals: DTensor value with DTensor template is returned as-is (no re-distribution).""" + dtensor = _as_replicated_dtensor(torch.randn(3, 3), single_rank_dist_manager) + template_dtensor = _as_replicated_dtensor(torch.zeros(3, 3), single_rank_dist_manager) + + result = _convert_serial_value_to_template_layout(dtensor, template_dtensor) + + assert isinstance(result, DTensor) + assert result is dtensor # should be the exact same object + + +def _parallel_assert_sharded_template_checkpoint_conversion(rank: int, payload): + grid_group_sizes, device_type, backend, env_per_rank, serial_weight, serial_bias = payload + + monkeypatch = pytest.MonkeyPatch() + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + DistributedManager._state = {} + try: + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + template_weight = distribute_tensor( + torch.zeros_like(serial_weight, device=manager.device), + device_mesh=manager.device_mesh_subgroups, + placements=(Replicate(), Shard(0), Shard(1)), + ) + template_bias = torch.zeros_like(serial_bias, device=manager.device) + template = { + "layer.weight": template_weight, + "layer.bias": template_bias, + } + checkpoint = { + "state_dict": { + "layer.weight": serial_weight.clone(), + "layer.bias": serial_bias.clone(), + } + } + + converted = convert_serial_checkpoint_to_distributed_state_dict( + checkpoint=checkpoint, + strict=True, + state_dict_template=template, + ) + + weight_dtensor = converted["layer.weight"] + assert isinstance(weight_dtensor, DTensor) + assert weight_dtensor.placements == template_weight.placements + torch.testing.assert_close(weight_dtensor.full_tensor().cpu(), serial_weight) + + cp_layout = manager.layout_subgroups["cp"] + cp_rank = manager.group_rank["cp"] + i, j = cp_layout.unravel(cp_rank) + expected_local = torch.chunk(serial_weight, cp_layout.shape[0], dim=0)[i] + expected_local = torch.chunk(expected_local, cp_layout.shape[1], dim=1)[j] + torch.testing.assert_close(weight_dtensor.to_local().cpu(), expected_local) + torch.testing.assert_close(converted["layer.bias"].cpu(), serial_bias) + + serialized = convert_distributed_checkpoint_to_serial_state_dict({"state_dict": converted}) + assert not has_dtensors(serialized) + # Sharded DTensors must serialize as full global tensors for topology portability. + torch.testing.assert_close(serialized["layer.weight"], serial_weight) + torch.testing.assert_close(serialized["layer.bias"], serial_bias) + + # Regression guard: distributed -> serial -> distributed roundtrip for sharded layout. + roundtrip = convert_serial_checkpoint_to_distributed_state_dict( + checkpoint={"state_dict": serialized}, + strict=True, + state_dict_template=template, + ) + roundtrip_weight = roundtrip["layer.weight"] + assert isinstance(roundtrip_weight, DTensor) + assert roundtrip_weight.placements == template_weight.placements + torch.testing.assert_close(roundtrip_weight.full_tensor().cpu(), serial_weight) + torch.testing.assert_close(roundtrip_weight.to_local().cpu(), expected_local) + finally: + DistributedManager.cleanup() + DistributedManager._state = {} + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=["cpu-dp1-cp2x2"], +) +def test_sharded_template_checkpoint_conversion_multi_rank(setup_env): + """Goals: serial→distributed conversion with Shard placements distributes data correctly across ranks.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + serial_weight = torch.arange(16, dtype=torch.float32).reshape(4, 4) + serial_bias = torch.arange(4, dtype=torch.float32) + payload = (grid_group_sizes, device_type, backend, env_per_rank, serial_weight, serial_bias) + spawn_multiprocessing(_parallel_assert_sharded_template_checkpoint_conversion, world_size, payload) + + +# --------------------------------------------------------------------------- +# SetTriAttnBackend +# --------------------------------------------------------------------------- + + +def _parallel_assert_set_triattn_backend(rank, env_per_rank, triattn_backend, boltz2_params): + """Worker: verify SetTriAttnBackend targets only PairformerLayer instances.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + from boltz.distributed.model.layers.pairformer import PairformerLayer + from boltz.distributed.model.models.boltz2 import Boltz2 as Boltz2Distributed + from boltz.model.models.boltz2 import Boltz2 as SerialBoltz2 + + grid_group_sizes = {"dp": 1, "cp": (1, 1)} + DistributedManager.initialize(grid_group_sizes, device_type="cuda", backend="nccl") + manager = DistributedManager() + + serial_model = SerialBoltz2(**boltz2_params).to(device=manager.device).eval() + dist_model = Boltz2Distributed(serial_model, manager).eval() + + pairformer_layers_before = [] + for name, submodule in dist_model.named_modules(): + if isinstance(submodule, PairformerLayer): + pairformer_layers_before.append(name) + assert ( + submodule.triattn_backend == TriAttnBackend.REFERENCE + ), f"{name}: expected REFERENCE before setter, got {submodule.triattn_backend}" + + assert len(pairformer_layers_before) > 0, "Model must contain at least one PairformerLayer" + + dist_model.apply(SetTriAttnBackend(triattn_backend)) + + for name, submodule in dist_model.named_modules(): + if isinstance(submodule, PairformerLayer): + assert ( + submodule.triattn_backend == triattn_backend + ), f"{name}: expected {triattn_backend} after setter, got {submodule.triattn_backend}" + else: + assert not hasattr(submodule, "triattn_backend") or isinstance(submodule, PairformerLayer), ( + f"Non-PairformerLayer module {name} ({type(submodule).__name__}) " + f"unexpectedly has triattn_backend attribute" + ) + + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + [((1, (1, 1)), False, "cuda", "ENV")], + indirect=True, + ids=["cuda-dp1-cp1x1"], +) +@pytest.mark.parametrize( + "triattn_backend", + [TriAttnBackend.CUEQ, TriAttnBackend.TRIFAST], + ids=lambda b: b.value, +) +def test_set_triattn_backend(setup_env, triattn_backend): + """SetTriAttnBackend sets triattn_backend only on PairformerLayer instances.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + spawn_multiprocessing( + _parallel_assert_set_triattn_backend, + world_size, + env_per_rank, + triattn_backend, + create_boltz2_model_init_params(use_large_model=False), + ) + + +# --------------------------------------------------------------------------- +# SetAttnPairBiasBackend +# --------------------------------------------------------------------------- + + +def _parallel_assert_set_attn_pair_bias_backend(rank, env_per_rank, sdpa_backend, boltz2_params): + """Worker: verify SetAttnPairBiasBackend targets only AttentionPairBias instances.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + from boltz.distributed.model.layers.attention import AttentionPairBias, AttentionPairBiasShardwise + from boltz.distributed.model.models.boltz2 import Boltz2 as Boltz2Distributed + from boltz.model.models.boltz2 import Boltz2 as SerialBoltz2 + + grid_group_sizes = {"dp": 1, "cp": (1, 1)} + DistributedManager.initialize(grid_group_sizes, device_type="cuda", backend="nccl") + manager = DistributedManager() + + serial_model = SerialBoltz2(**boltz2_params).to(device=manager.device).eval() + dist_model = Boltz2Distributed(serial_model, manager).eval() + + attn_modules_before = [] + for name, submodule in dist_model.named_modules(): + if isinstance(submodule, AttentionPairBias): + attn_modules_before.append(name) + assert submodule.sdpa_with_bias_backend != sdpa_backend, ( + f"{name}: default backend already matches {sdpa_backend}, " + "test cannot verify the setter changes anything" + ) + + assert len(attn_modules_before) > 0, "Model must contain at least one AttentionPairBias" + + shardwise_backends_before = {} + for name, submodule in dist_model.named_modules(): + if isinstance(submodule, AttentionPairBiasShardwise): + shardwise_backends_before[name] = submodule.sdpa_with_bias_backend + + dist_model.apply(SetAttnPairBiasBackend(sdpa_backend)) + + for name, submodule in dist_model.named_modules(): + if isinstance(submodule, AttentionPairBias): + assert submodule.sdpa_with_bias_backend == sdpa_backend, ( + f"{name}: expected {sdpa_backend} after setter, " f"got {submodule.sdpa_with_bias_backend}" + ) + if isinstance(submodule, AttentionPairBiasShardwise): + assert submodule.sdpa_with_bias_backend == shardwise_backends_before[name], ( + f"AttentionPairBiasShardwise {name} was unexpectedly changed by " f"SetAttnPairBiasBackend" + ) + + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + [((1, (1, 1)), False, "cuda", "ENV")], + indirect=True, + ids=["cuda-dp1-cp1x1"], +) +@pytest.mark.parametrize( + "sdpa_backend", + [SDPAWithBiasBackend.TORCH_FLEX_ATTN], + ids=lambda b: b.value, +) +def test_set_attn_pair_bias_backend(setup_env, sdpa_backend): + """SetAttnPairBiasBackend sets sdpa_with_bias_backend only on AttentionPairBias instances.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + spawn_multiprocessing( + _parallel_assert_set_attn_pair_bias_backend, + world_size, + env_per_rank, + sdpa_backend, + create_boltz2_model_init_params(use_large_model=False), + ) + + +# --------------------------------------------------------------------------- +# SetAttnPairBiasShardwiseBackend +# --------------------------------------------------------------------------- + + +def _parallel_assert_set_attn_pair_bias_shardwise_backend(rank, env_per_rank, sdpa_backend, boltz2_params): + """Worker: verify SetAttnPairBiasShardwiseBackend targets only AttentionPairBiasShardwise instances.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + from boltz.distributed.model.layers.attention import AttentionPairBias, AttentionPairBiasShardwise + from boltz.distributed.model.models.boltz2 import Boltz2 as Boltz2Distributed + from boltz.model.models.boltz2 import Boltz2 as SerialBoltz2 + + grid_group_sizes = {"dp": 1, "cp": (1, 1)} + DistributedManager.initialize(grid_group_sizes, device_type="cuda", backend="nccl") + manager = DistributedManager() + + serial_model = SerialBoltz2(**boltz2_params).to(device=manager.device).eval() + dist_model = Boltz2Distributed(serial_model, manager).eval() + + shardwise_modules_before = [] + for name, submodule in dist_model.named_modules(): + if isinstance(submodule, AttentionPairBiasShardwise): + shardwise_modules_before.append(name) + assert submodule.sdpa_with_bias_backend != sdpa_backend, ( + f"{name}: default backend already matches {sdpa_backend}, " + "test cannot verify the setter changes anything" + ) + + assert len(shardwise_modules_before) > 0, "Model must contain at least one AttentionPairBiasShardwise" + + ring_backends_before = {} + for name, submodule in dist_model.named_modules(): + if isinstance(submodule, AttentionPairBias): + ring_backends_before[name] = submodule.sdpa_with_bias_backend + + dist_model.apply(SetAttnPairBiasShardwiseBackend(sdpa_backend)) + + for name, submodule in dist_model.named_modules(): + if isinstance(submodule, AttentionPairBiasShardwise): + assert submodule.sdpa_with_bias_backend == sdpa_backend, ( + f"{name}: expected {sdpa_backend} after setter, " f"got {submodule.sdpa_with_bias_backend}" + ) + if isinstance(submodule, AttentionPairBias): + assert submodule.sdpa_with_bias_backend == ring_backends_before[name], ( + f"AttentionPairBias {name} was unexpectedly changed by " f"SetAttnPairBiasShardwiseBackend" + ) + + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + [((1, (1, 1)), False, "cuda", "ENV")], + indirect=True, + ids=["cuda-dp1-cp1x1"], +) +@pytest.mark.parametrize( + "sdpa_backend", + [SDPAWithBiasBackend.TORCH_FLEX_ATTN, SDPAWithBiasBackend.TORCH_SDPA_EFFICIENT_ATTENTION], + ids=lambda b: b.value, +) +def test_set_attn_pair_bias_shardwise_backend(setup_env, sdpa_backend): + """SetAttnPairBiasShardwiseBackend sets sdpa_with_bias_backend only on AttentionPairBiasShardwise instances.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + spawn_multiprocessing( + _parallel_assert_set_attn_pair_bias_shardwise_backend, + world_size, + env_per_rank, + sdpa_backend, + create_boltz2_model_init_params(use_large_model=False), + ) + + +# --------------------------------------------------------------------------- +# OffloadActvCkptToCPU +# --------------------------------------------------------------------------- + +_ALL_OFFLOAD_TYPES = ("DiffusionTransformer", "MSAModule", "PairformerModule") + + +def _parallel_assert_offload_actv_ckpt_to_cpu(rank, env_per_rank, target_names, boltz2_params): + """Worker: verify OffloadActvCkptToCPU targets only the requested module types.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + from boltz.distributed.model.layers.pairformer import PairformerModule + from boltz.distributed.model.models.boltz2 import Boltz2 as Boltz2Distributed + from boltz.distributed.model.modules.transformers import DiffusionTransformer + from boltz.distributed.model.modules.trunkv2 import MSAModule + from boltz.model.models.boltz2 import Boltz2 as SerialBoltz2 + + name_to_cls = { + "DiffusionTransformer": DiffusionTransformer, + "MSAModule": MSAModule, + "PairformerModule": PairformerModule, + } + target_classes = tuple(name_to_cls[n] for n in target_names) + all_classes = tuple(name_to_cls.values()) + + grid_group_sizes = {"dp": 1, "cp": (1, 1)} + DistributedManager.initialize(grid_group_sizes, device_type="cuda", backend="nccl") + manager = DistributedManager() + + serial_model = SerialBoltz2(**boltz2_params).to(device=manager.device).eval() + dist_model = Boltz2Distributed(serial_model, manager).eval() + + for name, submodule in dist_model.named_modules(): + if isinstance(submodule, all_classes): + assert ( + not submodule.cpu_offloading + ), f"{name} ({type(submodule).__name__}): cpu_offloading should be False before setter" + submodule.activation_checkpointing = True + + dist_model.apply(OffloadActvCkptToCPU(set(target_names))) + + found_targeted = {cls: 0 for cls in target_classes} + for name, submodule in dist_model.named_modules(): + if isinstance(submodule, target_classes): + assert ( + submodule.cpu_offloading + ), f"{name} ({type(submodule).__name__}): cpu_offloading should be True after setter" + for cls in target_classes: + if isinstance(submodule, cls): + found_targeted[cls] += 1 + elif isinstance(submodule, all_classes): + assert not submodule.cpu_offloading, ( + f"{name} ({type(submodule).__name__}): non-targeted module should still have " f"cpu_offloading=False" + ) + + for cls, count in found_targeted.items(): + assert count > 0, f"Model must contain at least one {cls.__name__} but found none" + + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + [((1, (1, 1)), False, "cuda", "ENV")], + indirect=True, + ids=["cuda-dp1-cp1x1"], +) +@pytest.mark.parametrize( + "target_names", + [ + ["DiffusionTransformer"], + ["MSAModule", "PairformerModule"], + list(_ALL_OFFLOAD_TYPES), + ], + ids=["score_model_only", "msa_and_pairformer", "all_three"], +) +def test_offload_actv_ckpt_to_cpu(setup_env, target_names): + """OffloadActvCkptToCPU sets cpu_offloading=True only on targeted module types.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + spawn_multiprocessing( + _parallel_assert_offload_actv_ckpt_to_cpu, + world_size, + env_per_rank, + target_names, + create_boltz2_model_init_params(use_large_model=False), + ) + + +def test_offload_actv_ckpt_to_cpu_invalid_type(): + """OffloadActvCkptToCPU rejects unrecognised module type names.""" + with pytest.raises(ValueError, match="Invalid module type"): + OffloadActvCkptToCPU({"InvalidModule"}) + + +def test_offload_actv_ckpt_to_cpu_empty_list(): + """OffloadActvCkptToCPU rejects an empty module_types list.""" + with pytest.raises(ValueError, match="must be non-empty"): + OffloadActvCkptToCPU(set()) + + +def _parallel_assert_offload_actv_ckpt_rejects_no_ckpt(rank, env_per_rank, boltz2_params): + """Worker: verify OffloadActvCkptToCPU raises when activation_checkpointing is off.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + from boltz.distributed.model.layers.pairformer import PairformerModule + from boltz.distributed.model.models.boltz2 import Boltz2 as Boltz2Distributed + from boltz.model.models.boltz2 import Boltz2 as SerialBoltz2 + + grid_group_sizes = {"dp": 1, "cp": (1, 1)} + DistributedManager.initialize(grid_group_sizes, device_type="cuda", backend="nccl") + manager = DistributedManager() + + serial_model = SerialBoltz2(**boltz2_params).to(device=manager.device).eval() + dist_model = Boltz2Distributed(serial_model, manager).eval() + + for _name, submodule in dist_model.named_modules(): + if isinstance(submodule, PairformerModule): + submodule.activation_checkpointing = False + break + + with pytest.raises(ValueError, match="activation_checkpointing is not enabled"): + dist_model.apply(OffloadActvCkptToCPU({"PairformerModule"})) + + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + [((1, (1, 1)), False, "cuda", "ENV")], + indirect=True, + ids=["cuda-dp1-cp1x1"], +) +def test_offload_actv_ckpt_to_cpu_rejects_no_ckpt(setup_env): + """OffloadActvCkptToCPU raises when activation_checkpointing is disabled on a target.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + spawn_multiprocessing( + _parallel_assert_offload_actv_ckpt_rejects_no_ckpt, + world_size, + env_per_rank, + create_boltz2_model_init_params(use_large_model=False), + ) diff --git a/tests/distributed/model/modules/test_dtensor_utils_center_random_augmentation.py b/tests/distributed/model/modules/test_dtensor_utils_center_random_augmentation.py new file mode 100644 index 000000000..0b7e3d98b --- /dev/null +++ b/tests/distributed/model/modules/test_dtensor_utils_center_random_augmentation.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DTensor center_random_augmentation. + +Adapted from Boltz-1x CP tests. Verifies centering, random augmentation, +and consistency across CP ranks. +""" + +import pytest +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.modules.utils import ( + center_random_augmentation, +) +from boltz.model.modules.utils import ( + center_random_augmentation as center_random_augmentation_serial, +) +from boltz.testing.utils import assert_all_identical, assert_tensors_identical, seed_by_rank, spawn_multiprocessing + + +def assert_center_random_augmentation(rank, payload): + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + centering, + augmentation, + s_trans, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device_mesh: DeviceMesh = manager.device_mesh_subgroups + replicate_group = device_mesh.get_group("cp_axis_1") + + n_samples_per_rank = 2 + batch_size = n_samples_per_rank * device_mesh.get_group("dp").size() + n_atoms = 5 * device_mesh.get_group("cp_axis_0").size() + + seed_by_rank(0, seed=42) + + atom_coords_gen_global = torch.randn( + (batch_size, n_atoms, 3), + dtype=torch.float32, + device=manager.device, + ) + + atom_coords = distribute_tensor( + atom_coords_gen_global, + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ) + + atom_mask_gen_global = torch.randint( + 0, + 2, + (batch_size, n_atoms), + dtype=torch.bool, + device=manager.device, + ) + + atom_mask = distribute_tensor( + atom_mask_gen_global, + device_mesh=device_mesh, + placements=(Shard(0), Shard(1), Replicate()), + ) + + atom_coords_copy = atom_coords.detach().clone() + atom_mask_copy = atom_mask.detach().clone() + + # all cp ranks should have same seed but + # different dp ranks should have different seeds + seed_by_rank(manager.group_rank["dp"]) + + results = center_random_augmentation( + atom_coords, + atom_mask, + augmentation=augmentation, + centering=centering, + s_trans=s_trans, + return_roto=augmentation, + ) + # no modification to the original tensor + assert_tensors_identical(atom_coords_copy.to_local(), atom_coords.to_local()) + assert_tensors_identical(atom_mask_copy.to_local(), atom_mask.to_local()) + + if augmentation: + atom_coords_augmented, random_R = results + else: + atom_coords_augmented = results + + # check if consistent across replicate ranks + atom_coords_augmented_local = atom_coords_augmented.to_local() + assert_all_identical(atom_coords_augmented_local, group=replicate_group) + + # check if mean is 0 + if centering and (not augmentation or s_trans == 0.0): + atom_mask_global = atom_mask.full_tensor().unsqueeze(-1) + assert_all_identical(atom_mask_global, group=manager.group["world"]) + + atom_coords_augmented_full = atom_coords_augmented.full_tensor() + assert_all_identical(atom_coords_augmented_full, group=manager.group["world"]) + + centroids = (atom_coords_augmented_full * atom_mask_global).sum(dim=1) / atom_mask_global.sum(dim=1) + + torch.testing.assert_close(centroids, torch.zeros_like(centroids)) + + if augmentation: + # Verify DTensor augmentation matches serial augmentation on global data. + # The V2 serial center_random_augmentation does not support return_roto, + # so we compare coordinates only. The rotation matrix consistency is + # verified by the assert_all_identical check above. + seed_by_rank(manager.group_rank["dp"]) + i_sample_begin = manager.group_rank["dp"] * n_samples_per_rank + i_sample_end = i_sample_begin + n_samples_per_rank + atom_coords_global = atom_coords.full_tensor()[i_sample_begin:i_sample_end] + atom_mask_global = atom_mask.full_tensor()[i_sample_begin:i_sample_end] + atom_coords_augmented_global_expected = center_random_augmentation_serial( + atom_coords_global, + atom_mask_global, + augmentation=augmentation, + centering=centering, + s_trans=s_trans, + ) + assert_all_identical(atom_coords_augmented_global_expected, group=manager.group["cp"]) + + atom_coords_augmented_global_result = atom_coords_augmented.full_tensor()[i_sample_begin:i_sample_end] + torch.testing.assert_close(atom_coords_augmented_global_result, atom_coords_augmented_global_expected) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +@pytest.mark.parametrize( + "config", + [ + (True, False, 1.0), + (False, False, 1.0), + (True, True, 2.0), + ], + ids=lambda x: f"centering:{x[0]}, augmentation:{x[1]}, s_trans:{x[2]}", +) +def test_center_random_augmentation( + setup_env, + config: tuple[bool, bool, float], +): + """Test DTensor center_random_augmentation vs serial equivalence.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + centering, augmentation, s_trans = config + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + centering, + augmentation, + s_trans, + ) + spawn_multiprocessing(assert_center_random_augmentation, world_size, payload) diff --git a/tests/distributed/model/optim/__init__.py b/tests/distributed/model/optim/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/tests/distributed/model/optim/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/tests/distributed/model/optim/test_dtensor_ema.py b/tests/distributed/model/optim/test_dtensor_ema.py new file mode 100644 index 000000000..26d1db38a --- /dev/null +++ b/tests/distributed/model/optim/test_dtensor_ema.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for DistributedEMA – the DTensor-aware EMA callback. + +All tests exercise DistributedEMA with DTensor parameters to validate +numerical parity with the base EMA, save/load serialisation, and the +eval-swap lifecycle. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest +import pytorch_lightning as pl +import torch +from torch import nn +from torch.distributed.tensor import DTensor, distribute_module + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.optim.ema import DistributedEMA +from boltz.model.optim.ema import EMA +from boltz.testing.utils import assert_all_identical, assert_tensors_identical, spawn_multiprocessing + + +def _distribute_model(model: pl.LightningModule, manager: DistributedManager) -> None: + """Distribute all sub-modules' parameters as replicated DTensors on the CP mesh.""" + for _name, child in model.named_children(): + distribute_module(child, manager.device_mesh_subgroups) + + +class _TinyModule(pl.LightningModule): + """Minimal model for EMA tests.""" + + def __init__(self) -> None: + super().__init__() + self.layer = nn.Linear(4, 2, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layer(x) + + def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: + return self(batch).sum() + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +# --------------------------------------------------------------------------- +# Distributed tests (require process group) +# --------------------------------------------------------------------------- + + +def _parallel_assert_ema_comprehensive(rank: int, payload: tuple[Any, ...]) -> None: + """Comprehensive EMA lifecycle with DTensors. + + Tests in sequence: init → multi-step parity with serial EMA → + cross-rank identity → save (DTensor→plain) → load (plain→DTensor + realignment) → eval swap lifecycle (replace/forward/restore). + """ + grid_group_sizes, device_type, backend, env_per_rank = payload + + monkeypatch = pytest.MonkeyPatch() + for key, value in env_per_rank.items(): + monkeypatch.setenv(key, f"{rank}" if value == "" else value) + + try: + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Create distributed + serial models with identical init. + with torch.random.fork_rng(devices=[]): + torch.manual_seed(42) + model = _TinyModule() + model.to(manager.device) + _distribute_model(model, manager) + + with torch.random.fork_rng(devices=[]): + torch.manual_seed(42) + serial_model = _TinyModule() + serial_model.to(manager.device) + + for param in model.state_dict().values(): + assert isinstance(param, DTensor) + + # Init EMA on both. + dist_ema = DistributedEMA(decay=0.9, eval_with_ema=True, warm_start=True) + serial_ema = EMA(decay=0.9, warm_start=True) + mock_trainer = MagicMock() + dist_ema.on_train_start(mock_trainer, model) + serial_ema.on_train_start(mock_trainer, serial_model) + + for w in dist_ema._ema_weights.values(): + assert isinstance(w, DTensor) + + # Run 5 EMA steps with matching perturbations. + for step in range(5): + with torch.random.fork_rng(devices=[]): + torch.manual_seed(100 + step) + deltas = [torch.randn_like(p) * 0.3 for p in serial_model.parameters()] + + with torch.no_grad(): + for p, d in zip(serial_model.parameters(), deltas): + p.add_(d) + serial_sd = serial_model.state_dict() + for n, p in model.named_parameters(): + (p.data.to_local() if isinstance(p.data, DTensor) else p).copy_(serial_sd[n]) + + serial_ema._cur_step = dist_ema._cur_step = step + mock_trainer.global_step = step + 1 + serial_ema.on_train_batch_end(mock_trainer, serial_model, None, None, 0) + dist_ema.on_train_batch_end(mock_trainer, model, None, None, 0) + + # Bitwise numerical parity with serial. + # Use full_tensor() for DTensor-vs-serial comparison (convention). + for key in serial_ema._ema_weights: + dw = dist_ema._ema_weights[key] + assert isinstance(dw, DTensor) + assert_tensors_identical(dw.full_tensor(), serial_ema._ema_weights[key]) + + # EMA actually had an effect: EMA weights must differ from the + # current model weights (EMA lags behind the perturbed model). + for key in dist_ema._ema_weights: + model_w = model.state_dict()[key].to_local() + ema_w = dist_ema._ema_weights[key].to_local() + assert not torch.equal(model_w, ema_w), ( + f"EMA weight '{key}' is identical to model weight after 5 steps — " "EMA may not be accumulating" + ) + + # Cross-rank identity. + world_group = torch.distributed.distributed_c10d._get_default_group() + for ema_w in dist_ema._ema_weights.values(): + assert_all_identical(ema_w.to_local(), world_group) + + # Save: DTensor EMA weights → plain tensors. + checkpoint: dict[str, Any] = {} + dist_ema.on_save_checkpoint(mock_trainer, model, checkpoint) + assert "ema" in checkpoint + for w in checkpoint["ema"]["ema_weights"].values(): + assert not isinstance(w, DTensor) + + # Load into fresh EMA: plain → DTensor realignment. + fresh_ema = DistributedEMA(decay=0.9, eval_with_ema=True) + fresh_ema.on_load_checkpoint(mock_trainer, model, checkpoint) + fresh_ema.on_train_start(mock_trainer, model) + assert fresh_ema._cur_step == dist_ema._cur_step + for key, loaded_w in fresh_ema._ema_weights.items(): + assert isinstance(loaded_w, DTensor) + assert loaded_w.placements == model.state_dict()[key].placements + torch.testing.assert_close(loaded_w.to_local(), dist_ema._ema_weights[key].to_local()) + + # Eval swap lifecycle. + training_weights = {k: v.to_local().clone() for k, v in model.state_dict().items()} + ema_snapshot = {k: v.to_local().clone() for k, v in fresh_ema._ema_weights.items()} + + fresh_ema.on_validation_start(mock_trainer, model) + + # Backup must be on CPU to avoid doubling GPU memory during validation. + for k, buf in fresh_ema._weights_buffer.items(): + assert buf.device.type == "cpu", ( + f"Weights buffer '{k}' should be on CPU to save GPU memory, " f"but is on {buf.device}" + ) + assert not isinstance(buf, DTensor), f"Weights buffer '{k}' should be a plain tensor on CPU, not a DTensor" + + for k, v in model.state_dict().items(): + assert isinstance(v, DTensor), f"DTensor semantics lost during replace for '{k}'" + assert_tensors_identical(v.to_local(), ema_snapshot[k]) + + fresh_ema.on_validation_end(mock_trainer, model) + for k, v in model.state_dict().items(): + assert isinstance(v, DTensor), f"DTensor semantics lost during restore for '{k}'" + assert_tensors_identical(v.to_local(), training_weights[k]) + finally: + DistributedManager.cleanup() + DistributedManager._state = {} + monkeypatch.undo() + + +def _parallel_assert_ema_inference_mode(rank: int, payload: tuple[Any, ...]) -> None: + """EMA replace/restore under torch.inference_mode(True). + + PyTorch >=2.10 disallows version-counter manipulation on inference + tensors. Lightning's trainer.predict() wraps the workflow in + inference_mode(True), so replace_model_weights and + restore_original_weights must handle this. + """ + grid_group_sizes, device_type, backend, env_per_rank = payload + + monkeypatch = pytest.MonkeyPatch() + for key, value in env_per_rank.items(): + monkeypatch.setenv(key, f"{rank}" if value == "" else value) + + try: + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + with torch.random.fork_rng(devices=[]): + torch.manual_seed(42) + model = _TinyModule() + model.to(manager.device) + _distribute_model(model, manager) + + ema = DistributedEMA(decay=0.9, eval_with_ema=True) + mock_trainer = MagicMock() + ema.on_train_start(mock_trainer, model) + + # Perturb EMA weights so they differ from model weights + with torch.no_grad(): + for w in ema._ema_weights.values(): + w.add_(0.5) + + training_weights = {k: v.to_local().clone() for k, v in model.state_dict().items()} + ema_snapshot = {k: v.to_local().clone() for k, v in ema._ema_weights.items()} + + with torch.inference_mode(True): + ema.replace_model_weights(model) + for k, v in model.state_dict().items(): + assert isinstance(v, DTensor), f"DTensor lost for '{k}' under inference_mode replace" + assert_tensors_identical(v.to_local(), ema_snapshot[k]) + + ema.restore_original_weights(model) + for k, v in model.state_dict().items(): + assert isinstance(v, DTensor), f"DTensor lost for '{k}' under inference_mode restore" + assert_tensors_identical(v.to_local(), training_weights[k]) + finally: + DistributedManager.cleanup() + DistributedManager._state = {} + monkeypatch.undo() + + +def _parallel_assert_ema_shape_mismatch_raises(rank: int, payload: tuple[Any, ...]) -> None: + """_realign_weights_to_model raises ValueError on shape mismatch.""" + grid_group_sizes, device_type, backend, env_per_rank = payload + + monkeypatch = pytest.MonkeyPatch() + for key, value in env_per_rank.items(): + monkeypatch.setenv(key, f"{rank}" if value == "" else value) + + try: + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + with torch.random.fork_rng(devices=[]): + torch.manual_seed(42) + model = _TinyModule() + model.to(manager.device) + _distribute_model(model, manager) + + try: + DistributedEMA._realign_weights_to_model( + {"layer.weight": torch.randn(99, 99), "layer.bias": torch.randn(2)}, model + ) + raise AssertionError("Expected ValueError for shape mismatch") # noqa: TRY301 + except ValueError as e: + assert "shape mismatch" in str(e).lower(), f"Unexpected error message: {e}" + finally: + DistributedManager.cleanup() + DistributedManager._state = {} + monkeypatch.undo() + + +# --------------------------------------------------------------------------- +# Distributed test parametrisation +# --------------------------------------------------------------------------- + + +_EMA_TOPOLOGIES = [ + ((1, (2, 2)), True, "cpu", "ENV"), + ((2, (1, 1)), True, "cuda", "ENV"), +] +_EMA_IDS = ["cpu-dp1-cp2x2", "cuda-dp2-cp1x1"] + + +@pytest.mark.parametrize("setup_env", _EMA_TOPOLOGIES, indirect=("setup_env",), ids=_EMA_IDS) +def test_ema_comprehensive_lifecycle(setup_env): + """Goals: full EMA lifecycle — parity, save/load roundtrip, eval swap, cross-rank identity.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + spawn_multiprocessing( + _parallel_assert_ema_comprehensive, world_size, (grid_group_sizes, device_type, backend, env_per_rank) + ) + + +@pytest.mark.parametrize("setup_env", _EMA_TOPOLOGIES, indirect=("setup_env",), ids=_EMA_IDS) +def test_ema_inference_mode(setup_env): + """Goals: replace/restore under inference_mode(True) — no version_counter error.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + spawn_multiprocessing( + _parallel_assert_ema_inference_mode, world_size, (grid_group_sizes, device_type, backend, env_per_rank) + ) + + +@pytest.mark.parametrize("setup_env", _EMA_TOPOLOGIES, indirect=("setup_env",), ids=_EMA_IDS) +def test_ema_shape_mismatch_raises(setup_env): + """Goals: _realign_weights_to_model raises ValueError when EMA shape != model shape.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + spawn_multiprocessing( + _parallel_assert_ema_shape_mismatch_raises, world_size, (grid_group_sizes, device_type, backend, env_per_rank) + ) diff --git a/tests/distributed/model/validation/test_dtensor_get_clash_metrics.py b/tests/distributed/model/validation/test_dtensor_get_clash_metrics.py new file mode 100644 index 000000000..d21758f86 --- /dev/null +++ b/tests/distributed/model/validation/test_dtensor_get_clash_metrics.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Standalone tests for DistributedValidator.get_clash_metrics at CP=(2,2). + +Tests: + - test_clash_score_counts_and_fraction: Triton clash_score kernel sanity check. + - test_dtensor_get_clash_metrics: Distributed get_clash_metrics vs serial + compute_chain_clashes, verifying the full DTensor gather path at CP=(2,2) + with both dp=1 (4 GPUs) and dp=2 (8 GPUs). +""" + +from __future__ import annotations + +import pytest +import torch +from torch.distributed.tensor import distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.loss.validation import clash_score +from boltz.distributed.model.validation.rcsb import DistributedRCSBValidator +from boltz.distributed.model.validation.utils import gather_along_cp +from boltz.model.loss.inference import compute_chain_clashes +from boltz.testing.utils import distribute_atom_features, get_feature_placements, random_features, spawn_multiprocessing + +_ATOM_KEYS = {"atom_pad_mask", "atom_to_token", "ref_element"} +_TOKEN_KEYS = {"asym_id"} +_placements = get_feature_placements(atom_keys=_ATOM_KEYS, token_keys=_TOKEN_KEYS) +SINGLE_REPR = _placements["single"] + +N_SAMPLES = 2 + + +# --------------------------------------------------------------------------- +# clash_score Triton kernel test +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_clash_score_counts_and_fraction(): + """Verify clash_score Triton kernel counts and fraction computation.""" + device = torch.device("cuda") + clash_cutoff = 2.0 + multiplicity = 2 + + coords_repr = torch.tensor( + [ + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [5.0, 0.0, 0.0], [9.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [3.0, 0.0, 0.0], [6.0, 0.0, 0.0], [9.0, 0.0, 0.0]], + ], + device=device, + dtype=torch.float32, + ) + token_pad_mask = torch.tensor([[True, True, True, False]], device=device) + + clash_atoms_count, clash_atoms_fraction = clash_score( + coords_repr=coords_repr, + token_pad_mask=token_pad_mask, + multiplicity=multiplicity, + clash_cutoff=clash_cutoff, + ) + + expected_count = torch.tensor([2, 0], device=device, dtype=clash_atoms_count.dtype) + expected_fraction = torch.tensor([2.0 / 3.0, 0.0], device=device, dtype=clash_atoms_fraction.dtype) + + torch.testing.assert_close(clash_atoms_count, expected_count) + torch.testing.assert_close(clash_atoms_fraction, expected_fraction) + + +# --------------------------------------------------------------------------- +# Distributed get_clash_metrics test +# --------------------------------------------------------------------------- + + +def _make_clash_test_data(n_tok, n_atom, batch_size=1, seed=42): + """Create test data for clash metrics.""" + atoms_per_tok = n_atom // n_tok + rng = torch.Generator().manual_seed(seed) + + feats = random_features( + size_batch=batch_size, + n_tokens=n_tok, + n_atoms=n_atom, + n_msa=1, + atom_counts_per_token_range=(atoms_per_tok, atoms_per_tok), + device=torch.device("cpu"), + float_value_range=(-1.0, 1.0), + selected_keys=["atom_to_token", "atom_pad_mask", "atom_counts_per_token", "asym_id", "ref_element"], + rng=rng, + ) + + batch: dict = {} + for k, v in feats.items(): + batch[k] = v.to(torch.float32) if v.is_floating_point() else v + batch["atom_to_token"] = batch["atom_to_token"].to(torch.float32) + + sample_coords = torch.randn(batch_size * N_SAMPLES, n_atom, 3, generator=rng) + + edges = [] + for t in range(n_tok): + for a in range(atoms_per_tok - 1): + edges.append([t * atoms_per_tok + a, t * atoms_per_tok + a + 1]) + if t < n_tok - 1: + edges.append([t * atoms_per_tok + atoms_per_tok - 1, (t + 1) * atoms_per_tok]) + edge_tensor = torch.tensor(edges, dtype=torch.long).T if edges else torch.empty(2, 0, dtype=torch.long) + batch["connections_edge_index"] = [edge_tensor.clone() for _ in range(batch_size)] + + batch["chain_symmetries"] = [] + for b in range(batch_size): + unique_asym = batch["asym_id"][b].unique().tolist() + batch["chain_symmetries"].append([[(int(cid), 0, "A", "PROTEIN", 0)] for cid in unique_asym]) + + out = {"sample_atom_coords": sample_coords} + return batch, out + + +def _distribute_clash_data(batch_serial, out_serial, manager): + """Distribute batch/out as DTensors for clash metrics. + + Atom features (atom_to_token, atom_pad_mask, ref_element) and + sample_atom_coords are distributed via ``distribute_atom_features`` + with intersperse padding. Token features use ``distribute_tensor``. + Non-sharded list features are sliced to the local DP rank's batch element. + """ + mesh = manager.device_mesh_subgroups + device = manager.device + dp_rank = manager.group_rank["dp"] + placements_token = _placements["token_features"] + + B = batch_serial["atom_pad_mask"].shape[0] + coords = out_serial["sample_atom_coords"] + coords_unflat = coords.unflatten(0, (B, N_SAMPLES)) + + inputs_atom = { + "atom_counts_per_token": batch_serial["atom_counts_per_token"], + "atom_to_token": batch_serial["atom_to_token"], + "atom_pad_mask": batch_serial["atom_pad_mask"], + "ref_element": batch_serial["ref_element"], + } + for i in range(N_SAMPLES): + inputs_atom[f"sample_atom_coords_{i}"] = coords_unflat[:, i] + + sample_cp = {f"sample_atom_coords_{i}": _placements["cp_single"] for i in range(N_SAMPLES)} + sample_dp_cp = {f"sample_atom_coords_{i}": _placements["single"] for i in range(N_SAMPLES)} + + feats_atom = distribute_atom_features( + inputs=inputs_atom, + placements_cp=_placements["cp_atom_features"] | sample_cp, + placements_dp_cp=_placements["atom_features"] | sample_dp_cp, + device_mesh=mesh, + cp_group=manager.group["cp"], + multiplicities={"sample_atom_coords": N_SAMPLES}, + ) + + batch_dt: dict = { + "atom_to_token": feats_atom.pop("atom_to_token"), + "atom_pad_mask": feats_atom.pop("atom_pad_mask"), + "ref_element": feats_atom.pop("ref_element"), + } + + batch_dt["asym_id"] = distribute_tensor( + batch_serial["asym_id"].to(device), device_mesh=mesh, placements=placements_token["asym_id"] + ) + batch_dt["chain_symmetries"] = [batch_serial["chain_symmetries"][dp_rank]] + batch_dt["connections_edge_index"] = [batch_serial["connections_edge_index"][dp_rank].to(device)] + + out_dt = {"sample_atom_coords": feats_atom.pop("sample_atom_coords")} + return batch_dt, out_dt + + +def _parallel_clash_worker(rank, payload): + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_results_per_batch, + batch_host, + out_host, + ) = payload + + mp = pytest.MonkeyPatch() + if env_per_rank is not None: + for k, v in env_per_rank.items(): + mp.setenv(k, f"{rank}" if v == "" else v) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + dp_rank = manager.group_rank["dp"] + + dist_validator = DistributedRCSBValidator( + val_names=["RCSB"], + confidence_prediction=False, + physicalism_metrics=True, + ) + dist_validator.to(manager.device) + + batch_dt, out_dt = _distribute_clash_data(batch_host, out_host, manager) + + batch_gathered = { + "asym_id": gather_along_cp(batch_dt["asym_id"]), + "atom_pad_mask": gather_along_cp(batch_dt["atom_pad_mask"]), + } + out_gathered = { + "sample_atom_coords": gather_along_cp(out_dt["sample_atom_coords"]), + } + + pair_clash_dict, pair_total_dict = dist_validator.get_clash_metrics( + batch_dt, + out_dt, + batch_gathered, + out_gathered, + ) + + expected_clash, expected_total = serial_results_per_batch[dp_rank] + for key in expected_clash: + torch.testing.assert_close( + pair_clash_dict[key].cpu(), + expected_clash[key], + msg=f"Clash mismatch on rank {rank} (dp={dp_rank}), key={key}", + ) + torch.testing.assert_close( + pair_total_dict[key].cpu(), + expected_total[key], + msg=f"Clash total mismatch on rank {rank} (dp={dp_rank}), key={key}", + ) + + DistributedManager.cleanup() + mp.undo() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["dp1-cp2x2", "dp2-cp2x2"], +) +def test_dtensor_get_clash_metrics(setup_env): + """Distributed get_clash_metrics matches serial compute_chain_clashes at CP=(2,2).""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + cp0 = grid_group_sizes["cp"][0] + num_dp = grid_group_sizes["dp"] + n_tok = 8 * cp0 + n_atom = 4 * n_tok + + batch, out = _make_clash_test_data(n_tok, n_atom, batch_size=num_dp, seed=42) + + serial_results_per_batch = [] + for b in range(num_dp): + b_feats = { + "atom_to_token": batch["atom_to_token"][b : b + 1], + "atom_pad_mask": batch["atom_pad_mask"][b : b + 1], + "ref_element": batch["ref_element"][b : b + 1], + "asym_id": batch["asym_id"][b : b + 1], + "chain_symmetries": [batch["chain_symmetries"][b]], + "connections_edge_index": [batch["connections_edge_index"][b]], + } + b_coords = out["sample_atom_coords"][b * N_SAMPLES : (b + 1) * N_SAMPLES] + cd, td = compute_chain_clashes(pred_atom_coords=b_coords, feats=b_feats) + serial_results_per_batch.append(({k: v.cpu() for k, v in cd.items()}, {k: v.cpu() for k, v in td.items()})) + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_results_per_batch, + batch, + out, + ) + spawn_multiprocessing(_parallel_clash_worker, world_size, payload) diff --git a/tests/distributed/model/validation/test_dtensor_get_pb_metrics.py b/tests/distributed/model/validation/test_dtensor_get_pb_metrics.py new file mode 100644 index 000000000..f71f3e02a --- /dev/null +++ b/tests/distributed/model/validation/test_dtensor_get_pb_metrics.py @@ -0,0 +1,267 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Standalone test for DistributedValidator.get_pb_metrics at CP=(2,2). + +Verifies that the distributed override produces identical results to the +serial compute_pb_geometry_metrics, compute_stereo_metrics, and +compute_pb_flatness_metrics by gathering DTensor features and comparing. +Tests both dp=1 (4 GPUs) and dp=2 (8 GPUs). + +Test data includes a real PHE (phenylalanine) ligand with CCD-derived +geometry features (66 atom-pair edges, 1 chiral center, 1 aromatic 6-ring) +so that all three PB metric types produce non-trivial results. +""" + +from __future__ import annotations + +import pytest +import torch +from torch.distributed.tensor import distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.validation.rcsb import DistributedRCSBValidator +from boltz.distributed.model.validation.utils import gather_along_cp +from boltz.model.loss.inference import ( + compute_pb_flatness_metrics, + compute_pb_geometry_metrics, + compute_stereo_metrics, +) +from boltz.testing.utils import ( + LIGAND_KEYS, + distribute_atom_features, + get_feature_placements, + make_pb_test_data, + spawn_multiprocessing, +) + +_ATOM_KEYS = {"atom_pad_mask", "atom_to_token"} +_TOKEN_KEYS = {"asym_id", "mol_type"} +_placements = get_feature_placements(atom_keys=_ATOM_KEYS, token_keys=_TOKEN_KEYS) +SINGLE_REPR = _placements["single"] + +N_SAMPLES = 2 + + +def _distribute_pb_data(batch_serial, out_serial, manager): + """Distribute batch/out as DTensors for PB metrics. + + Atom features (atom_to_token, atom_pad_mask) and sample_atom_coords + are distributed via ``distribute_atom_features`` with intersperse + padding. Token features use ``distribute_tensor``. + Non-sharded list features are sliced to the local DP rank's batch element. + """ + mesh = manager.device_mesh_subgroups + device = manager.device + dp_rank = manager.group_rank["dp"] + placements_token = _placements["token_features"] + + B = batch_serial["atom_pad_mask"].shape[0] + coords = out_serial["sample_atom_coords"] + coords_unflat = coords.unflatten(0, (B, N_SAMPLES)) + + inputs_atom = { + "atom_counts_per_token": batch_serial["atom_counts_per_token"], + "atom_to_token": batch_serial["atom_to_token"], + "atom_pad_mask": batch_serial["atom_pad_mask"], + } + for i in range(N_SAMPLES): + inputs_atom[f"sample_atom_coords_{i}"] = coords_unflat[:, i] + + sample_cp = {f"sample_atom_coords_{i}": _placements["cp_single"] for i in range(N_SAMPLES)} + sample_dp_cp = {f"sample_atom_coords_{i}": _placements["single"] for i in range(N_SAMPLES)} + + feats_atom = distribute_atom_features( + inputs=inputs_atom, + placements_cp=_placements["cp_atom_features"] | sample_cp, + placements_dp_cp=_placements["atom_features"] | sample_dp_cp, + device_mesh=mesh, + cp_group=manager.group["cp"], + multiplicities={"sample_atom_coords": N_SAMPLES}, + ) + + batch_dt: dict = { + "atom_to_token": feats_atom.pop("atom_to_token"), + "atom_pad_mask": feats_atom.pop("atom_pad_mask"), + } + + batch_dt["asym_id"] = distribute_tensor( + batch_serial["asym_id"].to(device), device_mesh=mesh, placements=placements_token["asym_id"] + ) + batch_dt["mol_type"] = distribute_tensor( + batch_serial["mol_type"].to(device), device_mesh=mesh, placements=placements_token["mol_type"] + ) + + for k in LIGAND_KEYS: + if k in batch_serial: + batch_dt[k] = [batch_serial[k][dp_rank].to(device)] + + out_dt = {"sample_atom_coords": feats_atom.pop("sample_atom_coords")} + return batch_dt, out_dt + + +def _parallel_pb_worker(rank, payload): + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_results_per_batch, + batch_host, + out_host, + ) = payload + + mp = pytest.MonkeyPatch() + if env_per_rank is not None: + for k, v in env_per_rank.items(): + mp.setenv(k, f"{rank}" if v == "" else v) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + dp_rank = manager.group_rank["dp"] + + dist_validator = DistributedRCSBValidator( + val_names=["RCSB"], + confidence_prediction=False, + physicalism_metrics=True, + ) + dist_validator.to(manager.device) + + batch_dt, out_dt = _distribute_pb_data(batch_host, out_host, manager) + + batch_gathered = { + "asym_id": gather_along_cp(batch_dt["asym_id"]), + "mol_type": gather_along_cp(batch_dt["mol_type"]), + } + out_gathered = { + "sample_atom_coords": gather_along_cp(out_dt["sample_atom_coords"]), + } + + pb_failure_dict, pb_total_dict = dist_validator.get_pb_metrics( + batch_dt, + out_dt, + batch_gathered, + out_gathered, + ) + + expected_failure, expected_total = serial_results_per_batch[dp_rank] + for key in expected_failure: + torch.testing.assert_close( + pb_failure_dict[key].cpu(), + expected_failure[key], + msg=f"PB failure mismatch on rank {rank} (dp={dp_rank}), key={key}", + ) + torch.testing.assert_close( + pb_total_dict[key].cpu(), + expected_total[key], + msg=f"PB total mismatch on rank {rank} (dp={dp_rank}), key={key}", + ) + + DistributedManager.cleanup() + mp.undo() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["dp1-cp2x2", "dp2-cp2x2"], +) +def test_dtensor_get_pb_metrics(setup_env, canonical_mols_dir): + """Distributed get_pb_metrics matches serial PB computation at CP=(2,2). + + Uses real PHE CCD geometry (66 edges, 1 chiral centre, 1 aromatic 6-ring) + so that all three PB metric families produce non-trivial totals. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + cp0 = grid_group_sizes["cp"][0] + num_dp = grid_group_sizes["dp"] + n_tok = 8 * cp0 + n_atom = 4 * n_tok + + batch, out = make_pb_test_data( + n_tok, + n_atom, + mols_dir=str(canonical_mols_dir), + batch_size=num_dp, + n_samples=N_SAMPLES, + seed=42, + ) + + serial_results_per_batch = [] + for b in range(num_dp): + b_feats: dict = { + "atom_to_token": batch["atom_to_token"][b : b + 1], + "atom_pad_mask": batch["atom_pad_mask"][b : b + 1], + "asym_id": batch["asym_id"][b : b + 1], + "mol_type": batch["mol_type"][b : b + 1], + } + b_feats.update({k: [batch[k][b]] for k in LIGAND_KEYS}) + + b_coords = out["sample_atom_coords"][b * N_SAMPLES : (b + 1) * N_SAMPLES] + (bl, ba, ic, nl) = compute_pb_geometry_metrics(pred_atom_coords=b_coords, feats=b_feats) + (cav, ca, sbv, sb) = compute_stereo_metrics(pred_atom_coords=b_coords, feats=b_feats) + (a5v, a5r, a6v, a6r, dbv, db) = compute_pb_flatness_metrics(pred_atom_coords=b_coords, feats=b_feats) + + assert nl.sum() > 0, "num_ligands should be > 0 (PHE present)" + assert ca.sum() > 0, "num_chiral_atoms should be > 0 (PHE has 1 chiral center)" + assert a6r.sum() > 0, "num_aromatic_6_rings should be > 0 (PHE has 1 aromatic ring)" + + failure = { + "bond_length": bl.cpu(), + "bond_angle": ba.cpu(), + "internal_clash": ic.cpu(), + "atom_chirality": cav.cpu(), + "bond_stereochemistry": sbv.cpu(), + "ring_5_flatness": a5v.cpu(), + "ring_6_flatness": a6v.cpu(), + "double_bond_flatness": dbv.cpu(), + } + total = { + "bond_length": nl.cpu(), + "bond_angle": nl.cpu(), + "internal_clash": nl.cpu(), + "atom_chirality": ca.cpu(), + "bond_stereochemistry": sb.cpu(), + "ring_5_flatness": a5r.cpu(), + "ring_6_flatness": a6r.cpu(), + "double_bond_flatness": db.cpu(), + } + serial_results_per_batch.append((failure, total)) + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_results_per_batch, + batch, + out, + ) + spawn_multiprocessing(_parallel_pb_worker, world_size, payload) diff --git a/tests/distributed/model/validation/test_dtensor_rcsb_validator.py b/tests/distributed/model/validation/test_dtensor_rcsb_validator.py new file mode 100644 index 000000000..0f17906e2 --- /dev/null +++ b/tests/distributed/model/validation/test_dtensor_rcsb_validator.py @@ -0,0 +1,900 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""End-to-end integration tests for DistributedRCSBValidator. + +Tests that the distributed validation pipeline produces identical metric +values to the serial RCSBValidator at CP=2 and CP=4. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch +from torch import Tensor +from torch.distributed.tensor import DTensor, distribute_tensor + +from boltz.data import const +from boltz.distributed.comm import TransposeComm +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.validation.rcsb import DistributedRCSBValidator +from boltz.model.validation.rcsb import RCSBValidator +from boltz.testing.utils import ( + get_feature_placements, + random_features, + skip_if_cuda_not_avail_or_device_count_less_than_word_size, + spawn_multiprocessing, +) + + +class _ConfigNamespace(SimpleNamespace): + """SimpleNamespace with dict-like .get() for OmegaConf compat.""" + + def get(self, key, default=None): + return getattr(self, key, default) + + +_ATOM_KEYS = {"atom_pad_mask", "atom_to_token", "atom_resolved_mask", "coords", "ref_element", "token_to_rep_atom"} +_TOKEN_KEYS = {"mol_type", "asym_id", "token_pad_mask", "token_index", "token_disto_mask"} +_placements = get_feature_placements(atom_keys=_ATOM_KEYS, token_keys=_TOKEN_KEYS) + +SINGLE_REPR = _placements["single"] +PAIR_REPR = _placements["pair"] +ENSEMBLE_REPR = _placements["atom_features"]["coords"] + + +def _make_mock_model(device="cpu", confidence_prediction=False, diffusion_samples=1, symmetry_correction=True): + """Build a lightweight mock LightningModule for validator tests.""" + model = SimpleNamespace( + val_group_mapper={0: {"label": "RCSB", "symmetry_correction": symmetry_correction}}, + validation_args=_ConfigNamespace( + recycling_steps=0, + sampling_steps=1, + diffusion_samples=diffusion_samples, + ), + confidence_prediction=confidence_prediction, + aggregate_distogram=True, + min_dist=2.0, + max_dist=22.0, + num_bins=8, + num_distograms=1, + device=device, + dp_group=None, + token_level_confidence=True, + _logged={}, + ) + + def _log(name, value, **kwargs): + model._logged[name] = value + + model.log = _log + + def _get_true_coordinates( + batch, + out, + diffusion_samples, + symmetry_correction, + expand_to_diffusion_samples=True, + ): + K, L = batch["coords"].shape[1:3] + mask = batch["atom_resolved_mask"] + tc = batch["coords"].squeeze(0) + if expand_to_diffusion_samples: + tc = tc.repeat((diffusion_samples, 1, 1)).reshape( + diffusion_samples, + K, + L, + 3, + ) + mask = mask.repeat_interleave(diffusion_samples, dim=0) + else: + mask = mask.squeeze(0) + return { + "true_coords": tc, + "true_coords_resolved_mask": mask, + "rmsds": 0, + "best_rmsd_recall": 0, + "best_rmsd_precision": 0, + } + + model.get_true_coordinates = _get_true_coordinates + return model + + +def _make_test_data( + n_tok, + n_atom, + num_bins, + seed=42, + n_samples=1, + confidence=False, + physicalism=False, +): + """Generate deterministic batch and output dicts for validator tests.""" + B, K, D = 1, 1, 1 + atoms_per_tok = n_atom // n_tok + assert n_atom == n_tok * atoms_per_tok + + rng = torch.Generator().manual_seed(seed) + + selected_keys = [ + "atom_pad_mask", + "atom_to_token", + "atom_resolved_mask", + "coords", + "mol_type", + "asym_id", + "token_pad_mask", + "token_index", + "token_disto_mask", + "disto_target", + "contact_conditioning", + ] + if confidence: + selected_keys += ["token_to_rep_atom", "r_set_to_rep_atom", "frames_idx"] + if physicalism: + selected_keys += ["ref_element"] + + feats = random_features( + size_batch=B, + n_tokens=n_tok, + n_atoms=n_atom, + n_msa=1, + atom_counts_per_token_range=(atoms_per_tok, atoms_per_tok), + device=torch.device("cpu"), + float_value_range=(-1.0, 1.0), + selected_keys=selected_keys, + num_disto_bins=num_bins, + rng=rng, + ) + + batch = {} + for k, v in feats.items(): + if v.is_floating_point(): + batch[k] = v.to(torch.float32) + else: + batch[k] = v + + # atom_to_token must be float for einsum in factored_lddt_loss + batch["atom_to_token"] = batch["atom_to_token"].to(torch.float32) + + # disto_target: (B, N, N, bins) -> (B, N, N, K, bins) for v2 distogram loss + batch["disto_target"] = batch["disto_target"].unsqueeze(3) + + # disto_coords_ensemble: needed by serial and distributed compute_disto_lddt + batch["disto_coords_ensemble"] = torch.randn(B, K * n_tok, 3, generator=rng) + + batch["idx_dataset"] = torch.tensor([0]) + + if n_samples == 1: + sample_coords = batch["coords"][:, 0, :, :].clone() + else: + sample_coords = torch.randn(n_samples, n_atom, 3, generator=rng) + + pdisto = torch.randn(B, n_tok, n_tok, D, num_bins, generator=rng) + out = {"sample_atom_coords": sample_coords, "pdistogram": pdisto} + + if confidence: + # Pad r_set_to_rep_atom to (B, n_tok, n_atom) for diagonal sharding + r_set = batch["r_set_to_rep_atom"] + n_r = r_set.shape[1] + if n_r < n_tok: + pad = torch.zeros(B, n_tok - n_r, n_atom, dtype=r_set.dtype, device=r_set.device) + batch["r_set_to_rep_atom"] = torch.cat([r_set, pad], dim=1) + + batch["frame_resolved_mask"] = torch.ones(B, n_tok, dtype=torch.bool) + + out["plddt"] = torch.randn(n_samples, n_tok, generator=rng).sigmoid() + out["pde"] = torch.randn(n_samples, n_tok, n_tok, generator=rng) + out["pae"] = torch.randn(n_samples, n_tok, n_tok, generator=rng) + for key in ( + "complex_plddt", + "complex_iplddt", + "complex_pde", + "complex_ipde", + "ptm", + "iptm", + "ligand_iptm", + "protein_iptm", + ): + out[key] = torch.randn(n_samples, generator=rng) + + if physicalism: + edges = [] + for t in range(n_tok): + for a in range(atoms_per_tok - 1): + edges.append([t * atoms_per_tok + a, t * atoms_per_tok + a + 1]) + if t < n_tok - 1: + edges.append([t * atoms_per_tok + atoms_per_tok - 1, (t + 1) * atoms_per_tok]) + if edges: + batch["connections_edge_index"] = [torch.tensor(edges, dtype=torch.long).T] + else: + batch["connections_edge_index"] = [torch.empty(2, 0, dtype=torch.long)] + + batch["chain_symmetries"] = [[[(0, 0, "A", "PROTEIN", 0)]]] + + for key in ( + "ligand_edge_index", + "ligand_stereo_bond_index", + "ligand_planar_double_bond_index", + ): + batch[key] = [torch.empty(2, 0, dtype=torch.long)] + for key in ( + "ligand_edge_lower_bounds", + "ligand_edge_upper_bounds", + ): + batch[key] = [torch.empty(0)] + for key in ( + "ligand_edge_bond_mask", + "ligand_edge_angle_mask", + "ligand_chiral_check_mask", + "ligand_chiral_atom_orientations", + "ligand_stereo_check_mask", + "ligand_stereo_bond_orientations", + ): + batch[key] = [torch.empty(0, dtype=torch.bool)] + batch["ligand_chiral_atom_index"] = [torch.empty(4, 0, dtype=torch.long)] + batch["ligand_aromatic_5_ring_index"] = [torch.empty(5, 0, dtype=torch.long)] + batch["ligand_aromatic_6_ring_index"] = [torch.empty(6, 0, dtype=torch.long)] + + return batch, out + + +def _extract_all_metrics(validator): + """Collect all non-NaN metric values from a validator into a flat dict.""" + results = {} + for gname, mlist in validator.folding_metrics.items(): + for ds_idx in range(len(mlist)): + for mname, metric in mlist[ds_idx].items(): + val = metric.compute() + if not torch.isnan(val): + results[f"{gname}/{ds_idx}/{mname}"] = val.item() + if validator.physicalism_metrics: + for gname, mlist in validator.physicalism_metrics.items(): + for ds_idx in range(len(mlist)): + for mname, metric in mlist[ds_idx].items(): + val = metric.compute() + if not torch.isnan(val): + results[f"phys/{gname}/{ds_idx}/{mname}"] = val.item() + if hasattr(validator, "confidence_metrics"): + for gname, mlist in validator.confidence_metrics.items(): + for ds_idx in range(len(mlist)): + for mname, metric in mlist[ds_idx].items(): + val = metric.compute() + if not torch.isnan(val): + results[f"conf/{gname}/{ds_idx}/{mname}"] = val.item() + return results + + +def _parallel_distributed_test(rank, payload): + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_metrics, + batch_host, + out_host, + confidence_prediction, + physicalism_metrics, + expand_to_diffusion_samples, + ) = payload + + n_samples = 2 if confidence_prediction else 1 + + mp = pytest.MonkeyPatch() + if env_per_rank is not None: + for k, v in env_per_rank.items(): + mp.setenv(k, f"{rank}" if v == "" else v) + + DistributedManager.initialize( + grid_group_sizes, + device_type=device_type, + backend=backend, + ) + manager = DistributedManager() + comm = TransposeComm(manager.group["cp"], manager.layout_subgroups["cp"]) + + # --- distribute batch/out as DTensors (inlined) --- + mesh = manager.device_mesh_subgroups + device = manager.device + cp0_size = mesh.get_group("cp_axis_0").size() + cp0_rank = mesh.get_coordinate()[1] + + placements_token = _placements["token_features"] + placements_atom = _placements["atom_features"] + + def _dt(tensor, placements): + return distribute_tensor( + tensor.to(device), + device_mesh=mesh, + placements=placements, + ) + + DIAG_SHARDED_KEYS = {"atom_to_token", "token_to_rep_atom", "r_set_to_rep_atom"} + + batch_dtensor = {} + for key, val in batch_host.items(): + if key == "idx_dataset": + batch_dtensor[key] = val.to(device) + continue + if key in DIAG_SHARDED_KEYS: + if key == "r_set_to_rep_atom" and cp0_size > 1: + # Reorder r_set rows so each shard's row range only contains + # elements whose representative atom falls within that shard's + # column range. + B, nr, nc = val.shape + nr_l = nr // cp0_size + nc_l = nc // cp0_size + valid_mask = val[0].any(dim=-1) + atom_indices = val[0].argmax(dim=-1) + shard_of_row = atom_indices // nc_l + reordered = torch.zeros_like(val) + for s in range(cp0_size): + shard_rows = val[:, (shard_of_row == s) & valid_mask] + n = min(shard_rows.shape[1], nr_l) + reordered[:, s * nr_l : s * nr_l + n] = shard_rows[:, :n] + val = reordered + + B, nr, nc = val.shape + nr_l = nr // cp0_size + nc_l = nc // cp0_size + local_block = ( + val[ + :, + cp0_rank * nr_l : (cp0_rank + 1) * nr_l, + cp0_rank * nc_l : (cp0_rank + 1) * nc_l, + ] + .contiguous() + .to(device) + ) + batch_dtensor[key] = DTensor.from_local( + local_block, + device_mesh=mesh, + placements=SINGLE_REPR, + ) + continue + if key in placements_atom: + batch_dtensor[key] = _dt(val, placements_atom[key]) + continue + if key in placements_token: + batch_dtensor[key] = _dt(val, placements_token[key]) + continue + if key in ("disto_target", "contact_conditioning"): + batch_dtensor[key] = _dt(val, PAIR_REPR) + continue + if isinstance(val, Tensor) and val.ndim >= 2: + batch_dtensor[key] = _dt(val, SINGLE_REPR) + elif isinstance(val, Tensor): + batch_dtensor[key] = val.to(device) + else: + batch_dtensor[key] = val + + out_dtensor = { + "sample_atom_coords": _dt(out_host["sample_atom_coords"], SINGLE_REPR), + "pdistogram": _dt(out_host["pdistogram"], PAIR_REPR), + } + for key in ("plddt",): + if key in out_host: + out_dtensor[key] = _dt(out_host[key], SINGLE_REPR) + for key in ("pde", "pae"): + if key in out_host: + out_dtensor[key] = _dt(out_host[key], PAIR_REPR) + for key in ( + "complex_plddt", + "complex_iplddt", + "complex_pde", + "complex_ipde", + "ptm", + "iptm", + "ligand_iptm", + "protein_iptm", + ): + if key in out_host: + out_dtensor[key] = out_host[key].to(device) + # --- end distribute --- + + dist_validator = DistributedRCSBValidator( + val_names=["RCSB"], + confidence_prediction=confidence_prediction, + physicalism_metrics=physicalism_metrics, + rmsd_metrics=True, + clash_score_metrics=True, + ) + dist_validator.to(manager.device) + + dist_model = _make_mock_model( + device=str(manager.device), + confidence_prediction=confidence_prediction, + diffusion_samples=n_samples, + ) + dist_model.dp_group = manager.device_mesh_subgroups.get_group(0) + + dist_validator.common_val_step( + model=dist_model, + batch=batch_dtensor, + out=out_dtensor, + idx_dataset=0, + expand_to_diffusion_samples=expand_to_diffusion_samples, + transpose_comm=comm, + ) + + dist_metrics = _extract_all_metrics(dist_validator) + + # Metrics that only exist in the distributed validator (not in serial). + _DISTRIBUTED_ONLY_PREFIXES = ("rmsd/", "clash_score/") + + for key in serial_metrics: + assert key in dist_metrics, f"Metric {key!r} missing in distributed result (rank {rank})" + torch.testing.assert_close( + torch.tensor(dist_metrics[key]), + torch.tensor(serial_metrics[key]), + msg=lambda m: f"Metric {key!r} mismatch on rank {rank}: {m}", + ) + + for key in dist_metrics: + if any(key.startswith(p) for p in _DISTRIBUTED_ONLY_PREFIXES): + continue + assert key in serial_metrics, f"Unexpected metric {key!r} in distributed (rank {rank})" + + DistributedManager.cleanup() + mp.undo() + + +def _parallel_epoch_end_test(rank, payload): + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + confidence_prediction, + physicalism_metrics, + ) = payload + + mp = pytest.MonkeyPatch() + if env_per_rank is not None: + for k, v in env_per_rank.items(): + mp.setenv(k, f"{rank}" if v == "" else v) + + DistributedManager.initialize( + grid_group_sizes, + device_type=device_type, + backend=backend, + ) + manager = DistributedManager() + + dv = DistributedRCSBValidator( + val_names=["RCSB"], + confidence_prediction=confidence_prediction, + physicalism_metrics=physicalism_metrics, + rmsd_metrics=True, + clash_score_metrics=True, + ) + dv.to(manager.device) + + model = _make_mock_model( + device=str(manager.device), + confidence_prediction=confidence_prediction, + ) + model.dp_group = manager.device_mesh_subgroups.get_group(0) + + logged_kwargs = [] + orig_log = model.log + + def _capture(*args, **kwargs): + logged_kwargs.append(dict(kwargs)) + return orig_log(*args, **kwargs) + + model.log = _capture + dv.on_epoch_end(model=model) + + DistributedManager.cleanup() + mp.undo() + + +@pytest.mark.parametrize( + "setup_env, confidence_prediction, physicalism_metrics", + [ + (((1, (2, 2)), True, "cuda", "ENV"), True, True), + ], + indirect=("setup_env",), + ids=[ + "dp1-cp2x2-conf-phys", + ], +) +@pytest.mark.parametrize("expand_to_diffusion_samples", [True, False]) +def test_distributed_rcsb_validator_matches_serial( + setup_env, confidence_prediction, physicalism_metrics, expand_to_diffusion_samples +): + """Distributed validator metrics match serial.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + skip_if_cuda_not_avail_or_device_count_less_than_word_size( + device_type=device_type, + world_size=world_size, + ) + + cp_axis_0 = grid_group_sizes["cp"][0] + n_tok = 8 * cp_axis_0 + n_atom = 32 * cp_axis_0 + num_bins = 8 + + # --- run serial reference (inlined) --- + n_samples = 2 if confidence_prediction else 1 + batch, out = _make_test_data( + n_tok, + n_atom, + num_bins, + seed=42, + n_samples=n_samples, + confidence=confidence_prediction, + physicalism=physicalism_metrics, + ) + model = _make_mock_model( + device="cpu", + confidence_prediction=confidence_prediction, + diffusion_samples=n_samples, + ) + validator = RCSBValidator( + val_names=["RCSB"], + confidence_prediction=confidence_prediction, + physicalism_metrics=physicalism_metrics, + ) + validator.common_val_step( + model=model, + batch=batch, + out=out, + idx_dataset=0, + expand_to_diffusion_samples=expand_to_diffusion_samples, + ) + serial_metrics = _extract_all_metrics(validator) + # --- end serial reference --- + + batch_host, out_host = _make_test_data( + n_tok, + n_atom, + num_bins, + seed=42, + n_samples=n_samples, + confidence=confidence_prediction, + physicalism=physicalism_metrics, + ) + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + serial_metrics, + batch_host, + out_host, + confidence_prediction, + physicalism_metrics, + expand_to_diffusion_samples, + ) + spawn_multiprocessing(_parallel_distributed_test, world_size, payload) + + +@pytest.mark.parametrize( + "setup_env, confidence_prediction, physicalism_metrics", + [ + (((1, (2, 2)), True, "cpu", "ENV"), True, True), + (((2, (1, 1)), True, "cuda", "ENV"), True, True), + ], + indirect=("setup_env",), + ids=[ + "cpu-cp2x2-conf-phys", + "cuda-dp2-cp1x1-conf-phys", + ], +) +def test_distributed_rcsb_validator_epoch_end_sync(setup_env, confidence_prediction, physicalism_metrics): + """common_on_epoch_end wraps model.log with sync_dist_group.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + skip_if_cuda_not_avail_or_device_count_less_than_word_size( + device_type=device_type, + world_size=world_size, + ) + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + confidence_prediction, + physicalism_metrics, + ) + spawn_multiprocessing(_parallel_epoch_end_test, world_size, payload) + + +def _parallel_dp_all_reduce_test(rank, payload): + """Worker for test_dp_all_reduce_epoch_end. + + Each DP rank updates the disto_loss MeanMetric with a rank-specific + value. After common_on_epoch_end (which calls _dp_all_reduce_metrics), + the logged disto_loss should be the global weighted mean across DP ranks. + """ + grid_group_sizes, device_type, backend, env_per_rank = payload + + mp = pytest.MonkeyPatch() + if env_per_rank is not None: + for k, v in env_per_rank.items(): + mp.setenv(k, f"{rank}" if v == "" else v) + + DistributedManager.initialize( + grid_group_sizes, + device_type=device_type, + backend=backend, + ) + manager = DistributedManager() + dp_rank = manager.device_mesh_subgroups.get_coordinate()[0] + + dv = DistributedRCSBValidator( + val_names=["RCSB"], + confidence_prediction=True, + physicalism_metrics=True, + rmsd_metrics=True, + clash_score_metrics=True, + ) + dv.to(manager.device) + + rank_values = [10.0, 20.0] + rank_weights = [1.0, 3.0] + local_val = torch.tensor(rank_values[dp_rank], device=manager.device) + local_weight = torch.tensor(rank_weights[dp_rank], device=manager.device) + dv.folding_metrics["disto_loss"][0]["disto_loss"].update(local_val, local_weight) + + model = _make_mock_model(device=str(manager.device)) + model.dp_group = manager.device_mesh_subgroups.get_group(0) + + dv.on_epoch_end(model=model) + + expected_global_mean = sum(v * w for v, w in zip(rank_values, rank_weights)) / sum(rank_weights) + + logged_disto = model._logged.get("val/disto_loss") + assert logged_disto is not None, f"Rank {rank}: val/disto_loss not logged" + + logged_val = logged_disto.item() if isinstance(logged_disto, Tensor) else float(logged_disto) + torch.testing.assert_close( + torch.tensor(logged_val), + torch.tensor(expected_global_mean), + msg=lambda m: f"Rank {rank}: DP all-reduce disto_loss mismatch: {m}", + ) + + DistributedManager.cleanup() + mp.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (1, 1)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp2-cp1x1"], +) +def test_dp_all_reduce_epoch_end(setup_env): + """DP all-reduce in common_on_epoch_end produces correct global means.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + skip_if_cuda_not_avail_or_device_count_less_than_word_size( + device_type=device_type, + world_size=world_size, + ) + + payload = (grid_group_sizes, device_type, backend, env_per_rank) + spawn_multiprocessing(_parallel_dp_all_reduce_test, world_size, payload) + + +def test_distributed_rcsb_validator_mro(): + """MRO gives DistributedValidator precedence over RCSBValidator.""" + from boltz.distributed.model.validation.validator import DistributedValidator + from boltz.model.validation.validator import Validator + + mro = DistributedRCSBValidator.__mro__ + dv_idx = mro.index(DistributedValidator) + rcsb_idx = mro.index(RCSBValidator) + v_idx = mro.index(Validator) + + assert dv_idx < rcsb_idx < v_idx, f"MRO wrong: DV@{dv_idx}, RCSB@{rcsb_idx}, V@{v_idx}" + + assert DistributedRCSBValidator.common_val_step is DistributedValidator.common_val_step + assert DistributedRCSBValidator.common_on_epoch_end is DistributedValidator.common_on_epoch_end + assert DistributedRCSBValidator._dp_all_reduce_metrics is DistributedValidator._dp_all_reduce_metrics + assert DistributedRCSBValidator.on_epoch_end is RCSBValidator.on_epoch_end + + +def _parallel_epoch_end_logged_metric_names_test(rank, payload): + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + confidence_prediction, + physicalism_metrics, + rmsd_metrics, + clash_score_metrics, + ) = payload + + mp = pytest.MonkeyPatch() + if env_per_rank is not None: + for k, v in env_per_rank.items(): + mp.setenv(k, f"{rank}" if v == "" else v) + + DistributedManager.initialize( + grid_group_sizes, + device_type=device_type, + backend=backend, + ) + manager = DistributedManager() + + dv = DistributedRCSBValidator( + val_names=["RCSB"], + confidence_prediction=confidence_prediction, + physicalism_metrics=physicalism_metrics, + rmsd_metrics=rmsd_metrics, + clash_score_metrics=clash_score_metrics, + ) + dv.to(manager.device) + + model = _make_mock_model( + device=str(manager.device), + confidence_prediction=confidence_prediction, + ) + model.dp_group = manager.device_mesh_subgroups.get_group(0) + + logged_names = [] + + def _capture_log(name, value, **kwargs): + logged_names.append(name) + + model.log = _capture_log + dv.on_epoch_end(model=model) + + # --- build expected logged names (inlined) --- + expected = set() + + for m_ in [*const.out_types, "pocket_ligand_protein", "contact_protein_protein"]: + expected.add(f"val/lddt_{m_}") + expected.add(f"val/disto_lddt_{m_}") + expected.add(f"val/complex_lddt_{m_}") + expected.add("val/lddt") + expected.add("val/disto_lddt") + expected.add("val/complex_lddt") + expected.add("val/disto_loss") + if rmsd_metrics: + expected.add("val/rmsd") + if clash_score_metrics: + expected.add("val/clash_atoms_count") + expected.add("val/clash_atoms_fraction") + + if confidence_prediction: + for m in const.out_single_types: + expected.add(f"val/MAE_plddt_{m}") + + out_types_no_modified = [m for m in const.out_types if m != "modified"] + for m_ in out_types_no_modified: + expected.add(f"val/MAE_pde_{m_}") + expected.add(f"val/MAE_pae_{m_}") + + conf_lddt_keys = [ + "top1_lddt", + "iplddt_top1_lddt", + "ipde_top1_lddt", + "pde_top1_lddt", + "ptm_top1_lddt", + "iptm_top1_lddt", + "ligand_iptm_top1_lddt", + "protein_iptm_top1_lddt", + "avg_lddt", + ] + for conf_key in conf_lddt_keys: + for m_ in out_types_no_modified: + expected.add(f"val/{conf_key}_{m_}") + + if physicalism_metrics: + clash_keys = [f"asym_{m_}" for m_ in const.clash_types] + [f"sym_{m_}" for m_ in const.out_single_types] + pb_keys = [ + "bond_length", + "bond_angle", + "internal_clash", + "atom_chirality", + "bond_stereochemistry", + "ring_5_flatness", + "ring_6_flatness", + "double_bond_flatness", + ] + for m in clash_keys: + expected.add(f"val/clash_{m}") + for m in pb_keys: + expected.add(f"val/pb_{m}") + + if confidence_prediction: + prefixes = [ + "top1", + "iplddt_top1", + "pde_top1", + "ipde_top1", + "ptm_top1", + "iptm_top1", + "ligand_iptm_top1", + "protein_iptm_top1", + "avg", + ] + for prefix in prefixes: + for m in clash_keys: + expected.add(f"val/{prefix}_clash_{m}") + for m in pb_keys: + expected.add(f"val/{prefix}_pb_{m}") + # --- end build expected --- + + logged_set = set(logged_names) + + missing = expected - logged_set + extra = logged_set - expected + assert not missing, f"Missing metrics: {sorted(missing)}" + assert not extra, f"Unexpected metrics: {sorted(extra)}" + + DistributedManager.cleanup() + mp.undo() + + +@pytest.mark.parametrize( + "setup_env, confidence_prediction, physicalism_metrics, rmsd_metrics, clash_score_metrics", + [ + (((1, (1, 1)), True, "cuda", "ENV"), False, False, False, False), + (((1, (1, 1)), True, "cuda", "ENV"), True, False, False, False), + (((1, (1, 1)), True, "cuda", "ENV"), False, True, False, False), + (((1, (1, 1)), True, "cuda", "ENV"), False, False, True, False), + (((1, (1, 1)), True, "cuda", "ENV"), False, False, False, True), + (((1, (1, 1)), True, "cuda", "ENV"), True, True, True, True), + ], + indirect=("setup_env",), + ids=["fold-only", "confidence", "physicalism", "rmsd", "clash-score", "all"], +) +def test_epoch_end_logged_metric_names( + setup_env, confidence_prediction, physicalism_metrics, rmsd_metrics, clash_score_metrics +): + """All expected metric names appear in model.log during on_epoch_end.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + skip_if_cuda_not_avail_or_device_count_less_than_word_size( + device_type=device_type, + world_size=world_size, + ) + + payload = ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + confidence_prediction, + physicalism_metrics, + rmsd_metrics, + clash_score_metrics, + ) + spawn_multiprocessing(_parallel_epoch_end_logged_metric_names_test, world_size, payload) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/distributed/test_dtensor_boltz2_train.py b/tests/distributed/test_dtensor_boltz2_train.py new file mode 100644 index 000000000..e276138e7 --- /dev/null +++ b/tests/distributed/test_dtensor_boltz2_train.py @@ -0,0 +1,2798 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""End-to-end Boltz-2 distributed training integration test via ``train()``. + +Calls the real ``train()`` entrypoint with real Boltz-2 training data and a +small model config, exercising the full pipeline: config loading → distributed +manager → distributed data module → distributed model wrapping → Trainer.fit +→ checkpoint. + +The only monkeypatches are: +- ``_cleanup_distributed → lambda: None`` (process group safety for tests) + +This test mirrors the pattern in ``test_dtensor_predict.py`` (real data, +real checkpoint) and ``test_dtensor_stop_and_go.py`` (``train()`` entrypoint +monkeypatching). +""" + +import copy +import functools +import importlib.util as _importlib_util +import math +import random as stdlib_random +import shutil +from dataclasses import dataclass +from enum import Enum, auto +from pathlib import Path +from typing import Any + +import numpy as np +import pytest +import pytorch_lightning as pl +import torch +from omegaconf import OmegaConf +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor + +import boltz.distributed.model.modules.diffusion as dist_diffusion_module +import boltz.distributed.train as train_module +import boltz.model.loss.diffusionv2 as serial_loss_v2_module +import boltz.model.modules.diffusionv2 as serial_diffusion_v2_module +from boltz.data.module.trainingv2 import ( + Boltz2TrainingDataModule, +) +from boltz.data.module.trainingv2 import ( + TrainingDataset as SerialTrainingDataset, +) +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.models.boltz2 import Boltz2 as Boltz2Distributed +from boltz.distributed.model.modules.diffusion import AtomDiffusion as DistAtomDiffusionV2 +from boltz.distributed.model.modules.utils import SDPAWithBiasBackend, TriAttnBackend +from boltz.distributed.predict import run_predict +from boltz.distributed.testing.utils import setup_mock_training_datamodule_config +from boltz.model.models.boltz2 import Boltz2 as SerialBoltz2 +from boltz.model.modules.diffusionv2 import AtomDiffusion as SerialAtomDiffusionV2 +from boltz.model.validation.rcsb import RCSBValidator +from boltz.testing.utils import ( + SetModuleInfValues, + concat_data, + distribute_atom_features, + get_feature_placements, + init_module_params_glorot, + init_tensors_uniform, + seed_by_rank, + spawn_multiprocessing, +) + + +class _Unset(Enum): + """Sentinel to distinguish 'not provided' from ``None``.""" + + UNSET = auto() + + +_UNSET = _Unset.UNSET + + +@dataclass +class TrainTestConfig: + """Unified config for writing YAML training configs in integration tests. + + Defaults match the most common usage (the general distributed training + function with 6 call sites). E2E parity callers override the fields + that differ. + """ + + config_path: Path + output_dir: Path + test_data_dir: Path + mol_dir: Path + + mode: str = "distributed" + + size_dp: int = 1 + size_cp: int = 1 + + accelerator: str = "cpu" + max_epochs: int = 1 + limit_train_batches: int = 2 + precision: str = "FP32" + num_sanity_val_steps: int = 0 + limit_val_batches: int | None = None + gradient_clip_val: float = 10.0 + + model: dict[str, Any] | None = None + pretrained: str | None = None + resume: str | None = None + + ema: bool = True + ema_decay: float = 0.999 + + batch_size: int = 1 + samples_per_epoch: int = 4 + max_tokens: int = 384 + max_atoms: int = 3456 + max_seqs: int = 128 + return_train_symmetries: bool = True + split: str | None = None + overfit: int | None = None + extra_dataset_overrides: dict[str, Any] | None = None + pop_target_keys: bool = False + + validate_structure: bool = False + validation_only: bool = False + seed: int = 42 + + v2: bool = False + strict_loading: bool = True + wandb: dict[str, Any] | None | _Unset = _UNSET + save_top_k: int = -1 + disable_checkpoint: bool = False + + +def _write_train_config(cfg: TrainTestConfig) -> None: + """Write a YAML training config from a :class:`TrainTestConfig`. + + Supports both ``mode="distributed"`` (default) and ``mode="serial"``. + """ + prod_yaml = Path(__file__).resolve().parents[2] / "scripts" / "train" / "configs" / "structurev2.yaml" + data_dict = OmegaConf.to_container(OmegaConf.load(prod_yaml).data, resolve=False) + + if cfg.pop_target_keys: + data_dict.pop("_target_", None) + + data_dict["datasets"] = [data_dict["datasets"][0]] + ds = data_dict["datasets"][0] + if cfg.pop_target_keys: + ds.pop("_target_", None) + + ds["target_dir"] = str(cfg.test_data_dir) + ds["msa_dir"] = str(cfg.test_data_dir / "msa") + ds["template_dir"] = None + ds["split"] = str(cfg.split) if cfg.split else None + ds["prob"] = 1.0 + + if cfg.extra_dataset_overrides: + for k, v in cfg.extra_dataset_overrides.items(): + ds[k] = v + + data_dict["samples_per_epoch"] = cfg.samples_per_epoch + data_dict["num_workers"] = 0 + data_dict["pin_memory"] = False + data_dict["use_templates"] = False + data_dict["return_train_symmetries"] = cfg.return_train_symmetries + data_dict["batch_size"] = cfg.batch_size + data_dict["max_tokens"] = cfg.max_tokens + data_dict["max_atoms"] = cfg.max_atoms + data_dict["max_seqs"] = cfg.max_seqs + data_dict["pad_to_max_tokens"] = True + data_dict["pad_to_max_atoms"] = True + data_dict["pad_to_max_seqs"] = True + data_dict["msa_sampling_training"] = False + data_dict["moldir"] = str(cfg.mol_dir) + + if cfg.overfit is not None: + data_dict["overfit"] = cfg.overfit + + if cfg.model is not None: + model_dict = cfg.model + else: + model_dict = { + "_target_": "boltz.model.models.boltz2.Boltz2", + "atom_s": 4, + "atom_z": 4, + "token_s": 4, + "token_z": 4, + "num_bins": 64, + "atom_feature_dim": 388, + "atoms_per_window_queries": 32, + "atoms_per_window_keys": 128, + "ema": cfg.ema, + "ema_decay": cfg.ema_decay, + "confidence_prediction": False, + "affinity_prediction": False, + "structure_prediction_training": True, + "use_templates": False, + "validate_structure": cfg.validate_structure, + "predict_bfactor": False, + "bond_type_feature": False, + "embedder_args": { + "atom_encoder_depth": 1, + "atom_encoder_heads": 1, + "activation_checkpointing": False, + }, + "msa_args": { + "msa_s": 4, + "msa_blocks": 1, + "msa_dropout": 0.0, + "z_dropout": 0.0, + "use_paired_feature": True, + }, + "pairformer_args": { + "num_blocks": 1, + "num_heads": 1, + "dropout": 0.0, + "v2": True, + }, + "score_model_args": { + "sigma_data": 16.0, + "dim_fourier": 4, + "atom_encoder_depth": 1, + "atom_encoder_heads": 1, + "token_transformer_depth": 1, + "token_transformer_heads": 1, + "atom_decoder_depth": 1, + "atom_decoder_heads": 1, + "activation_checkpointing": False, + "conditioning_transition_layers": 1, + }, + "diffusion_process_args": { + "coordinate_augmentation": False, + }, + "diffusion_loss_args": {}, + "training_args": { + "recycling_steps": 2, + "sampling_steps": 2, + "diffusion_multiplicity": 1, + "diffusion_samples": 1, + "diffusion_loss_weight": 1.0, + "distogram_loss_weight": 0.3, + "confidence_loss_weight": 0.0, + "bfactor_loss_weight": 0.0, + "symmetry_correction": False, + "adam_beta_1": 0.9, + "adam_beta_2": 0.95, + "adam_eps": 1e-8, + "lr_scheduler": "af3", + "base_lr": 1e-3, + "max_lr": 1e-3, + "lr_warmup_no_steps": 10, + "lr_start_decay_after_n_steps": 100, + "lr_decay_every_n_steps": 50000, + "lr_decay_factor": 0.95, + "weight_decay": 0.0, + }, + "validation_args": { + "recycling_steps": 0, + "sampling_steps": 2, + "diffusion_samples": 1, + "symmetry_correction": False, + "clash_cutoff": None, + }, + } + + config: dict[str, Any] = { + "data": data_dict, + "model": model_dict, + "output": str(cfg.output_dir), + "pretrained": cfg.pretrained, + "trainer": { + "accelerator": cfg.accelerator, + "devices": 1, + "max_epochs": cfg.max_epochs, + "limit_train_batches": cfg.limit_train_batches, + "enable_progress_bar": False, + "enable_model_summary": False, + "num_sanity_val_steps": cfg.num_sanity_val_steps, + "gradient_clip_val": cfg.gradient_clip_val, + }, + } + + if cfg.limit_val_batches is not None: + config["trainer"]["limit_val_batches"] = cfg.limit_val_batches + + if cfg.mode == "serial": + config["trainer"]["precision"] = 32 if cfg.precision == "FP32" else cfg.precision + config["v2"] = cfg.v2 + config["disable_checkpoint"] = cfg.disable_checkpoint + config["save_top_k"] = cfg.save_top_k + config["strict_loading"] = cfg.strict_loading + config["validation_only"] = cfg.validation_only + if cfg.wandb is not _UNSET: + config["wandb"] = cfg.wandb + else: + config["resume"] = cfg.resume + config["parallel_size"] = {"size_dp": cfg.size_dp, "size_cp": cfg.size_cp} + config["precision"] = cfg.precision + config["find_unused_parameters"] = False + config["save_top_k"] = cfg.save_top_k + config["disable_checkpoint"] = cfg.disable_checkpoint + config["debug"] = False + config["validation_only"] = cfg.validation_only + config["seed"] = cfg.seed + config["checkpoint"] = { + "monitor": None, + "save_last": True, + "every_n_epochs": 1, + } + config["triattn_backend"] = "reference" + config["sdpa_with_bias_backend"] = "reference" + config["sdpa_with_bias_shardwise_backend"] = "reference" + + cfg.config_path.parent.mkdir(parents=True, exist_ok=True) + OmegaConf.save(OmegaConf.create(config), cfg.config_path) + + +def _parallel_assert_boltz2_train(rank: int, payload: tuple[Any, ...]) -> None: + """Multi-rank worker: call real train() and verify completion.""" + env_per_rank, config_path, output_dir = payload + output_dir = Path(output_dir) + + monkeypatch = pytest.MonkeyPatch() + for key, value in env_per_rank.items(): + monkeypatch.setenv(key, f"{rank}" if value == "" else value) + + # Only monkeypatch cleanup — the model and data factories use their real implementations. + monkeypatch.setattr(train_module, "_cleanup_distributed", lambda: None) + DistributedManager._state = {} + + train_module.train(str(config_path), []) + + # Barrier BEFORE rank-0-only assertions: if rank 0's assertions fail, + # it raises before any trailing barrier, leaving other ranks stuck in + # an NCCL wait. Syncing first avoids this deadlock. + torch.distributed.barrier() + + # Assert: checkpoint was written (rank 0 only — file I/O, no collectives) + ckpt_path = output_dir / "last.ckpt" + if rank == 0: + assert ckpt_path.exists(), f"Rank {rank}: checkpoint not found at {ckpt_path}" + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + assert ckpt.get("global_step", 0) > 0, "global_step is 0 — training did not run" + assert "state_dict" in ckpt, "checkpoint missing state_dict" + assert "optimizer_states" in ckpt and ckpt["optimizer_states"], "checkpoint missing optimizer_states" + + # EMA: if model config has ema=True, checkpoint must contain EMA state + # saved as plain tensors (not DTensors) for checkpoint portability. + ckpt_hp = ckpt.get("hyper_parameters", {}) + if ckpt_hp.get("ema", False): + assert "ema" in ckpt, "checkpoint missing EMA state despite ema=True" + assert "ema_weights" in ckpt["ema"], "EMA state must include ema_weights" + assert "cur_step" in ckpt["ema"], "EMA state must include cur_step" + for ema_key, ema_val in ckpt["ema"]["ema_weights"].items(): + assert isinstance(ema_val, torch.Tensor) and not isinstance( + ema_val, DTensor + ), f"EMA weight '{ema_key}' must be saved as plain torch.Tensor, got {type(ema_val).__name__}" + + +def _worker_validation_parity( + rank: int, + grid_group_sizes: dict, + device_type: str, + backend: str, + env_per_rank: dict[str, Any], + dist_config_path: str, + serial_metrics: dict, + sigmas_global_host: torch.Tensor, + noise_global_host: torch.Tensor, + atom_counts_per_token_host: torch.Tensor, + cached_samples_path: str, + seed: int, +) -> None: + """Multi-rank worker: distributed validate() then compare with serial metrics. + + 1. Applies module-level monkeypatches (noise, data, smooth_lddt) + 2. Calls ``train_module.train(dist_config_path, [])`` in validation_only mode + 3. Captures validation metrics from ``trainer.validate()`` + 4. Compares all validation metrics with the serial reference + """ + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + monkeypatch.setattr(train_module, "_cleanup_distributed", lambda: None) + DistributedManager._state = {} + + _apply_cached_getitem(monkeypatch, cached_samples_path) + + # Deterministic sigmas for noise schedule + def _dist_noise_dist(self, bs, dtype=torch.float32): + s = sigmas_global_host.to(device=self.device_mesh.device_type, dtype=dtype)[:bs] + return distribute_tensor(s, self.device_mesh, (Shard(0), Replicate(), Replicate())) + + monkeypatch.setattr(DistAtomDiffusionV2, "noise_distribution", _dist_noise_dist) + + # Deterministic noise via distribute_atom_features (intersperse padding) + _noise_dt_cache: list[DTensor | None] = [None] + _noise_computed = [False] + + def _compute_noise_dt_once(device_mesh, dtype): + if _noise_computed[0]: + return + _noise_computed[0] = True + manager = DistributedManager() + _io_keys = {"noise"} + _placements = get_feature_placements( + atom_keys=set(), + model_io_keys=_io_keys, + model_io_fp32_keys=set(), + ) + size_batch = atom_counts_per_token_host.shape[0] + multiplicity_val = noise_global_host.shape[0] // size_batch + noise_unflat = noise_global_host.unflatten(0, (size_batch, multiplicity_val)) + inputs_io = {"atom_counts_per_token": atom_counts_per_token_host.clone()} + for i_mul in range(multiplicity_val): + inputs_io[f"noise_{i_mul}"] = noise_unflat[:, i_mul].to(dtype=dtype) + placements_cp_model_io_mul = { + f"{k}_{i_mul}": v for k, v in _placements["cp_model_io"].items() for i_mul in range(multiplicity_val) + } + placements_cp = _placements["cp_atom_features"] | placements_cp_model_io_mul + placements_model_io_mul = { + f"{k}_{i_mul}": v for k, v in _placements["model_io"].items() for i_mul in range(multiplicity_val) + } + placements_dp_cp = placements_model_io_mul + io_feats = distribute_atom_features( + inputs=inputs_io, + placements_cp=placements_cp, + placements_dp_cp=placements_dp_cp, + device_mesh=manager.device_mesh_subgroups, + cp_group=manager.group["cp"], + multiplicities={"noise": multiplicity_val}, + ) + _noise_dt_cache[0] = io_feats.pop("noise").to(dtype=dtype) + + _dist_in_val = [False] + + def _det_create_randn(shape, device_mesh, placements, dtype=torch.float32, scale=1.0): + if _dist_in_val[0]: + from boltz.distributed.utils import create_distributed_randn as _real_create_randn + + return _real_create_randn(shape, device_mesh, placements, dtype=dtype, scale=0.0) + from boltz.testing.utils import pad_to_length as _pad + + _compute_noise_dt_once(device_mesh, dtype) + n = _noise_dt_cache[0] + if n.dtype != dtype: + n = n.to(dtype=dtype) + if len(shape) > 1 and n.shape[1] < shape[1]: + n = _pad(n, dim=1, length=shape[1]) + return n * scale + + monkeypatch.setattr(dist_diffusion_module, "create_distributed_randn", _det_create_randn) + + _orig_dist_val_step = Boltz2Distributed.validation_step + + def _dist_val_step_wrapper(self_model, batch, batch_idx): + _dist_in_val[0] = True + try: + return _orig_dist_val_step(self_model, batch, batch_idx) + finally: + _dist_in_val[0] = False + + monkeypatch.setattr(Boltz2Distributed, "validation_step", _dist_val_step_wrapper) + + # Skip RMSD (not compared; serial uses a different code path) + import boltz.distributed.model.validation.validator as _dist_validator_mod + + def _rmsd_noop(*args, **kwargs): + return torch.tensor(0.0), None, None + + monkeypatch.setattr(_dist_validator_mod, "weighted_minimum_rmsd_single", _rmsd_noop) + + # Capture metrics from trainer.validate() + _captured_metrics: dict[str, float] = {} + _orig_validate = pl.Trainer.validate + + def _capturing_validate(self, *args, **kwargs): + result = _orig_validate(self, *args, **kwargs) + for k, v in self.callback_metrics.items(): + if isinstance(v, DTensor): + _captured_metrics[k] = v.full_tensor().detach().cpu().item() + elif isinstance(v, torch.Tensor): + _captured_metrics[k] = v.detach().cpu().item() + else: + _captured_metrics[k] = v + return result + + monkeypatch.setattr(pl.Trainer, "validate", _capturing_validate) + + train_module.train(dist_config_path, []) + + # --- Compare metrics --- + # Atom-level LDDT: atol=2e-4, rtol=0.005 (forward-pass accumulation order). + # Token-level metrics (disto_lddt_*, disto_loss): default tolerance. + # Trailing underscore omitted so the global weighted-average "val/lddt" + # (not just "val/lddt_*") also gets the relaxed tolerance. + _forward_dependent_prefixes = ("val/lddt", "val/complex_lddt", "val/clash", "val/pb", "val/rmsd") + _lddt_keys_compared = [] + if serial_metrics: + for k in serial_metrics: + if k in _captured_metrics: + got = torch.tensor(_captured_metrics[k]) + exp = torch.tensor(serial_metrics[k]) + if any(k.startswith(p) for p in _forward_dependent_prefixes): + torch.testing.assert_close( + got, + exp, + atol=2e-4, + rtol=0.005, + msg=lambda m: f"Rank {rank}: metric '{k}' mismatch: {m}", + ) + else: + torch.testing.assert_close( + got, + exp, + msg=lambda m: f"Rank {rank}: metric '{k}' mismatch: {m}", + ) + if "lddt" in k: + _lddt_keys_compared.append(k) + + assert _lddt_keys_compared, ( + f"Rank {rank}: no validation LDDT metrics were compared — test is vacuous. " + f"Serial keys: {sorted(serial_metrics)}, dist keys: {sorted(_captured_metrics)}" + ) + for required_metric in ("val/lddt", "val/disto_lddt", "val/complex_lddt"): + assert ( + required_metric in _captured_metrics + ), f"Rank {rank}: distributed metrics missing '{required_metric}' — available: {sorted(_captured_metrics)}" + + torch.distributed.barrier() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp2-cp2x2"], +) +def test_boltz2_validation_parity( + setup_env, + test_cp_training_base_data_dir_boltz2, + canonical_mols_dir, + tmp_path, +): + """Serial-vs-DTensor validation metric parity via ``train()`` validation_only. + + Runs both serial and distributed ``train()`` in validation_only mode + (no training, no backward pass) on the same pretrained model and data, + then compares all logged validation metrics. Model weights come from a + deterministic pretrained checkpoint; noise and data are controlled via + monkeypatches so the only source of difference is the serial-vs-distributed + forward pass in FP32. + + Token-level metrics (``val/disto_lddt_*``, ``val/disto_loss``) match at + default tolerance because they depend only on the confidence head logits, + which are deterministic given identical weights. + + Atom-level LDDT (``val/lddt_*``, ``val/complex_lddt_*``) uses + ``atol=2e-4, rtol=0.005``. These metrics depend on diffusion-sampled + coordinates whose FP32 accumulation order differs between serial + (full-sequence attention) and distributed (split attention + padding for + uniform DP shapes). Without training, these forward-pass-only differences + are smaller and more stable than in the e2e training test, allowing ~2.5x + tighter tolerances than the e2e test's ``atol=5e-4, rtol=0.02``. + Observed max absdiff is ~9e-5 (``lddt_intra_ligand``). + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + dtype = torch.float32 + seed = 42 + multiplicity = 2 + B = 2 + max_tokens = 256 + W = 32 # atoms_per_window_queries + size_cp = grid_group_sizes["cp"][0] * grid_group_sizes["cp"][1] + atom_align = math.lcm(W, size_cp) + max_atoms = ((max_tokens * 10 + atom_align - 1) // atom_align) * atom_align + max_seqs = 16 + scale_glorot = 0.05 + + # --- Merge all 4 samples with train/val split --- + training_data_dir, split_file = _setup_training_data_all_4_e2e( + tmp_path / "training_data", test_cp_training_base_data_dir_boltz2 + ) + + # --- Create pretrained checkpoint with deterministic init --- + seed_by_rank(0, seed=seed) + model_dict = _e2e_model_dict(multiplicity=multiplicity, validate_structure=True) + model_dict.pop("_target_") + model_dict.pop("validators", None) + _val_validators = [RCSBValidator(val_names=["RCSB"], confidence_prediction=False, physicalism_metrics=True)] + pretrained_model = SerialBoltz2(**model_dict, validators=_val_validators) + init_module_params_glorot(pretrained_model, gain=scale_glorot) + pretrained_model.apply(SetModuleInfValues()) + pretrained_model.structure_module.coordinate_augmentation = False + pretrained_model = pretrained_model.to(dtype=dtype) + + pretrained_path = tmp_path / "pretrained.ckpt" + torch.save( + { + "state_dict": pretrained_model.state_dict(), + "pytorch-lightning_version": pl.__version__, + "hyper_parameters": pretrained_model.hparams, + }, + pretrained_path, + ) + + # --- Pre-load and cache individual samples to disk --- + _tmp_mp = pytest.MonkeyPatch() + _apply_e2e_deterministic_getitem(_tmp_mp, base_seed=seed) + _preload_cfg = setup_mock_training_datamodule_config(training_data_dir) + _preload_cfg.batch_size = B + _preload_cfg.samples_per_epoch = B + _preload_cfg.moldir = str(canonical_mols_dir) + _preload_cfg.return_train_symmetries = False + _preload_cfg.msa_sampling_training = False + _preload_cfg.max_tokens = max_tokens + _preload_cfg.max_atoms = max_atoms + _preload_cfg.max_seqs = max_seqs + for _ds in _preload_cfg.datasets: + _ds.filters = None + _ds.split = str(split_file) + _ds.symmetry_correction = False + seed_by_rank(0, seed=seed) + _preload_dm = Boltz2TrainingDataModule(cfg=_preload_cfg) + _preload_ds = _preload_dm._train_set + _cached_samples = {i: _preload_ds[i] for i in range(B)} + cached_samples_path = tmp_path / "cached_samples.pt" + torch.save(_cached_samples, cached_samples_path) + + _preload_dl = _preload_dm.train_dataloader() + _preload_batch = next(iter(_preload_dl)) + atom_counts_per_token_host = _preload_batch["atom_counts_per_token"].detach().cpu() + atom_pad_mask_host = _preload_batch["atom_pad_mask"].detach().cpu() + _tmp_mp.undo() + + # --- Pre-generate deterministic noise (masked by atom_pad_mask) --- + seed_by_rank(0, seed=seed) + sigmas_global = pretrained_model.structure_module.noise_distribution(B * multiplicity).to(dtype=dtype) + noise_global = torch.empty(B * multiplicity, max_atoms, 3, dtype=dtype) + init_tensors_uniform([noise_global], low=-scale_glorot, high=scale_glorot) + _mask_mul = atom_pad_mask_host[:, :, None].repeat_interleave(multiplicity, 0).to(dtype=dtype) + noise_global = noise_global * _mask_mul + + sigmas_global_host = sigmas_global.detach().cpu() + noise_global_host = noise_global.detach().cpu() + + # --- Write serial config (validation_only) --- + serial_output_dir = tmp_path / "serial_output" + serial_output_dir.mkdir(parents=True, exist_ok=True) + serial_config_path = tmp_path / "serial_config.yaml" + _e2e_ds_overrides = { + "filters": None, + "moldir": None, + "symmetry_correction": False, + "val_group": "RCSB", + "use_train_subset": None, + "override_bfactor": False, + "override_method": None, + } + _write_train_config( + TrainTestConfig( + config_path=serial_config_path, + output_dir=serial_output_dir, + test_data_dir=training_data_dir, + mol_dir=canonical_mols_dir, + mode="serial", + accelerator="gpu", + validation_only=True, + pretrained=str(pretrained_path), + model=_e2e_model_dict(multiplicity=multiplicity, validate_structure=True), + batch_size=B, + samples_per_epoch=B, + max_tokens=max_tokens, + max_atoms=max_atoms, + max_seqs=max_seqs, + return_train_symmetries=False, + split=str(split_file), + pop_target_keys=True, + extra_dataset_overrides=_e2e_ds_overrides, + v2=True, + strict_loading=False, + wandb=None, + save_top_k=0, + disable_checkpoint=True, + ) + ) + + # --- Apply serial monkeypatches --- + serial_mp = pytest.MonkeyPatch() + _apply_cached_getitem(serial_mp, cached_samples_path) + + _orig_serial_boltz2_init = SerialBoltz2.__init__ + + @functools.wraps(_orig_serial_boltz2_init) + def _init_with_validators(self, *args, **kwargs): + if kwargs.get("validate_structure", False) and not kwargs.get("validators"): + kwargs["validators"] = [ + RCSBValidator(val_names=["RCSB"], confidence_prediction=False, physicalism_metrics=True) + ] + _orig_serial_boltz2_init(self, *args, **kwargs) + + serial_mp.setattr(SerialBoltz2, "__init__", _init_with_validators) + + serial_mp.setattr( + SerialAtomDiffusionV2, + "noise_distribution", + lambda self, bs: sigmas_global[:bs].to(device=self.zero.device), + ) + + # Deterministic noise: pre-generated for training, zero for validation + _serial_in_val = [False] + + def _serial_randn_like(t): + if _serial_in_val[0]: + return torch.zeros_like(t) + return noise_global[: t.shape[0], : t.shape[1]].to(device=t.device, dtype=t.dtype) + + serial_mp.setattr(serial_diffusion_v2_module.torch, "randn_like", _serial_randn_like) + + _orig_serial_randn = serial_diffusion_v2_module.torch.randn + + def _serial_randn(*args, **kwargs): + if _serial_in_val[0]: + kwargs.pop("generator", None) + return torch.zeros(*args, **kwargs) + return _orig_serial_randn(*args, **kwargs) + + serial_mp.setattr(serial_diffusion_v2_module.torch, "randn", _serial_randn) + + _orig_compute_random_augmentation = serial_diffusion_v2_module.compute_random_augmentation + + def _identity_augmentation_during_val(multiplicity_arg, s_trans=1.0, device=None, dtype=torch.float32): + if _serial_in_val[0]: + R = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(multiplicity_arg, -1, -1) + tr = torch.zeros(multiplicity_arg, 1, 3, device=device, dtype=dtype) + return R, tr + return _orig_compute_random_augmentation(multiplicity_arg, s_trans=s_trans, device=device, dtype=dtype) + + serial_mp.setattr(serial_diffusion_v2_module, "compute_random_augmentation", _identity_augmentation_during_val) + + _orig_serial_val_step = SerialBoltz2.validation_step + + def _serial_val_step_wrapper(self_model, batch, batch_idx): + _serial_in_val[0] = True + try: + return _orig_serial_val_step(self_model, batch, batch_idx) + finally: + _serial_in_val[0] = False + + serial_mp.setattr(SerialBoltz2, "validation_step", _serial_val_step_wrapper) + + serial_mp.setattr(serial_loss_v2_module, "smooth_lddt_loss", _smooth_lddt_loss_dense_e2e) + serial_mp.setattr(serial_diffusion_v2_module, "smooth_lddt_loss", _smooth_lddt_loss_dense_e2e) + + # Capture metrics from trainer.validate() + serial_captured_metrics: dict[str, float] = {} + _orig_validate = pl.Trainer.validate + + def _serial_capturing_validate(self, *args, **kwargs): + result = _orig_validate(self, *args, **kwargs) + for k, v in self.callback_metrics.items(): + if isinstance(v, torch.Tensor): + serial_captured_metrics[k] = v.detach().cpu().item() + else: + serial_captured_metrics[k] = v + return result + + serial_mp.setattr(pl.Trainer, "validate", _serial_capturing_validate) + + # --- Run serial validation --- + _serial_train_mod = _load_serial_train_module() + _serial_train_mod.train(str(serial_config_path), []) + + # Non-vacuous guard + serial_lddt_keys = [ + k for k in serial_captured_metrics if k.startswith("val/lddt_") or k.startswith("val/disto_lddt_") + ] + assert serial_lddt_keys, ( + f"Serial run produced no validation LDDT metrics — test is vacuous. " + f"Available metrics: {sorted(serial_captured_metrics)}" + ) + + serial_mp.undo() + + # --- Write distributed config (validation_only) --- + dp = grid_group_sizes["dp"] + cp0, cp1 = grid_group_sizes["cp"] + size_cp = cp0 * cp1 + dist_output_dir = tmp_path / "dist_output" + dist_output_dir.mkdir(parents=True, exist_ok=True) + dist_config_path = tmp_path / "dist_config.yaml" + _write_train_config( + TrainTestConfig( + config_path=dist_config_path, + output_dir=dist_output_dir, + test_data_dir=training_data_dir, + mol_dir=canonical_mols_dir, + size_dp=dp, + size_cp=size_cp, + accelerator="gpu", + validation_only=True, + pretrained=str(pretrained_path), + model=_e2e_model_dict(multiplicity=multiplicity, validate_structure=True, distributed=True), + batch_size=1, + samples_per_epoch=dp, + max_tokens=max_tokens, + max_atoms=max_atoms, + max_seqs=max_seqs, + return_train_symmetries=False, + split=str(split_file), + pop_target_keys=True, + extra_dataset_overrides=_e2e_ds_overrides, + ) + ) + + # --- Spawn distributed workers --- + spawn_multiprocessing( + _worker_validation_parity, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + str(dist_config_path), + serial_captured_metrics, + sigmas_global_host, + noise_global_host, + atom_counts_per_token_host, + str(cached_samples_path), + seed, + ) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + # DP-only: exercises DP correctness on 2-GPU workstations + ((2, (1, 1)), True, "cuda", "ENV"), + # DP + CP: catches integration issues between DP and CP + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp2-cp1x1", "cuda-dp2-cp2x2"], +) +def test_boltz2_train_entrypoint( + setup_env, + tmp_path, + test_cp_training_data_dir_boltz2, + canonical_mols_dir, +): + """End-to-end Boltz-2 training through the real train() entrypoint. + + Exercises the full pipeline with a small model and real training data: + config → Hydra instantiate → _create_distributed_model (Boltz2Distributed) + → _create_distributed_data_module (Boltz2TrainingDataModule) → Trainer.fit + → checkpoint. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + output_dir = tmp_path / "boltz2_train_output" + config_path = tmp_path / "boltz2_train_config.yaml" + + _write_train_config( + TrainTestConfig( + config_path=config_path, + output_dir=output_dir, + test_data_dir=test_cp_training_data_dir_boltz2, + mol_dir=canonical_mols_dir, + size_dp=grid_group_sizes["dp"], + size_cp=math.prod(grid_group_sizes["cp"]), + accelerator="gpu" if device_type == "cuda" else "cpu", + ) + ) + + payload = (env_per_rank, str(config_path), str(output_dir)) + spawn_multiprocessing(_parallel_assert_boltz2_train, world_size, payload) + + +def _load_production_model_config(*, reduce_depth: bool = True) -> dict[str, Any]: + """Load model config from the production structurev2.yaml. + + Returns the model dict with _target_ keys suitable for Hydra instantiation. + Overrides training args for fast test execution (1 recycling step, 1 + diffusion sample, etc.) and disables features not yet distributed + (confidence, affinity, templates, validation). + + Parameters + ---------- + reduce_depth + If True (default), reduce pairformer/transformer depth to fit in + 32 GiB GPUs. Set to False on clusters with >=64 GiB GPUs to test + with production-depth layers. + """ + config_yaml = Path(__file__).resolve().parents[2] / "scripts" / "train" / "configs" / "structurev2.yaml" + full_config = OmegaConf.load(config_yaml) + model_dict = OmegaConf.to_container(full_config.model, resolve=False) + + # Disable features not yet distributed + model_dict["confidence_prediction"] = False + model_dict["affinity_prediction"] = False + model_dict["use_templates"] = False + model_dict["validate_structure"] = False + + # Remove validators (not needed for training-only test) + model_dict.pop("validators", None) + + # Remove Hydra interpolations that reference ${data.*} or ${model.*} + # and replace with concrete values + model_dict["conditioning_cutoff_min"] = 4.0 + model_dict["conditioning_cutoff_max"] = 20.0 + + # Fast training settings + model_dict["training_args"] = { + "recycling_steps": 1, + "sampling_steps": 2, + "diffusion_multiplicity": 1, + "diffusion_samples": 1, + "diffusion_loss_weight": 1.0, + "distogram_loss_weight": 0.3, + "confidence_loss_weight": 0.0, + "bfactor_loss_weight": 0.0, + "symmetry_correction": False, + "adam_beta_1": 0.9, + "adam_beta_2": 0.95, + "adam_eps": 1e-8, + "lr_scheduler": "af3", + "base_lr": 1e-3, + "max_lr": 1e-3, + "lr_warmup_no_steps": 10, + "lr_start_decay_after_n_steps": 100, + "lr_decay_every_n_steps": 50000, + "lr_decay_factor": 0.95, + "weight_decay": 0.0, + } + model_dict["validation_args"] = { + "recycling_steps": 0, + "sampling_steps": 2, + "diffusion_samples": 1, + "symmetry_correction": False, + } + + if reduce_depth: + # Reduce model depth to fit in 32 GiB GPU memory under DP=2. + # The checkpoint is loaded with strict=False so missing blocks are + # fine — we still exercise the pretrained loading path and verify + # the matching layers load correctly. + model_dict["pairformer_args"]["num_blocks"] = 2 + model_dict["score_model_args"]["token_transformer_depth"] = 2 + model_dict["msa_args"]["msa_blocks"] = 1 + + model_dict["diffusion_process_args"] = model_dict.get("diffusion_process_args", {}) + model_dict["diffusion_process_args"]["coordinate_augmentation"] = False + + if reduce_depth: + # With reduced depth the model fits without activation checkpointing, + # and disabling it makes the test faster. + for key in ("embedder_args", "msa_args", "pairformer_args", "score_model_args"): + if key in model_dict and isinstance(model_dict[key], dict): + model_dict[key]["activation_checkpointing"] = False + if "template_args" in model_dict: + model_dict["template_args"]["activation_checkpointing"] = False + + return model_dict + + +@pytest.mark.slow +@pytest.mark.parametrize( + ("setup_env", "reduce_depth"), + [ + (((2, (1, 1)), True, "cuda", "ENV"), True), + # Full-depth uses CP to distribute activations across 4 GPUs, + # mirroring production topology and avoiding per-GPU OOM. + (((1, (2, 2)), True, "cuda", "ENV"), False), + ], + indirect=["setup_env"], + ids=["cuda-dp2-cp1x1-reduced", "cuda-dp1-cp2x2-full"], +) +def test_boltz2_finetune_from_checkpoint( + setup_env, + reduce_depth, + tmp_path, + test_cp_training_data_dir_boltz2, + canonical_mols_dir, + get_model_ckpt_v2, +): + """End-to-end Boltz-2 finetune from real checkpoint through train(). + + Loads the real Boltz-2 checkpoint via ``pretrained`` config, exercises + production-width model layers with real training data under BF16 mixed + precision. The ``reduce_depth=True`` variant uses dp=2 with reduced + pairformer/transformer depth (fast, 2 GPUs). The ``reduce_depth=False`` + variant uses dp=1, cp=2x2 to shard the full-depth model across 4 GPUs, + mirroring how production deployments use CP for large models. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + output_dir = tmp_path / "boltz2_finetune_output" + config_path = tmp_path / "boltz2_finetune_config.yaml" + + _write_train_config( + TrainTestConfig( + config_path=config_path, + output_dir=output_dir, + test_data_dir=test_cp_training_data_dir_boltz2, + mol_dir=canonical_mols_dir, + size_dp=grid_group_sizes["dp"], + size_cp=math.prod(grid_group_sizes["cp"]), + accelerator="gpu", + limit_train_batches=1, + pretrained=str(get_model_ckpt_v2), + model=_load_production_model_config(reduce_depth=reduce_depth), + precision="BF16_MIXED", + ) + ) + + payload = (env_per_rank, str(config_path), str(output_dir)) + spawn_multiprocessing(_parallel_assert_boltz2_train, world_size, payload) + + +# --------------------------------------------------------------------------- +# Stop-and-go checkpoint resume test for Boltz-2 +# --------------------------------------------------------------------------- + + +def _parallel_assert_boltz2_stop_and_go(rank: int, payload: tuple[Any, ...]) -> None: + """Verify checkpoint resume correctness through the real ``train()`` entrypoint. + + Runs two ``train()`` calls: + 1. Stop/go stage 1 — 1 epoch, checkpoint produced + 2. Stop/go stage 2 — resume from checkpoint, train to epoch 2 + + Verifies: + - Stage 1 checkpoint contains valid model state and optimizer state + - Stage 2 successfully resumes from checkpoint (no errors) + - Final checkpoint has correct epoch/step counters (epoch 2, step > stage 1) + - Final checkpoint has the same state_dict keys as stage 1 + - Weights changed between stage 1 and final (training actually happened) + - Optimizer state is populated with correct structure + + Note: Exact weight parity between continuous and stop/go runs is not + tested because ``train.py`` uses different seed offsets on resume + (``seed + rank + epoch*1000 + step``), which changes diffusion noise. + The distogram stop-and-go test (``test_dtensor_stop_and_go.py``) covers + exact parity because its model has no stochastic diffusion. + """ + ( + env_per_rank, + stage1_config_path, + stage2_config_path, + output_dir, + ) = payload + output_dir = Path(output_dir) + + monkeypatch = pytest.MonkeyPatch() + for key, value in env_per_rank.items(): + monkeypatch.setenv(key, f"{rank}" if value == "" else value) + + # Only suppress cleanup — use the real model and data factories. + monkeypatch.setattr(train_module, "_cleanup_distributed", lambda: None) + DistributedManager._state = {} + + # ---- Stage 1: 1 epoch, checkpoint produced. ---- + train_module.train(str(stage1_config_path), []) + ckpt_path = output_dir / "last.ckpt" + assert ckpt_path.exists(), f"Rank {rank}: stage 1 checkpoint not found at {ckpt_path}" + + stage1_ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + stage1_epoch = stage1_ckpt["epoch"] + stage1_step = stage1_ckpt["global_step"] + stage1_keys = set(stage1_ckpt["state_dict"].keys()) + assert stage1_step > 0, f"Rank {rank}: stage 1 global_step is 0 — training did not run" + assert "state_dict" in stage1_ckpt, f"Rank {rank}: stage 1 missing state_dict" + assert ( + "optimizer_states" in stage1_ckpt and stage1_ckpt["optimizer_states"] + ), f"Rank {rank}: stage 1 missing optimizer_states" + + # Save stage 1 weights for change detection + stage1_weights = {k: v.clone() for k, v in stage1_ckpt["state_dict"].items()} + + # ---- Stage 2: resume from checkpoint to epoch 2. ---- + train_module.train(str(stage2_config_path), []) + + # ---- Verify final checkpoint. ---- + final_ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + final_epoch = final_ckpt["epoch"] + final_step = final_ckpt["global_step"] + final_keys = set(final_ckpt["state_dict"].keys()) + + # 1) Epoch and step advanced beyond stage 1. + assert ( + final_epoch > stage1_epoch + ), f"Rank {rank}: Final epoch ({final_epoch}) should be > stage 1 epoch ({stage1_epoch})" + assert final_step > stage1_step, f"Rank {rank}: Final step ({final_step}) should be > stage 1 step ({stage1_step})" + + # 2) State dict keys match (model architecture is preserved). + assert final_keys == stage1_keys, ( + f"Rank {rank}: state_dict key mismatch between stage 1 and final. " + f"Extra: {final_keys - stage1_keys}, Missing: {stage1_keys - final_keys}" + ) + + # 3) Weights changed (training actually happened in stage 2). + weights_differ = any(not torch.equal(final_ckpt["state_dict"][k], stage1_weights[k]) for k in stage1_keys) + assert ( + weights_differ + ), f"Rank {rank}: No weights changed between stage 1 and final — stage 2 training may not have run" + + # 4) Optimizer state is valid. + assert ( + "optimizer_states" in final_ckpt and final_ckpt["optimizer_states"] + ), f"Rank {rank}: final checkpoint missing optimizer_states" + opt_state = final_ckpt["optimizer_states"][0]["state"] + assert len(opt_state) > 0, f"Rank {rank}: optimizer state is empty" + + # 5) Optimizer state keys are FQN strings (not legacy integers). + opt_state_keys = list(opt_state.keys()) + assert all( + isinstance(k, str) for k in opt_state_keys + ), f"Rank {rank}: optimizer state keys should be FQN strings, got {[type(k).__name__ for k in opt_state_keys[:3]]}" + + # 6) EMA state is preserved across resume. + if "ema" in stage1_ckpt: + assert "ema_weights" in stage1_ckpt["ema"], "Stage 1 EMA must include ema_weights" + assert "cur_step" in stage1_ckpt["ema"], "Stage 1 EMA must include cur_step" + for ema_key, ema_val in stage1_ckpt["ema"]["ema_weights"].items(): + assert isinstance(ema_val, torch.Tensor) and not isinstance( + ema_val, DTensor + ), f"Stage 1 EMA weight '{ema_key}' must be plain torch.Tensor, got {type(ema_val).__name__}" + + assert "ema" in final_ckpt, f"Rank {rank}: final checkpoint missing EMA state after resume" + assert final_ckpt["ema"]["cur_step"] > stage1_ckpt["ema"]["cur_step"], ( + f"Rank {rank}: EMA cur_step did not advance " + f"(stage1={stage1_ckpt['ema']['cur_step']}, final={final_ckpt['ema']['cur_step']})" + ) + assert set(final_ckpt["ema"]["ema_weights"].keys()) == set( + stage1_ckpt["ema"]["ema_weights"].keys() + ), f"Rank {rank}: EMA weight keys changed between stage 1 and final" + + torch.distributed.barrier() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + # DP-only: checkpoint resume correctness on 2-GPU systems + ((2, (1, 1)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp2-cp1x1"], +) +def test_boltz2_stop_and_go( + setup_env, + tmp_path, + test_cp_training_data_dir_boltz2, + canonical_mols_dir, +): + """Stop-and-go checkpoint resume correctness for Boltz-2 via ``train()``. + + Verifies that training 1 epoch, checkpointing, then resuming to epoch 2 + produces a valid final state: correct epoch/step counters, preserved + model architecture (state_dict keys), weights that changed (training + happened), and valid optimizer state. + + Note: exact weight parity with a continuous 2-epoch run is not tested + because ``train.py`` uses different seed offsets on resume. + + Uses the real ``train()`` entrypoint with real Boltz-2 training data + and a small model config. Only ``_cleanup_distributed`` is + monkeypatched (for process group safety in test harness). + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + size_dp = grid_group_sizes["dp"] + size_cp = math.prod(grid_group_sizes["cp"]) + accelerator = "gpu" if device_type == "cuda" else "cpu" + + stopgo_dir = tmp_path / "stopgo" + + common_kwargs: dict[str, Any] = { + "test_data_dir": test_cp_training_data_dir_boltz2, + "mol_dir": canonical_mols_dir, + "size_dp": size_dp, + "size_cp": size_cp, + "accelerator": accelerator, + } + + # Stage 1: 1 epoch with checkpoint. + stage1_config = stopgo_dir / "config_stage1.yaml" + _write_train_config( + TrainTestConfig( + config_path=stage1_config, + output_dir=stopgo_dir, + **common_kwargs, + ) + ) + + # Stage 2: resume from checkpoint, train to epoch 2. + stage2_config = stopgo_dir / "config_stage2.yaml" + _write_train_config( + TrainTestConfig( + config_path=stage2_config, + output_dir=stopgo_dir, + max_epochs=2, + resume=str(stopgo_dir / "last.ckpt"), + **common_kwargs, + ) + ) + + payload = ( + env_per_rank, + str(stage1_config), + str(stage2_config), + str(stopgo_dir), + ) + spawn_multiprocessing(_parallel_assert_boltz2_stop_and_go, world_size, payload) + + +# --------------------------------------------------------------------------- +# Train → checkpoint → distributed inference pipeline test +# --------------------------------------------------------------------------- + + +def _parallel_assert_train_to_inference( + rank: int, + env_per_rank: dict[str, Any], + checkpoint_path: str, + data_dir: str, + mol_dir: str, + out_dir: str, + size_dp: int, + size_cp: int, + accelerator: str, +) -> None: + """Worker: run distributed inference using a checkpoint from training.""" + monkeypatch = pytest.MonkeyPatch() + for key, value in env_per_rank.items(): + monkeypatch.setenv(key, f"{rank}" if value == "" else value) + + run_predict( + data=data_dir, + out_dir=out_dir, + mol_dir=mol_dir, + checkpoint=checkpoint_path, + size_dp=size_dp, + size_cp=size_cp, + accelerator=accelerator, + recycling_steps=0, + sampling_steps=2, + diffusion_samples=1, + seed=42, + input_format="preprocessed", + use_templates=False, + confidence_prediction=False, + triattn_backend=TriAttnBackend.REFERENCE, + sdpa_with_bias_backend=SDPAWithBiasBackend.REFERENCE, + sdpa_with_bias_shardwise_backend=SDPAWithBiasBackend.REFERENCE, + ) + + out_path = Path(out_dir) + data_stem = Path(data_dir).stem + results_dir = out_path / f"boltz_results_{data_stem}" + + rank_cp = rank % size_cp + if rank_cp == 0: + assert results_dir.exists(), f"Rank {rank}: results dir {results_dir} not found" + cif_files = list(results_dir.rglob("*.cif")) + assert len(cif_files) > 0, f"Rank {rank}: no CIF output files in {results_dir}" + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (1, 1)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp2-cp1x1"], +) +def test_boltz2_train_checkpoint_to_inference( + setup_env, + tmp_path, + test_cp_training_data_dir_boltz2, + canonical_mols_dir, +): + """Train a small Boltz-2 model, then run inference with the saved checkpoint. + + Verifies the complete training-to-inference pipeline: + 1. Trains with real data for 1 epoch → checkpoint + 2. Loads the checkpoint via ``run_predict`` for distributed inference + 3. Verifies that CIF output files are produced + + This ensures checkpoints saved by ``train()`` (via ``BoltzContextParallelStrategy``) + are compatible with the ``run_predict`` inference pipeline. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) + if gpu_mem_gb < 100: + pytest.skip( + f"GPU has {gpu_mem_gb:.0f}GB memory; inference with CP=1 on real-data " + "structures requires >80GB (outer_product_mean pair tensor OOM)" + ) + + size_dp = grid_group_sizes["dp"] + size_cp = math.prod(grid_group_sizes["cp"]) + + # ---- Step 1: Train and produce checkpoint ---- + training_output_dir = tmp_path / "training" + config_path = tmp_path / "train_config.yaml" + + _write_train_config( + TrainTestConfig( + config_path=config_path, + output_dir=training_output_dir, + test_data_dir=test_cp_training_data_dir_boltz2, + mol_dir=canonical_mols_dir, + size_dp=size_dp, + size_cp=size_cp, + accelerator="gpu", + ) + ) + + train_payload = (env_per_rank, str(config_path), str(training_output_dir)) + spawn_multiprocessing(_parallel_assert_boltz2_train, world_size, train_payload) + + checkpoint_path = training_output_dir / "last.ckpt" + assert checkpoint_path.exists(), f"Training checkpoint not found at {checkpoint_path}" + + # ---- Step 2: Run inference with the trained checkpoint ---- + inference_output_dir = tmp_path / "inference" + + spawn_multiprocessing( + _parallel_assert_train_to_inference, + world_size, + env_per_rank, + str(checkpoint_path), + str(test_cp_training_data_dir_boltz2), + str(canonical_mols_dir), + str(inference_output_dir), + size_dp, + size_cp, + "gpu", + ) + + +# --------------------------------------------------------------------------- +# E2E serial-vs-DTensor training parity test +# --------------------------------------------------------------------------- + + +def _load_serial_train_module(): + """Lazily import serial train.py (a script, not a package module). + + Deferred to call time to avoid executing the module during pytest + collection, which can pollute global state and cause failures when + many tests are collected. + """ + path = Path(__file__).resolve().parents[2] / "scripts" / "train" / "train.py" + spec = _importlib_util.spec_from_file_location("_serial_train", str(path)) + mod = _importlib_util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def _smooth_lddt_loss_dense_e2e( + pred_coords, + true_coords, + is_nucleotide, + coords_mask=None, + nucleic_acid_cutoff=30.0, + other_cutoff=15.0, + multiplicity=1, + **kwargs, +): + """Dense pairwise distance smooth_lddt for serial-distributed backward parity. + + Aligns the backward autograd graph between serial (sparse) and distributed + (dense CDIST) computation paths. + """ + compute_dtype = torch.promote_types(pred_coords.dtype, torch.float32) + N = pred_coords.shape[1] + lddt = [] + for i in range(true_coords.shape[0]): + true_dists = torch.cdist(true_coords[i], true_coords[i]) + is_nuc_i = is_nucleotide[i // multiplicity] + mask_i = coords_mask[i // multiplicity] + is_nuc_pair = is_nuc_i.unsqueeze(-1).expand(-1, is_nuc_i.shape[-1]) + mask = is_nuc_pair * (true_dists < nucleic_acid_cutoff).to(compute_dtype) + mask += (1 - is_nuc_pair) * (true_dists < other_cutoff).to(compute_dtype) + mask *= 1 - torch.eye(N, device=pred_coords.device) + mask *= mask_i.unsqueeze(-1) + mask *= mask_i.unsqueeze(-2) + diff = pred_coords[i].unsqueeze(0) - pred_coords[i].unsqueeze(1) + pred_dists = (diff * diff).sum(-1).add(1e-30).sqrt() + dist_diff = (true_dists - pred_dists).abs() + eps = ( + torch.sigmoid(0.5 - dist_diff) + + torch.sigmoid(1.0 - dist_diff) + + torch.sigmoid(2.0 - dist_diff) + + torch.sigmoid(4.0 - dist_diff) + ) * 0.25 + lddt_i = (eps * mask).sum() / (mask.sum() + 1e-5) + lddt.append(lddt_i) + return 1 - sum(lddt) / len(lddt) + + +def _setup_training_data_7z64_8b2e_e2e(out_dir: Path, base_data_dir: Path) -> Path: + """Merge 7z64 and 8b2e processed data into a single training directory.""" + names = ["7z64", "8b2e"] + source_dirs = [base_data_dir / f"processed_{name}" for name in names] + merged = concat_data(out_dir, *source_dirs) + records_dir = merged / "records" + records_dir.mkdir(parents=True, exist_ok=True) + copied: set[str] = set() + for src in source_dirs: + for rf in (src / "records").glob("*.json"): + if rf.name in copied: + raise ValueError(f"Duplicate record file {rf.name}") + shutil.copy(rf, records_dir / rf.name) + copied.add(rf.name) + return merged + + +def _setup_training_data_all_4_e2e(out_dir: Path, base_data_dir: Path) -> tuple[Path, Path]: + """Merge all 4 samples with a train/val split. + + Training: 7ylz, 8b2e. Validation: 7z64, 8ayv. + Returns ``(merged_data_dir, val_split_file_path)``. + """ + names = ["7ylz", "7z64", "8ayv", "8b2e"] + source_dirs = [base_data_dir / f"processed_{name}" for name in names] + merged = concat_data(out_dir, *source_dirs) + records_dir = merged / "records" + records_dir.mkdir(parents=True, exist_ok=True) + copied: set[str] = set() + for src in source_dirs: + for rf in (src / "records").glob("*.json"): + if rf.name in copied: + raise ValueError(f"Duplicate record file {rf.name}") + shutil.copy(rf, records_dir / rf.name) + copied.add(rf.name) + split_file = out_dir / "val_split.txt" + split_file.write_text("7z64\n8ayv\n") + return merged, split_file + + +def _e2e_model_dict( + *, + ema: bool = True, + ema_decay: float = 0.999, + multiplicity: int = 2, + validate_structure: bool = False, + distributed: bool = False, +) -> dict[str, Any]: + """Small model config dict for the E2E training parity test. + + Matches ``create_boltz2_model_init_params(use_large_model=False)`` with + recycling disabled, coordinate augmentation off, and configurable + ``validate_structure``. + + When ``distributed=True``, validators use + :class:`DistributedRCSBValidator` (the DTensor-aware variant required + by ``Boltz2Distributed``). + """ + d: dict[str, Any] = { + "_target_": "boltz.model.models.boltz2.Boltz2", + "atom_s": 4, + "atom_z": 4, + "token_s": 4, + "token_z": 4, + "num_bins": 64, + "atom_feature_dim": 388, + "atoms_per_window_queries": 32, + "atoms_per_window_keys": 128, + "ema": ema, + "ema_decay": ema_decay, + "confidence_prediction": False, + "affinity_prediction": False, + "structure_prediction_training": True, + "use_templates": False, + "validate_structure": validate_structure, + "predict_bfactor": False, + "bond_type_feature": False, + "no_random_recycling_training": True, + "embedder_args": { + "atom_encoder_depth": 1, + "atom_encoder_heads": 1, + "activation_checkpointing": False, + }, + "msa_args": { + "msa_s": 4, + "msa_blocks": 1, + "msa_dropout": 0.0, + "z_dropout": 0.0, + "use_paired_feature": True, + }, + "pairformer_args": { + "num_blocks": 1, + "num_heads": 1, + "dropout": 0.0, + "v2": True, + }, + "score_model_args": { + "sigma_data": 16.0, + "dim_fourier": 4, + "atom_encoder_depth": 1, + "atom_encoder_heads": 1, + "token_transformer_depth": 1, + "token_transformer_heads": 1, + "atom_decoder_depth": 1, + "atom_decoder_heads": 1, + "activation_checkpointing": False, + "conditioning_transition_layers": 1, + }, + "diffusion_process_args": { + "coordinate_augmentation": False, + }, + "diffusion_loss_args": {}, + "training_args": { + "recycling_steps": 0, + "sampling_steps": -1, + "diffusion_multiplicity": multiplicity, + "diffusion_samples": -1, + "diffusion_loss_weight": 1.0, + "distogram_loss_weight": 0.3, + "confidence_loss_weight": 0.0, + "bfactor_loss_weight": 0.0, + "symmetry_correction": False, + "adam_beta_1": 0.9, + "adam_beta_2": 0.95, + "adam_eps": 1e-8, + "lr_scheduler": "af3", + "base_lr": 1e-4, + "max_lr": 1e-4, + "lr_warmup_no_steps": 10, + "lr_start_decay_after_n_steps": 100, + "lr_decay_every_n_steps": 50000, + "lr_decay_factor": 0.95, + "weight_decay": 0.0, + }, + "validation_args": { + "recycling_steps": 0, + "sampling_steps": 2, + "diffusion_samples": 1, + "symmetry_correction": False, + }, + } + if validate_structure: + d["num_val_datasets"] = 1 + _validator_target = ( + "boltz.distributed.model.validation.rcsb.DistributedRCSBValidator" + if distributed + else "boltz.model.validation.rcsb.RCSBValidator" + ) + d["validators"] = [ + { + "_target_": _validator_target, + "val_names": ["RCSB"], + "confidence_prediction": False, + "physicalism_metrics": True, + } + ] + return d + + +def _apply_e2e_deterministic_getitem(monkeypatch, base_seed: int = 42) -> None: + """Patch ``TrainingDataset.__getitem__`` at the class level for deterministic data.""" + original_getitem = SerialTrainingDataset.__getitem__ + + _getitem_call_count = [0] + + def _wrapped_getitem(self, idx): + _getitem_call_count[0] += 1 + np.random.seed(base_seed + idx) + torch.manual_seed(base_seed + idx) + stdlib_random.seed(base_seed + idx) + _original_np_choice = np.random.choice + _call_count = [0] + _num_samples = len(self.samples[0]) + + def _deterministic_choice(a, p=None, **kwargs): + _call_count[0] += 1 + result = _original_np_choice(a, p=p, **kwargs) + if _call_count[0] == 1: + return 0 + elif _call_count[0] == 2: + return idx % _num_samples + return result + + np.random.choice = _deterministic_choice + try: + return original_getitem(self, idx) + finally: + np.random.choice = _original_np_choice + + monkeypatch.setattr(SerialTrainingDataset, "__getitem__", _wrapped_getitem) + + +def _apply_cached_getitem(monkeypatch, cache_path: str | Path) -> None: + """Replace ``TrainingDataset.__getitem__`` with a disk-backed cache lookup. + + The cache file (created during pre-load) maps integer sample indices to + feature dicts, guaranteeing identical features across serial and + distributed data pipelines regardless of data-processing RNG state. + """ + _cache: dict[int, dict] = {} + + def _cached_getitem(self, idx): + if not _cache: + _cache.update(torch.load(str(cache_path), map_location="cpu", weights_only=False)) + return _cache[idx] + + monkeypatch.setattr(SerialTrainingDataset, "__getitem__", _cached_getitem) + + +def _worker_e2e_training_parity( + rank: int, + grid_group_sizes: dict, + device_type: str, + backend: str, + env_per_rank: dict[str, Any], + dist_config_path: str, + dist_output_dir: str, + serial_ckpt_path: str, + pretrained_ckpt_path: str, + serial_metrics: dict, + sigmas_global_host: torch.Tensor, + noise_global_host: torch.Tensor, + atom_counts_per_token_host: torch.Tensor, + cached_samples_path: str, + seed: int, +) -> None: + """Multi-rank worker: distributed train() then compare with serial checkpoint. + + 1. Applies module-level monkeypatches (noise, data, smooth_lddt, DoublePrecision) + 2. Calls ``train_module.train(dist_config_path, [])`` + 3. Loads both serial and distributed checkpoints + 4. Compares state_dict, EMA weights, and logged metrics + """ + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + monkeypatch.setattr(train_module, "_cleanup_distributed", lambda: None) + DistributedManager._state = {} + + # --- Deterministic data loading via cached samples --- + _apply_cached_getitem(monkeypatch, cached_samples_path) + + # --- Deterministic noise (distributed) --- + # Noise must go through distribute_atom_features (intersperse padding) so + # that noise[i] at position i in the distributed tensor corresponds to the + # same atom as coords[i]. Plain distribute_tensor would keep serial ordering + # while coords use the intersperse-padded ordering, causing a mismatch. + def _dist_noise_dist(self, bs, dtype=torch.float32): + s = sigmas_global_host.to(device=self.device_mesh.device_type, dtype=dtype)[:bs] + return distribute_tensor(s, self.device_mesh, (Shard(0), Replicate(), Replicate())) + + monkeypatch.setattr(DistAtomDiffusionV2, "noise_distribution", _dist_noise_dist) + + _noise_dt_cache: list[DTensor | None] = [None] + _noise_computed = [False] + + def _compute_noise_dt_once(device_mesh, dtype): + """Compute noise DTensor via distribute_atom_features (intersperse padding).""" + if _noise_computed[0]: + return + _noise_computed[0] = True + + manager = DistributedManager() + _io_keys = {"noise"} + _placements = get_feature_placements( + atom_keys=set(), + model_io_keys=_io_keys, + model_io_fp32_keys=set(), + ) + + size_batch = atom_counts_per_token_host.shape[0] + multiplicity_val = noise_global_host.shape[0] // size_batch + noise_unflat = noise_global_host.unflatten(0, (size_batch, multiplicity_val)) + + inputs_io = {"atom_counts_per_token": atom_counts_per_token_host.clone()} + for i_mul in range(multiplicity_val): + inputs_io[f"noise_{i_mul}"] = noise_unflat[:, i_mul].to(dtype=dtype) + + placements_cp_model_io_mul = { + f"{k}_{i_mul}": v for k, v in _placements["cp_model_io"].items() for i_mul in range(multiplicity_val) + } + placements_cp = _placements["cp_atom_features"] | placements_cp_model_io_mul + placements_model_io_mul = { + f"{k}_{i_mul}": v for k, v in _placements["model_io"].items() for i_mul in range(multiplicity_val) + } + placements_dp_cp = placements_model_io_mul + + io_feats = distribute_atom_features( + inputs=inputs_io, + placements_cp=placements_cp, + placements_dp_cp=placements_dp_cp, + device_mesh=manager.device_mesh_subgroups, + cp_group=manager.group["cp"], + multiplicities={"noise": multiplicity_val}, + ) + _noise_dt_cache[0] = io_feats.pop("noise").to(dtype=dtype) + + _dist_in_val = [False] + + def _det_create_randn(shape, device_mesh, placements, dtype=torch.float32, scale=1.0): + if _dist_in_val[0]: + from boltz.distributed.utils import create_distributed_randn as _real_create_randn + + return _real_create_randn(shape, device_mesh, placements, dtype=dtype, scale=0.0) + from boltz.testing.utils import pad_to_length as _pad + + _compute_noise_dt_once(device_mesh, dtype) + n = _noise_dt_cache[0] + if n.dtype != dtype: + n = n.to(dtype=dtype) + if len(shape) > 1 and n.shape[1] < shape[1]: + n = _pad(n, dim=1, length=shape[1]) + return n * scale + + monkeypatch.setattr(dist_diffusion_module, "create_distributed_randn", _det_create_randn) + + _orig_dist_val_step = Boltz2Distributed.validation_step + + def _dist_val_step_wrapper(self_model, batch, batch_idx): + _dist_in_val[0] = True + try: + return _orig_dist_val_step(self_model, batch, batch_idx) + finally: + _dist_in_val[0] = False + + monkeypatch.setattr(Boltz2Distributed, "validation_step", _dist_val_step_wrapper) + + # --- Skip RMSD in distributed validation (not needed for LDDT parity) --- + import boltz.distributed.model.validation.validator as _dist_validator_mod + + def _rmsd_noop(*args, **kwargs): + return torch.tensor(0.0), None, None + + monkeypatch.setattr(_dist_validator_mod, "weighted_minimum_rmsd_single", _rmsd_noop) + + # --- Capture trainer metrics --- + _captured_metrics: dict[str, float] = {} + _orig_fit = pl.Trainer.fit + + def _capturing_fit(self, *args, **kwargs): + result = _orig_fit(self, *args, **kwargs) + for k, v in self.callback_metrics.items(): + if isinstance(v, DTensor): + _captured_metrics[k] = v.full_tensor().detach().cpu().item() + elif isinstance(v, torch.Tensor): + _captured_metrics[k] = v.detach().cpu().item() + else: + _captured_metrics[k] = v + return result + + monkeypatch.setattr(pl.Trainer, "fit", _capturing_fit) + + # --- Run distributed training --- + train_module.train(dist_config_path, []) + + # --- Load checkpoints and compare --- + dist_ckpt_path = Path(dist_output_dir) / "last.ckpt" + assert dist_ckpt_path.exists(), f"Rank {rank}: distributed checkpoint not found at {dist_ckpt_path}" + dist_ckpt = torch.load(dist_ckpt_path, map_location="cpu", weights_only=False) + serial_ckpt = torch.load(serial_ckpt_path, map_location="cpu", weights_only=False) + + dist_sd = dist_ckpt["state_dict"] + serial_sd = serial_ckpt["state_dict"] + assert len(dist_sd) > 0, f"Rank {rank}: distributed state_dict is empty" + assert len(serial_sd) > 0, f"Rank {rank}: serial state_dict is empty" + + # The distributed model prefixes keys with ``_serial.``; strip it for comparison. + dist_sd_mapped = {} + for k, v in dist_sd.items(): + canonical_k = k.replace("_serial.", "", 1) if k.startswith("_serial.") else k + dist_sd_mapped[canonical_k] = v + + for k in serial_sd: + assert k in dist_sd_mapped, f"Rank {rank}: key '{k}' missing from distributed checkpoint" + torch.testing.assert_close( + dist_sd_mapped[k], + serial_sd[k], + msg=lambda m: f"Rank {rank}: state_dict mismatch on '{k}': {m}", + ) + + # --- EMA weight parity --- + assert "ema" in dist_ckpt, f"Rank {rank}: distributed checkpoint missing EMA state" + assert "ema" in serial_ckpt, f"Rank {rank}: serial checkpoint missing EMA state" + dist_ema = dist_ckpt["ema"]["ema_weights"] + serial_ema = serial_ckpt["ema"]["ema_weights"] + assert dist_ckpt["ema"]["cur_step"] == serial_ckpt["ema"]["cur_step"], ( + f"Rank {rank}: EMA cur_step mismatch: " + f"dist={dist_ckpt['ema']['cur_step']}, serial={serial_ckpt['ema']['cur_step']}" + ) + + dist_ema_mapped = {} + for k, v in dist_ema.items(): + canonical_k = k.replace("_serial.", "", 1) if k.startswith("_serial.") else k + dist_ema_mapped[canonical_k] = v + + for k in serial_ema: + assert k in dist_ema_mapped, f"Rank {rank}: EMA key '{k}' missing from distributed checkpoint" + torch.testing.assert_close( + dist_ema_mapped[k], + serial_ema[k], + msg=lambda m: f"Rank {rank}: EMA weight mismatch on '{k}': {m}", + ) + + # --- Non-vacuous guard: at least one parameter changed from pretrained init --- + pretrained_sd = torch.load( + pretrained_ckpt_path, + map_location="cpu", + weights_only=False, + ).get("state_dict", {}) + if pretrained_sd: + changed = any(not torch.equal(serial_sd[k], pretrained_sd[k]) for k in serial_sd if k in pretrained_sd) + assert changed, f"Rank {rank}: no parameters changed from pretrained init — test is vacuous" + + # --- Metric parity --- + # Atom-level LDDT metrics (val/lddt_*, val/complex_lddt_*) and the global + # weighted-average val/lddt depend on diffusion-sampled coordinates, which + # differ slightly between serial and distributed forward passes in FP32 due + # to accumulation order in parallel attention. Their exact parity is + # verified separately by test_boltz2_validation_step_parity (FP64). Here + # we use relaxed tolerance (atol=5e-4) for these forward-pass-dependent + # metrics and default tolerance for everything else. Trailing underscore + # omitted so "val/lddt" (global) is also matched. + _forward_dependent_prefixes = ("val/lddt", "val/complex_lddt", "val/clash", "val/pb", "val/rmsd") + _lddt_keys_compared = [] + if serial_metrics: + for k in serial_metrics: + if k in _captured_metrics: + got = torch.tensor(_captured_metrics[k]) + exp = torch.tensor(serial_metrics[k]) + if any(k.startswith(p) for p in _forward_dependent_prefixes): + torch.testing.assert_close( + got, + exp, + atol=5e-4, + rtol=0.02, + msg=lambda m: f"Rank {rank}: metric '{k}' mismatch: {m}", + ) + else: + torch.testing.assert_close( + got, + exp, + msg=lambda m: f"Rank {rank}: metric '{k}' mismatch: {m}", + ) + if "lddt" in k: + _lddt_keys_compared.append(k) + + assert _lddt_keys_compared, ( + f"Rank {rank}: no validation LDDT metrics were compared — test is vacuous. " + f"Serial keys: {sorted(serial_metrics)}, dist keys: {sorted(_captured_metrics)}" + ) + for required_metric in ("val/lddt", "val/disto_lddt", "val/complex_lddt"): + assert ( + required_metric in _captured_metrics + ), f"Rank {rank}: distributed metrics missing '{required_metric}' — available: {sorted(_captured_metrics)}" + + # Verify component-wise grad_norm metrics are present and non-zero + _grad_norm_keys = [ + "train/grad_norm", + "train/grad_norm_msa_module", + "train/grad_norm_pairformer_module", + "train/grad_norm_structure_module", + ] + for gn_key in _grad_norm_keys: + assert ( + gn_key in _captured_metrics + ), f"Rank {rank}: distributed metrics missing '{gn_key}' — available: {sorted(_captured_metrics)}" + assert ( + _captured_metrics[gn_key] > 0 + ), f"Rank {rank}: '{gn_key}' is zero — gradients should be non-zero after training" + + torch.distributed.barrier() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp2-cp2x2"], +) +def test_boltz2_e2e_training_parity( + setup_env, + test_cp_training_base_data_dir_boltz2, + canonical_mols_dir, + tmp_path, +): + """E2E serial-vs-DTensor training parity via ``train()`` entry points. + + Both serial and distributed training go through their respective + ``train()`` functions for 1 epoch (1 batch of 7ylz+8b2e for training, + 7z64+8ayv for validation), then compare checkpoints (state_dict, EMA + weights) and logged metrics — including validation LDDT — at FP32 + default tolerance. Model initialisation is controlled via a pretrained + checkpoint; noise and data RNG are controlled via module-level + monkeypatches. + + Comparison summary (32 metrics, 296 state_dict params, 296 EMA params): + + State dict & EMA (296 params each, default FP32 tolerance): + All params have non-zero magnitude (absmax in [3.9e-3, 5.0e-1]). + Serial and distributed match exactly (bitwise identical). + + Weight update (pretrained -> post-training, lr=1e-4, 1 step): + 273/296 params changed (delta_absmax ~1e-4, consistent with lr). + 23 params unchanged (fourier embeddings, some norms/MLPs — zero + gradient for this mini-batch). 15 "changed" params have delta < 1e-8 + (triangle attention Q/K weights with negligible gradients). + + Training metrics (default tolerance): + train/loss=1.78, train/grad_norm=0.43, train/param_norm=6.37, + train/diffusion_loss=0.53, train/distogram_loss=4.16, + train/grad_norm_{msa_module,pairformer_module,structure_module} + — all match within 1e-7 or exactly. Component-wise and global + grad_norms are non-zero (logged from on_after_backward where + gradients are available). + + Validation metrics — token-level (default tolerance): + val/disto_lddt_{ligand_protein,intra_protein,protein_protein, + intra_ligand}, val/disto_loss — all non-zero, match exactly. + + Validation metrics — atom-level (relaxed: atol=5e-4, rtol=0.02): + val/lddt_{intra_ligand,intra_protein,ligand_protein,protein_protein} + — non-zero values in [7e-4, 0.053], with abs diffs up to 8e-4 due + to FP32 accumulation order differences in distributed attention. + Exact parity verified separately by test_boltz2_validation_step_parity + in FP64. + + Trivially-zero metrics (12 keys, no DNA/RNA in test data): + val/{lddt,disto_lddt}_{dna_protein,rna_protein,dna_ligand, + rna_ligand,intra_dna,intra_rna} — 0==0 on both sides. + """ + + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + dtype = torch.float32 + seed = 42 + multiplicity = 2 + B = 2 + max_tokens = 256 + W = 32 # atoms_per_window_queries + size_cp = grid_group_sizes["cp"][0] * grid_group_sizes["cp"][1] + atom_align = math.lcm(W, size_cp) + max_atoms = ((max_tokens * 10 + atom_align - 1) // atom_align) * atom_align + max_seqs = 16 + scale_glorot = 0.05 + + # --- Merge all 4 samples with train/val split --- + training_data_dir, split_file = _setup_training_data_all_4_e2e( + tmp_path / "training_data", test_cp_training_base_data_dir_boltz2 + ) + + # --- Create pretrained checkpoint with deterministic init --- + seed_by_rank(0, seed=seed) + model_dict = _e2e_model_dict(multiplicity=multiplicity, validate_structure=True) + model_dict.pop("_target_") + model_dict.pop("validators", None) + _val_validators = [RCSBValidator(val_names=["RCSB"], confidence_prediction=False, physicalism_metrics=True)] + pretrained_model = SerialBoltz2(**model_dict, validators=_val_validators) + init_module_params_glorot(pretrained_model, gain=scale_glorot) + pretrained_model.apply(SetModuleInfValues()) + pretrained_model.structure_module.coordinate_augmentation = False + pretrained_model = pretrained_model.to(dtype=dtype) + + pretrained_path = tmp_path / "pretrained.ckpt" + torch.save( + { + "state_dict": pretrained_model.state_dict(), + "pytorch-lightning_version": pl.__version__, + "hyper_parameters": pretrained_model.hparams, + }, + pretrained_path, + ) + + # --- Pre-load and cache individual samples to disk --- + # The serial and distributed ``train()`` create independent data pipelines + # that call ``TrainingDataset.__getitem__`` separately. Despite RNG seeding, + # the featurizer has non-deterministic code paths (random augmentation of + # ref_pos, MSA subsampling). Caching the getitem results to disk and + # replaying them guarantees identical features in both pipelines. + _tmp_mp = pytest.MonkeyPatch() + _apply_e2e_deterministic_getitem(_tmp_mp, base_seed=seed) + _preload_cfg = setup_mock_training_datamodule_config(training_data_dir) + _preload_cfg.batch_size = B + _preload_cfg.samples_per_epoch = B + _preload_cfg.moldir = str(canonical_mols_dir) + _preload_cfg.return_train_symmetries = False + _preload_cfg.msa_sampling_training = False + _preload_cfg.max_tokens = max_tokens + _preload_cfg.max_atoms = max_atoms + _preload_cfg.max_seqs = max_seqs + for _ds in _preload_cfg.datasets: + _ds.filters = None + _ds.split = str(split_file) + _ds.symmetry_correction = False + seed_by_rank(0, seed=seed) + _preload_dm = Boltz2TrainingDataModule(cfg=_preload_cfg) + + _preload_ds = _preload_dm._train_set + _cached_samples = {i: _preload_ds[i] for i in range(B)} + cached_samples_path = tmp_path / "cached_samples.pt" + torch.save(_cached_samples, cached_samples_path) + + _preload_dl = _preload_dm.train_dataloader() + _preload_batch = next(iter(_preload_dl)) + atom_counts_per_token_host = _preload_batch["atom_counts_per_token"].detach().cpu() + atom_pad_mask_host = _preload_batch["atom_pad_mask"].detach().cpu() + _tmp_mp.undo() + + # --- Pre-generate deterministic noise (masked by atom_pad_mask) --- + seed_by_rank(0, seed=seed) + sigmas_global = pretrained_model.structure_module.noise_distribution(B * multiplicity).to(dtype=dtype) + noise_global = torch.empty(B * multiplicity, max_atoms, 3, dtype=dtype) + init_tensors_uniform([noise_global], low=-scale_glorot, high=scale_glorot) + _mask_mul = atom_pad_mask_host[:, :, None].repeat_interleave(multiplicity, 0).to(dtype=dtype) + noise_global = noise_global * _mask_mul + + sigmas_global_host = sigmas_global.detach().cpu() + noise_global_host = noise_global.detach().cpu() + + # --- Write serial config --- + serial_output_dir = tmp_path / "serial_output" + serial_output_dir.mkdir(parents=True, exist_ok=True) + serial_config_path = tmp_path / "serial_config.yaml" + _e2e_ds_overrides = { + "filters": None, + "moldir": None, + "symmetry_correction": False, + "val_group": "RCSB", + "use_train_subset": None, + "override_bfactor": False, + "override_method": None, + } + _write_train_config( + TrainTestConfig( + config_path=serial_config_path, + output_dir=serial_output_dir, + test_data_dir=training_data_dir, + mol_dir=canonical_mols_dir, + mode="serial", + accelerator="gpu", + limit_train_batches=1, + pretrained=str(pretrained_path), + model=_e2e_model_dict(multiplicity=multiplicity, validate_structure=True), + batch_size=B, + samples_per_epoch=B, + max_tokens=max_tokens, + max_atoms=max_atoms, + max_seqs=max_seqs, + return_train_symmetries=False, + split=str(split_file), + pop_target_keys=True, + extra_dataset_overrides=_e2e_ds_overrides, + v2=True, + strict_loading=False, + wandb=None, + save_top_k=0, + disable_checkpoint=False, + ) + ) + + # --- Apply serial monkeypatches --- + serial_mp = pytest.MonkeyPatch() + _apply_cached_getitem(serial_mp, cached_samples_path) + + _orig_serial_boltz2_init = SerialBoltz2.__init__ + + @functools.wraps(_orig_serial_boltz2_init) + def _init_with_validators(self, *args, **kwargs): + if kwargs.get("validate_structure", False) and not kwargs.get("validators"): + kwargs["validators"] = [ + RCSBValidator(val_names=["RCSB"], confidence_prediction=False, physicalism_metrics=True) + ] + _orig_serial_boltz2_init(self, *args, **kwargs) + + serial_mp.setattr(SerialBoltz2, "__init__", _init_with_validators) + + serial_mp.setattr( + SerialAtomDiffusionV2, + "noise_distribution", + lambda self, bs: sigmas_global[:bs].to(device=self.zero.device), + ) + + _serial_in_val = [False] + + def _serial_randn_like(t): + if _serial_in_val[0]: + return torch.zeros_like(t) + return noise_global[: t.shape[0], : t.shape[1]].to(device=t.device, dtype=t.dtype) + + serial_mp.setattr(serial_diffusion_v2_module.torch, "randn_like", _serial_randn_like) + + _orig_serial_randn = serial_diffusion_v2_module.torch.randn + + def _serial_randn(*args, **kwargs): + if _serial_in_val[0]: + kwargs.pop("generator", None) + return torch.zeros(*args, **kwargs) + return _orig_serial_randn(*args, **kwargs) + + serial_mp.setattr(serial_diffusion_v2_module.torch, "randn", _serial_randn) + + _orig_compute_random_augmentation = serial_diffusion_v2_module.compute_random_augmentation + + def _identity_augmentation_during_val(multiplicity, s_trans=1.0, device=None, dtype=torch.float32): + if _serial_in_val[0]: + R = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(multiplicity, -1, -1) + tr = torch.zeros(multiplicity, 1, 3, device=device, dtype=dtype) + return R, tr + return _orig_compute_random_augmentation(multiplicity, s_trans=s_trans, device=device, dtype=dtype) + + serial_mp.setattr(serial_diffusion_v2_module, "compute_random_augmentation", _identity_augmentation_during_val) + + _orig_serial_val_step = SerialBoltz2.validation_step + + def _serial_val_step_wrapper(self_model, batch, batch_idx): + _serial_in_val[0] = True + try: + return _orig_serial_val_step(self_model, batch, batch_idx) + finally: + _serial_in_val[0] = False + + serial_mp.setattr(SerialBoltz2, "validation_step", _serial_val_step_wrapper) + + serial_mp.setattr(serial_loss_v2_module, "smooth_lddt_loss", _smooth_lddt_loss_dense_e2e) + serial_mp.setattr(serial_diffusion_v2_module, "smooth_lddt_loss", _smooth_lddt_loss_dense_e2e) + + serial_captured_metrics: dict[str, float] = {} + _orig_fit = pl.Trainer.fit + + def _serial_capturing_fit(self, *args, **kwargs): + result = _orig_fit(self, *args, **kwargs) + for k, v in self.callback_metrics.items(): + if isinstance(v, torch.Tensor): + serial_captured_metrics[k] = v.detach().cpu().item() + else: + serial_captured_metrics[k] = v + return result + + serial_mp.setattr(pl.Trainer, "fit", _serial_capturing_fit) + + # --- Run serial training --- + _serial_train_mod = _load_serial_train_module() + _serial_train_mod.train(str(serial_config_path), []) + + # --- Find serial checkpoint --- + serial_ckpt_files = list(serial_output_dir.rglob("last.ckpt")) + assert ( + len(serial_ckpt_files) == 1 + ), f"Expected exactly 1 last.ckpt in serial output, found {len(serial_ckpt_files)}: {serial_ckpt_files}" + serial_ckpt_path = serial_ckpt_files[0] + + # Verify serial checkpoint has EMA + serial_ckpt = torch.load(serial_ckpt_path, map_location="cpu", weights_only=False) + assert "ema" in serial_ckpt, "Serial checkpoint missing EMA state" + assert "ema_weights" in serial_ckpt["ema"], "Serial EMA missing ema_weights" + + # Non-vacuous guard: serial must have logged at least one LDDT val metric + serial_lddt_keys = [ + k for k in serial_captured_metrics if k.startswith("val/lddt_") or k.startswith("val/disto_lddt_") + ] + assert serial_lddt_keys, ( + f"Serial run produced no validation LDDT metrics — test is vacuous. " + f"Available metrics: {sorted(serial_captured_metrics)}" + ) + + serial_mp.undo() + + # --- Write distributed config --- + dp = grid_group_sizes["dp"] + cp0, cp1 = grid_group_sizes["cp"] + size_cp = cp0 * cp1 + dist_output_dir = tmp_path / "dist_output" + dist_output_dir.mkdir(parents=True, exist_ok=True) + dist_config_path = tmp_path / "dist_config.yaml" + _write_train_config( + TrainTestConfig( + config_path=dist_config_path, + output_dir=dist_output_dir, + test_data_dir=training_data_dir, + mol_dir=canonical_mols_dir, + size_dp=dp, + size_cp=size_cp, + accelerator="gpu", + limit_train_batches=1, + pretrained=str(pretrained_path), + model=_e2e_model_dict(multiplicity=multiplicity, validate_structure=True, distributed=True), + batch_size=1, + samples_per_epoch=dp, + max_tokens=max_tokens, + max_atoms=max_atoms, + max_seqs=max_seqs, + return_train_symmetries=False, + split=str(split_file), + pop_target_keys=True, + extra_dataset_overrides=_e2e_ds_overrides, + ) + ) + + # --- Spawn distributed workers --- + spawn_multiprocessing( + _worker_e2e_training_parity, + world_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + str(dist_config_path), + str(dist_output_dir), + str(serial_ckpt_path), + str(pretrained_path), + serial_captured_metrics, + sigmas_global_host, + noise_global_host, + atom_counts_per_token_host, + str(cached_samples_path), + seed, + ) + + +# --------------------------------------------------------------------------- +# BF16 / activation-checkpoint parity: shared setup +# --------------------------------------------------------------------------- + + +@dataclass +class _Bf16AcTestEnv: + """Return type for :func:`_setup_bf16_ac_test_env`.""" + + serial_config_path: Path + dist_config_path: Path + grid_group_sizes: dict + world_size: int + device_type: str + backend: str + env_per_rank: dict[str, Any] + + +def _setup_bf16_ac_test_env( + setup_env, + test_cp_training_base_data_dir_boltz2: Path, + canonical_mols_dir: Path, + tmp_path: Path, +) -> _Bf16AcTestEnv: + """Shared setup for BF16 + activation-checkpointing parity tests. + + Creates training data, a pretrained checkpoint with all AC flags enabled + (mirroring ``structurev2.yaml``), and writes serial / distributed YAML + configs. Returns paths and grid metadata so each test can attach its own + profiler and spawn workers. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + seed = 42 + multiplicity = 2 + B = 2 + max_tokens = 256 + W = 32 # atoms_per_window_queries + size_cp = grid_group_sizes["cp"][0] * grid_group_sizes["cp"][1] + atom_align = math.lcm(W, size_cp) + max_atoms = ((max_tokens * 10 + atom_align - 1) // atom_align) * atom_align + max_seqs = 16 + scale_glorot = 0.05 + + training_data_dir, split_file = _setup_training_data_all_4_e2e( + tmp_path / "training_data", test_cp_training_base_data_dir_boltz2 + ) + + ac_model_dict = _e2e_model_dict(multiplicity=multiplicity, validate_structure=False) + ac_model_dict["checkpoint_diffusion_conditioning"] = True + ac_model_dict["msa_args"]["activation_checkpointing"] = True + ac_model_dict["pairformer_args"]["activation_checkpointing"] = True + ac_model_dict["score_model_args"]["activation_checkpointing"] = True + + seed_by_rank(0, seed=seed) + model_dict = copy.deepcopy(ac_model_dict) + model_dict.pop("_target_") + model_dict.pop("validators", None) + pretrained_model = SerialBoltz2(**model_dict) + init_module_params_glorot(pretrained_model, gain=scale_glorot) + pretrained_model.apply(SetModuleInfValues()) + pretrained_model.structure_module.coordinate_augmentation = False + pretrained_model = pretrained_model.to(dtype=torch.float32) + + pretrained_path = tmp_path / "pretrained.ckpt" + torch.save( + { + "state_dict": pretrained_model.state_dict(), + "pytorch-lightning_version": pl.__version__, + "hyper_parameters": pretrained_model.hparams, + }, + pretrained_path, + ) + + _e2e_ds_overrides = { + "filters": None, + "moldir": None, + "symmetry_correction": False, + "val_group": "RCSB", + "use_train_subset": None, + "override_bfactor": False, + "override_method": None, + } + + serial_output_dir = tmp_path / "serial_output" + serial_output_dir.mkdir(parents=True, exist_ok=True) + serial_config_path = tmp_path / "serial_config.yaml" + _write_train_config( + TrainTestConfig( + config_path=serial_config_path, + output_dir=serial_output_dir, + test_data_dir=training_data_dir, + mol_dir=canonical_mols_dir, + mode="serial", + accelerator="gpu", + precision="bf16-mixed", + limit_train_batches=1, + limit_val_batches=0, + num_sanity_val_steps=0, + pretrained=str(pretrained_path), + model=copy.deepcopy(ac_model_dict), + batch_size=B, + samples_per_epoch=B, + max_tokens=max_tokens, + max_atoms=max_atoms, + max_seqs=max_seqs, + return_train_symmetries=False, + split=str(split_file), + pop_target_keys=True, + extra_dataset_overrides=_e2e_ds_overrides, + v2=True, + strict_loading=False, + wandb=None, + save_top_k=0, + disable_checkpoint=True, + ) + ) + + dp = grid_group_sizes["dp"] + size_cp = grid_group_sizes["cp"] + if isinstance(size_cp, tuple): + size_cp = size_cp[0] * size_cp[1] + dist_output_dir = tmp_path / "dist_output" + dist_output_dir.mkdir(parents=True, exist_ok=True) + dist_config_path = tmp_path / "dist_config.yaml" + _write_train_config( + TrainTestConfig( + config_path=dist_config_path, + output_dir=dist_output_dir, + test_data_dir=training_data_dir, + mol_dir=canonical_mols_dir, + size_dp=dp, + size_cp=size_cp, + accelerator="gpu", + precision="BF16_MIXED", + limit_train_batches=1, + limit_val_batches=0, + num_sanity_val_steps=0, + pretrained=str(pretrained_path), + model=copy.deepcopy(ac_model_dict), + batch_size=1, + samples_per_epoch=dp, + max_tokens=max_tokens, + max_atoms=max_atoms, + max_seqs=max_seqs, + return_train_symmetries=False, + split=str(split_file), + pop_target_keys=True, + extra_dataset_overrides=_e2e_ds_overrides, + ) + ) + + return _Bf16AcTestEnv( + serial_config_path=serial_config_path, + dist_config_path=dist_config_path, + grid_group_sizes=grid_group_sizes, + world_size=world_size, + device_type=device_type, + backend=backend, + env_per_rank=env_per_rank, + ) + + +# --------------------------------------------------------------------------- +# BF16 dtype parity: serial vs DTensor training +# --------------------------------------------------------------------------- + + +def _worker_bf16_dtype_parity( + rank: int, + grid_group_sizes: dict, + device_type: str, + backend: str, + env_per_rank: dict[str, Any], + dist_config_path: str, + serial_dtype_profile_path: str, +) -> None: + """Multi-rank worker: run distributed train(), compare dtype profiles with serial. + + Only dtype equality is checked — no numerical comparison. The serial + dtype profile (written by the main process) is loaded from disk and + compared against the distributed profile captured via :class:`DtypeProfiler`. + """ + from boltz.testing.utils import DtypeProfiler + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + monkeypatch.setattr(train_module, "_cleanup_distributed", lambda: None) + DistributedManager._state = {} + + # Capture dtype profile via Trainer.fit monkeypatch + _dist_profiler: list[DtypeProfiler | None] = [None] + _orig_fit = pl.Trainer.fit + + def _profiling_fit(trainer_self, model, **kwargs): + _dist_profiler[0] = DtypeProfiler(model) + result = _orig_fit(trainer_self, model, **kwargs) + _dist_profiler[0].collect_grad_dtypes(model) + return result + + monkeypatch.setattr(pl.Trainer, "fit", _profiling_fit) + + train_module.train(dist_config_path, []) + + profiler = _dist_profiler[0] + assert profiler is not None, f"Rank {rank}: DtypeProfiler was never attached" + + profiler.remove_hooks() + + # Load serial dtype profile + serial_profile = torch.load(serial_dtype_profile_path, map_location="cpu", weights_only=False) + serial_fwd = serial_profile["fwd_dtypes"] + serial_grads = serial_profile["param_grad_dtypes"] + + dist_fwd = profiler.fwd_dtypes + dist_params = profiler.param_dtypes + dist_grads = profiler.param_grad_dtypes + + # --- Parameter dtypes: all FP32 under bf16-mixed --- + for name, dtype in dist_params.items(): + assert dtype == torch.float32, f"Rank {rank}: distributed param '{name}' has dtype {dtype}, expected float32" + + # --- Forward activation dtypes: strict equality at common module names --- + common_fwd = sorted(set(serial_fwd) & set(dist_fwd)) + assert len(common_fwd) >= 10, ( + f"Rank {rank}: only {len(common_fwd)} common forward module names " + f"between serial ({len(serial_fwd)}) and distributed ({len(dist_fwd)}). " + f"Expected >= 10 for a non-vacuous comparison." + ) + fwd_mismatches: list[str] = [] + for name in common_fwd: + if serial_fwd[name] != dist_fwd[name]: + fwd_mismatches.append(f" {name}: serial={serial_fwd[name]}, dist={dist_fwd[name]}") + assert not fwd_mismatches, f"Rank {rank}: forward activation dtype mismatches:\n" + "\n".join(fwd_mismatches) + + # --- Non-vacuous: autocast must produce a mix of BF16 and FP32 --- + dist_fwd_dtypes_set = set(dist_fwd.values()) + assert ( + torch.bfloat16 in dist_fwd_dtypes_set + ), f"Rank {rank}: no bfloat16 activations found — autocast may not be active. Unique dtypes: {dist_fwd_dtypes_set}" + assert ( + torch.float32 in dist_fwd_dtypes_set + ), f"Rank {rank}: no float32 activations found — all ops appear autocasted. Unique dtypes: {dist_fwd_dtypes_set}" + + # --- Parameter gradient dtypes: strict equality at common param names --- + common_grads = sorted(set(serial_grads) & set(dist_grads)) + grad_mismatches: list[str] = [] + for name in common_grads: + if serial_grads[name] != dist_grads[name]: + grad_mismatches.append(f" {name}: serial={serial_grads[name]}, dist={dist_grads[name]}") + assert not grad_mismatches, f"Rank {rank}: param gradient dtype mismatches:\n" + "\n".join(grad_mismatches) + + torch.distributed.barrier() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp1-cp2x2"], +) +def test_boltz2_bf16_dtype_parity( + setup_env, + test_cp_training_base_data_dir_boltz2, + canonical_mols_dir, + tmp_path, +): + """Verify that BF16-mixed autocast produces identical dtype profiles in + serial and DTensor training workflows. + + Runs 1 training step under ``bf16-mixed`` precision for both the serial + ``Boltz2`` model (via ``scripts/train/train.py``) and the distributed + ``Boltz2`` model (via ``src/boltz/distributed/train.py``), then compares + forward activation dtypes, parameter dtypes, and parameter gradient + dtypes at every module whose name appears in both models. + + No numerical comparison is performed — only dtype equality. This means + no deterministic-noise / cached-sample monkeypatching is required. + """ + from boltz.testing.utils import DtypeProfiler + + env = _setup_bf16_ac_test_env(setup_env, test_cp_training_base_data_dir_boltz2, canonical_mols_dir, tmp_path) + + # --- Run serial training with dtype profiling --- + serial_mp = pytest.MonkeyPatch() + _serial_profiler: list[DtypeProfiler | None] = [None] + _orig_fit = pl.Trainer.fit + + def _serial_profiling_fit(trainer_self, model, **kwargs): + _serial_profiler[0] = DtypeProfiler(model) + result = _orig_fit(trainer_self, model, **kwargs) + _serial_profiler[0].collect_grad_dtypes(model) + return result + + serial_mp.setattr(pl.Trainer, "fit", _serial_profiling_fit) + + _serial_train_mod = _load_serial_train_module() + _serial_train_mod.train(str(env.serial_config_path), []) + + profiler = _serial_profiler[0] + assert profiler is not None, "Serial DtypeProfiler was never attached" + profiler.remove_hooks() + + # Non-vacuous: serial must have a mix of BF16 and FP32 activations + serial_fwd_dtypes_set = set(profiler.fwd_dtypes.values()) + assert ( + torch.bfloat16 in serial_fwd_dtypes_set + ), f"Serial: no bfloat16 activations — autocast may not be active. Unique dtypes: {serial_fwd_dtypes_set}" + assert ( + torch.float32 in serial_fwd_dtypes_set + ), f"Serial: no float32 activations. Unique dtypes: {serial_fwd_dtypes_set}" + + # All serial params must be FP32 + for name, dtype in profiler.param_dtypes.items(): + assert dtype == torch.float32, f"Serial param '{name}' has dtype {dtype}, expected float32" + + # Save serial profile for workers + serial_dtype_profile_path = tmp_path / "serial_dtype_profile.pt" + torch.save( + { + "fwd_dtypes": profiler.fwd_dtypes, + "param_dtypes": profiler.param_dtypes, + "param_grad_dtypes": profiler.param_grad_dtypes, + }, + serial_dtype_profile_path, + ) + + serial_mp.undo() + + # --- Spawn distributed workers --- + spawn_multiprocessing( + _worker_bf16_dtype_parity, + env.world_size, + env.grid_group_sizes, + env.device_type, + env.backend, + env.env_per_rank, + str(env.dist_config_path), + str(serial_dtype_profile_path), + ) + + +# --------------------------------------------------------------------------- +# Activation checkpoint recomputation parity: serial vs DTensor training +# --------------------------------------------------------------------------- + + +def _worker_actv_ckpt_parity( + rank: int, + grid_group_sizes: dict, + device_type: str, + backend: str, + env_per_rank: dict[str, Any], + dist_config_path: str, + serial_recompute_profile_path: str, +) -> None: + """Per-rank worker: compare checkpoint-recomputed modules with serial. + + Each rank records its own local forward-hook call counts via + :class:`RecomputeProfiler` and independently compares against the serial + reference. Counts are **never** aggregated across ranks. + """ + from boltz.testing.utils import RecomputeProfiler + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + monkeypatch.setattr(train_module, "_cleanup_distributed", lambda: None) + DistributedManager._state = {} + + _dist_profiler: list[RecomputeProfiler | None] = [None] + _orig_fit = pl.Trainer.fit + + def _profiling_fit(trainer_self, model, **kwargs): + _dist_profiler[0] = RecomputeProfiler(model) + return _orig_fit(trainer_self, model, **kwargs) + + monkeypatch.setattr(pl.Trainer, "fit", _profiling_fit) + + train_module.train(dist_config_path, []) + + profiler = _dist_profiler[0] + assert profiler is not None, f"Rank {rank}: RecomputeProfiler was never attached" + profiler.remove_hooks() + + serial_profile = torch.load(serial_recompute_profile_path, map_location="cpu", weights_only=False) + serial_counts: dict[str, int] = serial_profile["fwd_counts"] + dist_counts = profiler.fwd_counts + + common = sorted(set(serial_counts) & set(dist_counts)) + assert len(common) >= 10, ( + f"Rank {rank}: only {len(common)} common module names between serial " + f"({len(serial_counts)}) and distributed ({len(dist_counts)}). Expected >= 10." + ) + + serial_recomp = {n for n in common if serial_counts[n] >= 2} + dist_recomp = {n for n in common if dist_counts[n] >= 2} + + # Non-vacuous: activation checkpointing must be active on this rank + assert len(dist_recomp) >= 5, ( + f"Rank {rank}: only {len(dist_recomp)} recomputed modules in distributed " + f"(expected >= 5). Activation checkpointing may not be active." + ) + assert dist_recomp < set(common), ( + f"Rank {rank}: all {len(common)} common modules are recomputed — " + f"not a strict subset, implying either a counting bug or every module is checkpointed." + ) + + serial_only = sorted(serial_recomp - dist_recomp) + dist_only = sorted(dist_recomp - serial_recomp) + assert ( + not serial_only + ), f"Rank {rank}: modules recomputed in serial but not DTensor ({len(serial_only)}): {serial_only}" + assert not dist_only, f"Rank {rank}: modules recomputed in DTensor but not serial ({len(dist_only)}): {dist_only}" + + torch.distributed.barrier() + + +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cuda-dp1-cp2x2"], +) +def test_boltz_actv_ckpt_parity( + setup_env, + test_cp_training_base_data_dir_boltz2, + canonical_mols_dir, + tmp_path, +): + """Verify serial and DTensor training checkpoint-recompute the same modules. + + Runs 1 training step (forward + backward) under ``bf16-mixed`` precision + with all production activation-checkpointing flags enabled. Forward hooks + count how many times each module's forward is invoked: modules inside a + ``torch.utils.checkpoint.checkpoint`` region are called twice (once in + forward, once during backward recomputation). + + Each distributed rank independently compares its local counts against the + serial reference — counts are **never** aggregated across ranks. + """ + from boltz.testing.utils import RecomputeProfiler + + env = _setup_bf16_ac_test_env(setup_env, test_cp_training_base_data_dir_boltz2, canonical_mols_dir, tmp_path) + + # --- Run serial training with recompute profiling --- + serial_mp = pytest.MonkeyPatch() + _serial_profiler: list[RecomputeProfiler | None] = [None] + _orig_fit = pl.Trainer.fit + + def _serial_profiling_fit(trainer_self, model, **kwargs): + _serial_profiler[0] = RecomputeProfiler(model) + return _orig_fit(trainer_self, model, **kwargs) + + serial_mp.setattr(pl.Trainer, "fit", _serial_profiling_fit) + + _serial_train_mod = _load_serial_train_module() + _serial_train_mod.train(str(env.serial_config_path), []) + + profiler = _serial_profiler[0] + assert profiler is not None, "Serial RecomputeProfiler was never attached" + profiler.remove_hooks() + + # Non-vacuous: activation checkpointing must recompute some modules + serial_recomp = profiler.recomputed_modules + all_serial_names = set(profiler.fwd_counts) + assert len(serial_recomp) >= 5, ( + f"Serial: only {len(serial_recomp)} modules recomputed (expected >= 5). " + f"Activation checkpointing may not be active." + ) + assert ( + serial_recomp < all_serial_names + ), f"Serial: all {len(all_serial_names)} modules are recomputed — not a strict subset, implying a counting bug." + + # Save serial recompute profile for workers + serial_recompute_profile_path = tmp_path / "serial_recompute_profile.pt" + torch.save({"fwd_counts": profiler.fwd_counts}, serial_recompute_profile_path) + + serial_mp.undo() + + # --- Spawn distributed workers (per-rank comparison, no cross-rank aggregation) --- + spawn_multiprocessing( + _worker_actv_ckpt_parity, + env.world_size, + env.grid_group_sizes, + env.device_type, + env.backend, + env.env_per_rank, + str(env.dist_config_path), + str(serial_recompute_profile_path), + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/distributed/test_dtensor_cp_dataloader_v2.py b/tests/distributed/test_dtensor_cp_dataloader_v2.py new file mode 100644 index 000000000..b21de514f --- /dev/null +++ b/tests/distributed/test_dtensor_cp_dataloader_v2.py @@ -0,0 +1,1188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# System imports +from pathlib import Path +from typing import Dict, Optional + +# Third party imports +import pytest +import torch +from torch.distributed.tensor import DTensor +from torch.utils.data import DataLoader, DistributedSampler + +from boltz.data.module.inferencev2 import Boltz2InferenceDataModule, PredictionDataset, collate +from boltz.data.module.trainingv2 import ( + Boltz2TrainingDataModule, +) +from boltz.data.module.trainingv2 import ( + collate as collate_training, +) +from boltz.data.types import Manifest +from boltz.distributed.data.module.inferencev2 import ( + Boltz2InferenceDataModuleDTensor as BoltzInferenceDataModuleDTensor, +) +from boltz.distributed.data.module.trainingv2 import ( + Boltz2TrainingDataModule as BoltzTrainingDataModuleDTensor, +) +from boltz.distributed.data.types import PairMaskMode +from boltz.distributed.data.utils import ( + ATOM_FEATURES_V2 as ATOM_FEATURES, +) +from boltz.distributed.data.utils import ( + NON_SHARDED_FEATURES_V2, + map_subgroup_mesh_to_cpu, +) +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.atom_to_token import ( + single_repr_atom_to_token, + single_repr_token_to_atom, +) +from boltz.distributed.testing.utils import setup_mock_training_datamodule_config +from boltz.main import BoltzProcessedInput +from boltz.testing.utils import ( + concat_data, + seed_by_rank, + spawn_multiprocessing, +) + +# TODO support CP constraint, template, affinity, ensemble features for inference +INFERENCE_FEATURES_DIFFERENCE = { + "atom_counts_per_token", + "chiral_atom_index", + "chiral_atom_orientations", + "chiral_reference_mask", + "connected_atom_index", + "connected_chain_index", + "contact_negation_mask", + "contact_pair_index", + "contact_thresholds", + "contact_union_index", + "ensemble_ref_idxs", + "pair_mask", + "planar_bond_index", + "planar_ring_5_index", + "planar_ring_6_index", + "query_to_template", + "rdkit_bounds_angle_mask", + "rdkit_bounds_bond_mask", + "rdkit_bounds_index", + "rdkit_lower_bounds", + "rdkit_upper_bounds", + "record", + "r_set_to_rep_atom", + "stereo_bond_index", + "stereo_bond_orientations", + "stereo_reference_mask", + "symmetric_chain_index", + "template_ca", + "template_cb", + "template_frame_rot", + "template_frame_t", + "template_mask", + "template_mask_cb", + "template_mask_frame", + "template_restype", + "token_to_center_atom", + "token_pair_pad_mask", + "visibility_ids", +} + +TRAINING_FEATURES_DIFFERENCE = { + "atom_counts_per_token", + "ensemble_ref_idxs", + "idx_dataset", + "record", + "token_pair_pad_mask", +} + +TOKEN_PAIR_FEATURES = { + "contact_conditioning", + "contact_threshold", + "disto_target", + "token_bonds", + "token_pair_pad_mask", + "type_bonds", +} + +MSA_FEATURES = { + "deletion_value", + "has_deletion", + "msa", + "msa_paired", +} + +ENSEMBLE_ATOM_FEATURES = { + "coords", +} + +ENSEMBLE_TOKEN_FEATURES = { + "disto_coords_ensemble", +} + + +def _map_padded_frames_idx_to_unpadded( + frames_idx_dtensor: DTensor, + dp_idx_str: int, + dp_idx_end: int, + atom_to_token_dtensor: DTensor, + frame_resolved_mask_dtensor: DTensor, +) -> torch.Tensor: + """Map frames_idx from padded atom indices to unpadded indices using atom_to_token.""" + frames_idx_padded = frames_idx_dtensor.full_tensor()[dp_idx_str:dp_idx_end] + if frames_idx_padded.ndim != 4: + raise ValueError(f"frames_idx_padded must have ndim=4 (B, E, T, 3), got {frames_idx_padded.ndim}") + frame_resolved_mask_padded = frame_resolved_mask_dtensor.full_tensor()[dp_idx_str:dp_idx_end].bool() + if frame_resolved_mask_padded.ndim != 3: + raise ValueError( + f"frame_resolved_mask_padded must have ndim=3 (B, E, T), got {frame_resolved_mask_padded.ndim}" + ) + + atom_to_token_local = atom_to_token_dtensor.to_local() + if atom_to_token_local.ndim != 3: + raise ValueError( + f"atom_to_token must have ndim=3 (B, N_atoms, N_tokens_per_row), got {atom_to_token_local.ndim}" + ) + if atom_to_token_local.shape[0] != frames_idx_padded.shape[0]: + raise ValueError( + "atom_to_token batch size does not match frames_idx_padded: " + f"{atom_to_token_local.shape[0]} vs {frames_idx_padded.shape[0]}" + ) + device_mesh = atom_to_token_dtensor.device_mesh + cp_axis_0_group = device_mesh.get_group(1) + # Local atom mask for this CP row + atom_mask_local = atom_to_token_local.sum(dim=2) > 0 # (B, N_atoms_per_shard) + local_counts = atom_mask_local.sum(dim=1).to(torch.int64) # (B,) + + # Gather counts and masks across CP axis 0 + counts_list = [torch.empty_like(local_counts) for _ in range(cp_axis_0_group.size())] + torch.distributed.all_gather(counts_list, local_counts, group=cp_axis_0_group) + counts_all = torch.stack(counts_list, dim=0) # (cp_rows, B) + + mask_list = [torch.empty_like(atom_mask_local) for _ in range(cp_axis_0_group.size())] + torch.distributed.all_gather(mask_list, atom_mask_local, group=cp_axis_0_group) + atom_mask_all = torch.stack(mask_list, dim=0) # (cp_rows, B, N_atoms_per_shard) + + # Row offsets for global unpadded indices + row_offsets = torch.cumsum(counts_all, dim=0) - counts_all # (cp_rows, B) + + # CollateDTensor remaps frames_idx to the final padded stride, so the + # padded row width is always the local atom dim after collation. + padded_row_width = atom_to_token_local.shape[1] + + batch_size = frames_idx_padded.shape[0] + mapped = torch.empty_like(frames_idx_padded) + for batch_idx in range(batch_size): + frames = frames_idx_padded[batch_idx] + resolved_frames = frame_resolved_mask_padded[batch_idx] + resolved_frames = resolved_frames.unsqueeze(-1).expand_as(frames) + + row_idx = torch.div(frames, padded_row_width, rounding_mode="floor") + local_idx = frames % padded_row_width + + if row_idx.max().item() >= atom_mask_all.shape[0]: + raise ValueError( + f"frames_idx_padded has row index {row_idx.max().item()} out of range for cp rows " + f"{atom_mask_all.shape[0]}" + ) + # Build prefix sums per row for mapping local -> unpadded + prefix = torch.cumsum(atom_mask_all[:, batch_idx].to(torch.int64), dim=1) - 1 + local_mask = atom_mask_all[:, batch_idx] + # Validate no padding indices are referenced for resolved frames. + valid = local_mask[row_idx, local_idx] + if not torch.all(valid[resolved_frames]): + raise ValueError("frames_idx points to padded atom indices in atom_to_token") + unpadded_local = prefix[row_idx, local_idx] + mapped_vals = row_offsets[row_idx, batch_idx] + unpadded_local + mapped[batch_idx] = torch.where(resolved_frames, mapped_vals, torch.zeros_like(mapped_vals)) + return mapped + + +def _compare_frames_idx_valid_tokens( + frames_idx_non_dtensor: torch.Tensor, + frames_idx_dtensor_unpadded: torch.Tensor, + frame_resolved_mask_non_dtensor: torch.Tensor, + frame_resolved_mask_dtensor: torch.Tensor, + non_dtensor_mask: torch.Tensor, + dtensor_mask: torch.Tensor, +) -> None: + """Compare frames_idx using token and frame-resolved masks on each side.""" + if non_dtensor_mask.ndim != 2 or dtensor_mask.ndim != 2: + raise ValueError(f"token_pad_mask must have ndim=2, got {non_dtensor_mask.ndim} and {dtensor_mask.ndim}") + batch_size = frames_idx_dtensor_unpadded.shape[0] + for batch_idx in range(batch_size): + non_mask = non_dtensor_mask[batch_idx].bool() + dt_mask = dtensor_mask[batch_idx].bool() + if non_mask.sum().item() != dt_mask.sum().item(): + raise AssertionError( + f"frames_idx token count mismatch: non-dtensor={non_mask.sum().item()}, " + f"dtensor={dt_mask.sum().item()}" + ) + non_item = frames_idx_non_dtensor[batch_idx] + dt_item = frames_idx_dtensor_unpadded[batch_idx] + non_frame_mask = frame_resolved_mask_non_dtensor[batch_idx].bool() + dt_frame_mask = frame_resolved_mask_dtensor[batch_idx].bool() + if non_item.ndim != 3 or dt_item.ndim != 3: + raise ValueError( + "frames_idx must be ensemble-aware with ndim=3 per batch item; got " + f"{non_item.ndim} and {dt_item.ndim}" + ) + non_mask_ens = non_mask.unsqueeze(0).expand(non_item.shape[0], -1) & non_frame_mask + dt_mask_ens = dt_mask.unsqueeze(0).expand(dt_item.shape[0], -1) & dt_frame_mask + non_tokens = non_item[non_mask_ens] + dt_tokens = dt_item[dt_mask_ens] + torch.testing.assert_close(non_tokens, dt_tokens, atol=0, rtol=0) + + +def compare_r_set_to_rep_atom( + ref_tensor: torch.Tensor, + dtensor_full: torch.Tensor, + atom_pad_mask_dtensor: torch.Tensor, + n_shards: int, +) -> None: + """Compare r_set_to_rep_atom: reference (one-hot) vs DTensor (one-hot diagonal blocks). + + Simplified comparison using: + - .any(dim=-1) to find valid R-set rows + - .argmax(dim=-1) to get LOCAL atom indices + - atom_pad_mask reshaped per-shard to compute atom offsets + - Convert local -> global indices using per-shard offsets + + Parameters + ---------- + ref_tensor : torch.Tensor + Single-device reference tensor (one-hot), shape [B, size_r_set_ref, N_atom_global] + dtensor_full : torch.Tensor + DTensor full_tensor() (one-hot diagonal blocks), shape [B, size_r_set_padded, max_atoms_per_shard] + atom_pad_mask_dtensor : torch.Tensor + DTensor atom_pad_mask.full_tensor(), shape [B, N_atoms_padded]. True/1 = valid atom. + n_shards : int + Number of shards along the R-set sharding axis (cp_axis_0 size). + """ + # Handle both 2D and 3D tensors + if ref_tensor.dim() == 2: + ref_tensor = ref_tensor.unsqueeze(0) + if dtensor_full.dim() == 2: + dtensor_full = dtensor_full.unsqueeze(0) + if atom_pad_mask_dtensor.dim() == 1: + atom_pad_mask_dtensor = atom_pad_mask_dtensor.unsqueeze(0) + + assert ref_tensor.dim() == 3, f"Expected 3D ref tensor [B, size_r_set, N_atom], got {ref_tensor.shape}" + assert ( + dtensor_full.dim() == 3 + ), f"Expected 3D dtensor [B, size_r_set_padded, max_atoms_per_shard], got {dtensor_full.shape}" + assert ( + atom_pad_mask_dtensor.dim() == 2 + ), f"Expected 2D atom_pad_mask [B, N_atoms_padded], got {atom_pad_mask_dtensor.shape}" + + batch_size = ref_tensor.shape[0] + size_r_set_padded = dtensor_full.shape[1] + max_atoms_per_shard = dtensor_full.shape[2] + max_size_r_set_per_shard = size_r_set_padded // n_shards + + for b in range(batch_size): + ref_b = ref_tensor[b] # [size_r_set_ref, N_atom_global] one-hot + dtensor_b = dtensor_full[b] # [size_r_set_padded, max_atoms_per_shard] one-hot + atom_mask_b = atom_pad_mask_dtensor[b].long() # [N_atoms_padded] + + # Find valid R-set rows (non-zero rows) + ref_valid_mask = ref_b.any(dim=-1) # [size_r_set_ref] + dtensor_valid_mask = dtensor_b.any(dim=-1) # [size_r_set_padded] + + size_r_set_valid_ref = ref_valid_mask.sum().item() + size_r_set_valid_dtensor = dtensor_valid_mask.sum().item() + + if size_r_set_valid_ref == 0: + # No valid R-set elements - DTensor should have no valid rows either + assert ( + size_r_set_valid_dtensor == 0 + ), f"DTensor batch {b} has {size_r_set_valid_dtensor} valid rows, expected 0" + continue + + # Verify counts match + if size_r_set_valid_ref != size_r_set_valid_dtensor: + raise AssertionError( + f"r_set_to_rep_atom batch {b}: count mismatch. " + f"Ref has {size_r_set_valid_ref} valid rows, DTensor has {size_r_set_valid_dtensor}" + ) + + # Get global atom indices from reference (straightforward argmax) + ref_global_indices = ref_b[ref_valid_mask].argmax(dim=-1) # [size_r_set_valid] + + # Get LOCAL atom indices from DTensor (argmax gives local index within shard) + dtensor_atom_ids_local = dtensor_b.argmax(dim=-1) # [size_r_set_padded] local indices + + # Compute per-shard atom offset using atom_pad_mask + # atom_pad_mask_per_shard[s] = mask for shard s's atoms + atom_pad_mask_per_shard = atom_mask_b.reshape(n_shards, max_atoms_per_shard) # [n_shards, max_atoms_per_shard] + atom_counts_per_shard = atom_pad_mask_per_shard.sum(dim=-1) # [n_shards] + # atom_offset[s] = cumsum of atoms before shard s + atom_offset_per_shard = atom_counts_per_shard.cumsum(dim=0) - atom_counts_per_shard # [n_shards] + + # Reshape local atom ids by shard: [n_shards, max_size_r_set_per_shard] + dtensor_atom_ids_local_per_shard = dtensor_atom_ids_local.reshape(n_shards, max_size_r_set_per_shard) + + # Add per-shard offset to convert local -> global + # atom_offset_per_shard: [n_shards] -> [n_shards, 1] for broadcasting + dtensor_atom_ids_global_per_shard = dtensor_atom_ids_local_per_shard + atom_offset_per_shard.unsqueeze(-1) + + # Flatten and select valid rows + dtensor_global_indices = dtensor_atom_ids_global_per_shard.flatten()[ + dtensor_valid_mask + ] # [size_r_set_valid_dtensor] + + # Sort both and compare (order may differ due to token-aligned sharding) + ref_sorted = ref_global_indices.sort().values.to(torch.int64) + dtensor_sorted = dtensor_global_indices.sort().values.to(torch.int64) + + if not torch.equal(ref_sorted, dtensor_sorted): + # Find first mismatch for debugging + mismatch_mask = ref_sorted != dtensor_sorted + first_mismatch = mismatch_mask.nonzero(as_tuple=True)[0][0].item() if mismatch_mask.any() else -1 + raise AssertionError( + f"r_set_to_rep_atom batch {b}: global atom indices mismatch.\n" + f"First mismatch at sorted index {first_mismatch}\n" + f"Ref (sorted)[:10]: {ref_sorted[:10].tolist()}\n" + f"DTensor (sorted)[:10]: {dtensor_sorted[:10].tolist()}" + ) + + +def compare_token_to_rep_atom( + ref_tensor: torch.Tensor, + dtensor_full: torch.Tensor, + atom_pad_mask_ref: torch.Tensor, + token_pad_mask_ref: torch.Tensor, + token_pad_mask_dtensor: torch.Tensor, + n_shards: int, + atom_counts_per_token: torch.Tensor, +) -> None: + """Compare serial token_to_rep_atom against DTensor block-diagonal version. + + In the serial pipeline, token_to_rep_atom is (N_tokens, N_atoms) where entry [t, a] + is nonzero iff atom a is the representative atom for token t. In the distributed + pipeline, it is stored as diagonal blocks: shard i holds the slice for its token and + atom ranges, padded to uniform per-shard dimensions. + + V2 adds two padding layers that affect this comparison: + 1) featurizer padding (per-sample token count rounded up to CP divisibility) + 2) collation padding (batched per-shard token count rounded up across samples/ranks) + + This helper handles those layers directly, so callers can pass unmodified serial + tensors. In particular, serial row boundaries are clipped against cumsum bounds and + atom-empty shards are skipped. + + Parameters + ---------- + ref_tensor : torch.Tensor + Serial token_to_rep_atom. Shape: (B, N_tokens_serial, N_atoms_serial). + dtensor_full : torch.Tensor + DTensor token_to_rep_atom after full_tensor(). + Shape: (B, n_shards * N_tps_collated, max_atoms_per_shard_collated). + atom_pad_mask_ref : torch.Tensor + Serial atom padding mask. Shape: (B, N_atoms_serial). + token_pad_mask_ref : torch.Tensor + Serial token padding mask. Shape: (B, N_tokens_serial). + token_pad_mask_dtensor : torch.Tensor + DTensor token padding mask (after full_tensor). + Shape: (B, n_shards * N_tps_collated). + n_shards : int + Number of CP shards along mesh dim 0 (the token-sharding axis). + atom_counts_per_token : torch.Tensor + Serial atom counts per token. Shape: (B, N_tokens_serial). + + """ + batch_size = ref_tensor.shape[0] + for b in range(batch_size): + ref_b = ref_tensor[b] + dtensor_b = dtensor_full[b] + token_mask_ref = token_pad_mask_ref[b].bool() + token_mask_dt = token_pad_mask_dtensor[b].bool() + atom_mask_ref = atom_pad_mask_ref[b].bool() + + N_tps_collated = dtensor_b.shape[0] // n_shards + + N_tps_featurizer = int(token_mask_dt[0:N_tps_collated].sum().item()) + + cumsum = torch.cat( + [torch.tensor([0], device=atom_counts_per_token.device), atom_counts_per_token[b].cumsum(dim=0)] + ) + + for shard_i in range(n_shards): + ref_row_start = shard_i * N_tps_featurizer + ref_row_end = (shard_i + 1) * N_tps_featurizer + ref_row_start_clipped = min(ref_row_start, cumsum.shape[0] - 1) + ref_row_end_clipped = min(ref_row_end, cumsum.shape[0] - 1) + ref_atom_start = int(cumsum[ref_row_start_clipped].item()) + ref_atom_end = int(cumsum[ref_row_end_clipped].item()) + n_atoms_in_shard = ref_atom_end - ref_atom_start + + if n_atoms_in_shard == 0: + continue + + ref_block = ref_b[ref_row_start:ref_row_end, ref_atom_start:ref_atom_end] + + ref_tok_mask = token_mask_ref[ref_row_start:ref_row_end] + ref_atom_mask = atom_mask_ref[ref_atom_start:ref_atom_end] + ref_unpadded = ref_block[ref_tok_mask][:, ref_atom_mask] + + dt_row_start = shard_i * N_tps_collated + dt_row_end = (shard_i + 1) * N_tps_collated + dt_block = dtensor_b[dt_row_start:dt_row_end, :n_atoms_in_shard] + + dt_tok_mask = token_mask_dt[dt_row_start:dt_row_end] + dt_unpadded = dt_block[dt_tok_mask][:, ref_atom_mask] + + torch.testing.assert_close( + dt_unpadded, + ref_unpadded, + atol=0, + rtol=0, + msg=f"token_to_rep_atom mismatch at batch {b}, shard {shard_i}", + ) + + +def _compare_features( + common_keys, + skip_keys, + data_batch_serial, + data_batch_dtensor, + dp_idx_str, + dp_idx_end, + n_shards, + rank, + manager, +): + """Compare serial and DTensor features across all common keys. + + Extracts masks, applies per-feature-type masking, and compares values. + Collects all errors and reports them together. Guards against vacuous + comparisons by requiring at least one feature with non-trivial data. + """ + atom_pad_mask_dtensor_full = data_batch_dtensor["atom_pad_mask"].full_tensor()[dp_idx_str:dp_idx_end].bool() + token_pad_mask_dtensor_full = data_batch_dtensor["token_pad_mask"].full_tensor()[dp_idx_str:dp_idx_end].bool() + msa_mask_dtensor_full = data_batch_dtensor["msa_mask"].full_tensor()[dp_idx_str:dp_idx_end].bool() + atom_pad_mask_serial = data_batch_serial["atom_pad_mask"].bool() + token_pad_mask_serial = data_batch_serial["token_pad_mask"].bool() + msa_mask_serial = data_batch_serial["msa_mask"].bool() + + token_pad_pair_mask_dtensor_full = token_pad_mask_dtensor_full[:, :, None] * token_pad_mask_dtensor_full[:, None, :] + token_pad_pair_mask_serial = token_pad_mask_serial[:, :, None] * token_pad_mask_serial[:, None, :] + + errors = [] + any_nonempty_data = False + + for key in common_keys: + if key in skip_keys: + continue + + feature_serial = data_batch_serial[key] + feature_dtensor = data_batch_dtensor[key] + if not isinstance(feature_serial, torch.Tensor) or not isinstance(feature_dtensor, DTensor): + continue + + if key == "r_set_to_rep_atom": + compare_r_set_to_rep_atom( + ref_tensor=feature_serial, + dtensor_full=feature_dtensor.full_tensor()[dp_idx_str:dp_idx_end], + atom_pad_mask_dtensor=atom_pad_mask_dtensor_full, + n_shards=n_shards, + ) + any_nonempty_data = True + continue + elif key == "token_to_rep_atom": + compare_token_to_rep_atom( + ref_tensor=feature_serial, + dtensor_full=feature_dtensor.full_tensor()[dp_idx_str:dp_idx_end], + atom_pad_mask_ref=atom_pad_mask_serial, + token_pad_mask_ref=token_pad_mask_serial, + token_pad_mask_dtensor=token_pad_mask_dtensor_full, + n_shards=n_shards, + atom_counts_per_token=data_batch_serial["atom_counts_per_token"], + ) + any_nonempty_data = True + continue + + feature_dtensor_full = None + feature_serial_full = None + try: + if key in ENSEMBLE_ATOM_FEATURES: + # (B, E, A, ...) — expand atom mask with ensemble dim + dt_full = feature_dtensor.full_tensor()[dp_idx_str:dp_idx_end] + mask_dt = atom_pad_mask_dtensor_full.unsqueeze(1).expand(-1, dt_full.shape[1], -1) + mask_serial = atom_pad_mask_serial.unsqueeze(1).expand(-1, feature_serial.shape[1], -1) + feature_dtensor_full = dt_full[mask_dt] + feature_serial_full = feature_serial[mask_serial] + elif key in ENSEMBLE_TOKEN_FEATURES: + # (B, E, T, ...) — expand token mask with ensemble dim + dt_full = feature_dtensor.full_tensor()[dp_idx_str:dp_idx_end] + mask_dt = token_pad_mask_dtensor_full.unsqueeze(1).expand(-1, dt_full.shape[1], -1) + mask_serial = token_pad_mask_serial.unsqueeze(1).expand(-1, feature_serial.shape[1], -1) + feature_dtensor_full = dt_full[mask_dt] + feature_serial_full = feature_serial[mask_serial] + elif key in TOKEN_PAIR_FEATURES: + feature_dtensor_full = feature_dtensor.full_tensor()[dp_idx_str:dp_idx_end][ + token_pad_pair_mask_dtensor_full + ] + feature_serial_full = feature_serial[token_pad_pair_mask_serial] + elif key in MSA_FEATURES: + feature_dtensor_full = feature_dtensor.full_tensor()[dp_idx_str:dp_idx_end][msa_mask_dtensor_full] + feature_serial_full = feature_serial[msa_mask_serial] + elif key == "frame_resolved_mask": + dt_full = feature_dtensor.full_tensor()[dp_idx_str:dp_idx_end] + if dt_full.ndim != 3: + raise ValueError( + f"frame_resolved_mask must be ensemble-aware with ndim=3 (B, E, T), got {dt_full.ndim}" + ) + mask_dt = token_pad_mask_dtensor_full.unsqueeze(1).expand(-1, dt_full.shape[1], -1) + mask_serial = token_pad_mask_serial.unsqueeze(1).expand(-1, feature_serial.shape[1], -1) + feature_dtensor_full = dt_full[mask_dt] + feature_serial_full = feature_serial[mask_serial] + elif key == "frames_idx": + frame_resolved_mask_dtensor_full = data_batch_dtensor["frame_resolved_mask"].full_tensor()[ + dp_idx_str:dp_idx_end + ] + frames_idx_dtensor_unpadded = _map_padded_frames_idx_to_unpadded( + feature_dtensor, + dp_idx_str, + dp_idx_end, + data_batch_dtensor["atom_to_token"], + data_batch_dtensor["frame_resolved_mask"], + ) + _compare_frames_idx_valid_tokens( + feature_serial, + frames_idx_dtensor_unpadded, + data_batch_serial["frame_resolved_mask"], + frame_resolved_mask_dtensor_full, + token_pad_mask_serial, + token_pad_mask_dtensor_full, + ) + continue + elif key in ATOM_FEATURES: + feature_dtensor_full = feature_dtensor.full_tensor()[dp_idx_str:dp_idx_end][atom_pad_mask_dtensor_full] + feature_serial_full = feature_serial[atom_pad_mask_serial] + else: + feature_dtensor_full = feature_dtensor.full_tensor()[dp_idx_str:dp_idx_end][token_pad_mask_dtensor_full] + feature_serial_full = feature_serial[token_pad_mask_serial] + + if feature_dtensor_full.numel() > 0: + if feature_dtensor_full.is_floating_point(): + if feature_dtensor_full.any(): + any_nonempty_data = True + else: + if feature_dtensor_full.unique().numel() > 1: + any_nonempty_data = True + + torch.testing.assert_close( + feature_dtensor_full, + feature_serial_full, + atol=0, + rtol=0, + ) + except Exception as e: + dt_shape = feature_dtensor_full.shape if feature_dtensor_full is not None else "N/A" + serial_shape = feature_serial_full.shape if feature_serial_full is not None else "N/A" + errors.append( + f"[{key}] rank {rank} cp_rank {manager.group_rank['cp']}: " + f"dtensor shape {dt_shape} vs serial shape {serial_shape} " + f"| {e}" + ) + + if errors: + raise AssertionError(f"Shape/value mismatches on rank {rank} ({len(errors)} failures):\n" + "\n".join(errors)) + + assert any_nonempty_data, ( + f"All compared features were empty or zero on rank {rank} -- " + f"likely a padding/masking bug producing vacuous comparisons" + ) + + +def parallel_assert_cp_inference_dataloader( + rank: int, + processed, + canonical_mols_dir: Path, + local_batch_size: int, + device_type, + backend, + grid_group_sizes: Dict[str, int], + env_map: Optional[dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + # TODO VI: hardcode device_type and backend for now. Do we need GPU? + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + # create a CPU process groups for comm python objects + DistributedManager.create_group( + "world_cpu", manager.group_ranks["world"], backend="gloo", use_local_synchronization=True + ) + DistributedManager.create_group("cp_cpu", manager.group_ranks["cp"], backend="gloo", use_local_synchronization=True) + + device_mesh = manager.device_mesh_subgroups + cp_device_mesh = map_subgroup_mesh_to_cpu(manager) + assert device_mesh.shape == cp_device_mesh.shape + dp_rank = device_mesh.get_local_rank(0) + + # Serial dataloader + seed_by_rank(0, seed=42) + data_module_serial = Boltz2InferenceDataModule( + manifest=processed.manifest, + target_dir=processed.targets_dir, + msa_dir=processed.msa_dir, + mol_dir=canonical_mols_dir, + num_workers=0, + ) + dataset = PredictionDataset( + manifest=processed.manifest, + target_dir=processed.targets_dir, + msa_dir=processed.msa_dir, + mol_dir=canonical_mols_dir, + ) + sampler = DistributedSampler( + dataset, + num_replicas=device_mesh.shape[0], + rank=device_mesh.get_local_rank(0), + shuffle=False, + drop_last=False, + ) + dataloader_serial = DataLoader( + dataset, + sampler=sampler, + batch_size=local_batch_size, + num_workers=0, + shuffle=False, + collate_fn=collate, + ) + data_batch_list_serial = list(dataloader_serial) + + # DTensor CP dataloader + seed_by_rank(0, seed=42) + data_module_dtensor = BoltzInferenceDataModuleDTensor( + manifest=processed.manifest, + target_dir=processed.targets_dir, + msa_dir=processed.msa_dir, + mol_dir=canonical_mols_dir, + num_workers=0, + device_mesh=device_mesh, + device_mesh_cpu=cp_device_mesh, + local_batch_size=local_batch_size, + pair_mask_mode=PairMaskMode.SEQUENCE_LOCAL_ATTENTION, + ) + data_module_dtensor.setup("predict") + dataloader_dtensor = data_module_dtensor.predict_dataloader() + data_batch_list_dtensor = list(dataloader_dtensor) + + skip_keys = { + "atom_pad_mask", + "token_pad_mask", + "msa_mask", + "token_pair_mask", + "atom_to_token", + } | NON_SHARDED_FEATURES_V2 + + n_shards = device_mesh.get_group("cp_axis_0").size() + + for data_batch_serial, data_batch_dtensor in zip(data_batch_list_serial, data_batch_list_dtensor): + serial_keys = set(data_batch_serial.keys()) + dtensor_keys = set(data_batch_dtensor.keys()) + assert (serial_keys - INFERENCE_FEATURES_DIFFERENCE) == (dtensor_keys - INFERENCE_FEATURES_DIFFERENCE), ( + f"Feature keys are different: {serial_keys - INFERENCE_FEATURES_DIFFERENCE} " + f"!= {dtensor_keys - INFERENCE_FEATURES_DIFFERENCE}" + ) + common_keys = sorted(serial_keys & dtensor_keys) + + data_batch_serial = data_module_serial.transfer_batch_to_device(data_batch_serial, manager.device, 0) + data_batch_dtensor = data_module_dtensor.transfer_batch_to_device(data_batch_dtensor, manager.device, 0) + + dp_idx_str = dp_rank * local_batch_size + dp_idx_end = dp_idx_str + local_batch_size + + _compare_features( + common_keys=common_keys, + skip_keys=skip_keys, + data_batch_serial=data_batch_serial, + data_batch_dtensor=data_batch_dtensor, + dp_idx_str=dp_idx_str, + dp_idx_end=dp_idx_end, + n_shards=n_shards, + rank=rank, + manager=manager, + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, device_type={x[2]}", +) +def test_dtensor_cp_inference_dataloader( + setup_env: tuple[dict, int, str, str, str, dict[str, str]], + test_cp_training_base_data_dir_boltz2: Path, + canonical_mols_dir: Path, + tmp_path: Path, + local_batch_size: int = 2, +): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + # Create combined dataset with multiple samples to support local_batch_size > 1 + names = ["7ylz", "7z64", "8ayv", "8b2e"] + combined_dataset_path = concat_data( + tmp_path / "processed_combined", + *[test_cp_training_base_data_dir_boltz2 / f"processed_{name}" for name in names], + ) + + # Check if we have enough samples for the requested batch size + total_samples = len(names) + if total_samples < local_batch_size * grid_group_sizes["dp"]: + pytest.skip( + f"Not enough samples ({total_samples}) for local_batch_size * dp_size ({local_batch_size * grid_group_sizes['dp']})" + ) + + # Create a new BoltzProcessedInput with the combined dataset + combined_processed_handle = BoltzProcessedInput( + manifest=Manifest.load(combined_dataset_path / "manifest.json"), + targets_dir=combined_dataset_path / "structures", + msa_dir=combined_dataset_path / "msa", + template_dir=None, + extra_mols_dir=None, + ) + + spawn_multiprocessing( + parallel_assert_cp_inference_dataloader, + world_size, + combined_processed_handle, + canonical_mols_dir, + local_batch_size, + device_type, + backend, + grid_group_sizes, + env_per_rank, + ) + + +def parallel_assert_atom_to_token_sharding( + rank: int, + processed, + canonical_mols_dir: Path, + local_batch_size: int, + device_type, + backend, + grid_group_sizes: Dict[str, int], + env_map: Optional[dict[str, str]] = None, +): + """Test that sharded atom_to_token operations produce the same results as non-sharded after removing padding.""" + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + # Initialize distributed environment + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + DistributedManager.create_group( + "world_cpu", manager.group_ranks["world"], backend="gloo", use_local_synchronization=True + ) + DistributedManager.create_group("cp_cpu", manager.group_ranks["cp"], backend="gloo", use_local_synchronization=True) + + device_mesh = manager.device_mesh_subgroups + cp_device_mesh = map_subgroup_mesh_to_cpu(manager) + dp_rank = device_mesh.get_local_rank(0) + cp_rank = manager.group_rank["cp"] + + # Set up data modules + seed_by_rank(0, seed=42) + data_module_serial = Boltz2InferenceDataModule( + manifest=processed.manifest, + target_dir=processed.targets_dir, + msa_dir=processed.msa_dir, + mol_dir=canonical_mols_dir, + num_workers=0, + ) + dataset = PredictionDataset( + manifest=processed.manifest, + target_dir=processed.targets_dir, + msa_dir=processed.msa_dir, + mol_dir=canonical_mols_dir, + ) + + data_module_dtensor = BoltzInferenceDataModuleDTensor( + manifest=processed.manifest, + target_dir=processed.targets_dir, + msa_dir=processed.msa_dir, + mol_dir=canonical_mols_dir, + num_workers=0, + device_mesh=device_mesh, + device_mesh_cpu=cp_device_mesh, + local_batch_size=local_batch_size, + ) + data_module_dtensor.setup("predict") + + # Get data loaders + sampler = DistributedSampler( + dataset, + num_replicas=device_mesh.shape[0], + rank=device_mesh.get_local_rank(0), + shuffle=False, + drop_last=False, + ) + dataloader_serial = DataLoader( + dataset, + sampler=sampler, + batch_size=local_batch_size, + num_workers=0, + shuffle=False, + collate_fn=collate, + ) + dataloader_dtensor = data_module_dtensor.predict_dataloader() + + # Process batches + for data_batch_serial, data_batch_dtensor in zip(dataloader_serial, dataloader_dtensor): + data_batch_serial = data_module_serial.transfer_batch_to_device(data_batch_serial, manager.device, 0) + data_batch_dtensor = data_module_dtensor.transfer_batch_to_device(data_batch_dtensor, manager.device, 0) + + # Get masks for removing padding + dp_idx_str = dp_rank * local_batch_size + dp_idx_end = dp_idx_str + local_batch_size + + atom_pad_mask_dtensor_full = data_batch_dtensor["atom_pad_mask"].full_tensor()[dp_idx_str:dp_idx_end].bool() + token_pad_mask_dtensor_full = data_batch_dtensor["token_pad_mask"].full_tensor()[dp_idx_str:dp_idx_end].bool() + atom_pad_mask_serial = data_batch_serial["atom_pad_mask"].bool() + token_pad_mask_serial = data_batch_serial["token_pad_mask"].bool() + + # Get atom_to_token matrices + atom_to_token_serial = data_batch_serial["atom_to_token"] + atom_to_token_dtensor = data_batch_dtensor["atom_to_token"] + + # Test 1: Single representation token-to-atom + token_single_repr_serial = data_batch_serial["residue_index"].float() + token_single_repr_dtensor = data_batch_dtensor["residue_index"].float() + + atom_single_repr_serial = torch.einsum( + "bij,bj...->bi...", atom_to_token_serial.float(), token_single_repr_serial + ) + atom_single_repr_dtensor = single_repr_token_to_atom(token_single_repr_dtensor, atom_to_token_dtensor) + + # Remove padding and compare + for local_sample_idx in range(local_batch_size): + atom_serial_sample = atom_single_repr_serial[local_sample_idx] + atom_serial_sample = atom_serial_sample[atom_pad_mask_serial[local_sample_idx]] + + atom_dtensor_full = atom_single_repr_dtensor.full_tensor()[dp_idx_str:dp_idx_end] + atom_dtensor_sample = atom_dtensor_full[local_sample_idx].float() + atom_dtensor_sample = atom_dtensor_sample[atom_pad_mask_dtensor_full[local_sample_idx]] + + torch.testing.assert_close( + atom_dtensor_sample, + atom_serial_sample, + atol=1e-6, + rtol=1e-6, + msg=f"Single representation token-to-atom mismatch on rank {cp_rank}, sample {local_sample_idx}", + ) + + # Test 2: Single representation atom-to-token + atom_single_repr_serial = data_batch_serial["ref_charge"].float() + atom_single_repr_dtensor = data_batch_dtensor["ref_charge"] + + # Non-sharded operation (with normalization) + atom_to_token_normalized = atom_to_token_serial.float() / (atom_to_token_serial.sum(dim=1, keepdim=True) + 1e-6) + token_single_repr_serial = torch.einsum("bji,bj...->bi...", atom_to_token_normalized, atom_single_repr_serial) + token_single_repr_dtensor = single_repr_atom_to_token(atom_single_repr_dtensor.float(), atom_to_token_dtensor) + + # Remove padding and compare + for local_sample_idx in range(local_batch_size): + token_serial_sample = token_single_repr_serial[local_sample_idx] + token_serial_sample = token_serial_sample[token_pad_mask_serial[local_sample_idx]] + + token_dtensor_full = token_single_repr_dtensor.full_tensor()[dp_idx_str:dp_idx_end] + token_dtensor_sample = token_dtensor_full[local_sample_idx].float() + token_dtensor_sample = token_dtensor_sample[token_pad_mask_dtensor_full[local_sample_idx]] + + torch.testing.assert_close( + token_dtensor_sample, + token_serial_sample, + atol=1e-6, + rtol=1e-6, + msg=f"Single representation atom-to-token mismatch on rank {cp_rank}, sample {local_sample_idx}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [((1, (2, 2)), True, "cuda", "ENV")], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_atom_to_token_sharding_consistency( + setup_env: tuple[dict, int, str, str, str, dict[str, str]], + test_cp_training_base_data_dir_boltz2: Path, + canonical_mols_dir: Path, + tmp_path: Path, + local_batch_size: int = 2, +): + """Test that sharded atom_to_token operations produce the same results as non-sharded after removing padding. + + This test verifies the sharding strategy described in atom_to_token.py where: + 1. Padding atoms/tokens are added to make dimensions divisible by cp_size + 2. The atom_to_token matrix becomes block diagonal with only diagonal blocks non-zero + 3. Row-wise broadcasting enables local computation on all ranks + 4. After removing padding, sharded and non-sharded results should be identical + + The test covers all three atom_to_token operations: + - Single representation token-to-atom: residue_index -> atom representation (torch.einsum) + - Single representation atom-to-token: ref_charge -> token representation (torch.einsum with normalization) + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + # Create combined dataset with multiple samples to support local_batch_size > 1 + names = ["7ylz", "7z64", "8ayv", "8b2e"] + combined_dataset_path = concat_data( + tmp_path / "processed_combined", + *[test_cp_training_base_data_dir_boltz2 / f"processed_{name}" for name in names], + ) + + # Check if we have enough samples for the requested batch size + total_samples = len(names) + if total_samples < local_batch_size: + pytest.skip( + f"Not enough samples ({total_samples}) for local_batch_size * dp_size ({local_batch_size * grid_group_sizes['dp']})" + ) + + # Create a new BoltzProcessedInput with the combined dataset + combined_processed_handle = BoltzProcessedInput( + manifest=Manifest.load(combined_dataset_path / "manifest.json"), + targets_dir=combined_dataset_path / "structures", + msa_dir=combined_dataset_path / "msa", + template_dir=None, + extra_mols_dir=None, + ) + + spawn_multiprocessing( + parallel_assert_atom_to_token_sharding, + world_size, + combined_processed_handle, + canonical_mols_dir, + local_batch_size, + device_type, + backend, + grid_group_sizes, + env_per_rank, + ) + + +def parallel_assert_cp_training_dataloader( + rank: int, + training_data_dir: Path, + canonical_mols_dir: Path, + device_type, + backend, + grid_group_sizes: Dict[str, int], + env_map: Optional[dict[str, str]] = None, + dataloader_kind: str = "train", +): + local_batch_size = 1 + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + DistributedManager.create_group( + "world_cpu", manager.group_ranks["world"], backend="gloo", use_local_synchronization=True + ) + DistributedManager.create_group("cp_cpu", manager.group_ranks["cp"], backend="gloo", use_local_synchronization=True) + + device_mesh = manager.device_mesh_subgroups + cp_device_mesh = map_subgroup_mesh_to_cpu(manager) + dp_rank = device_mesh.get_local_rank(0) + + # Serial reference dataloader + cfg_serial = setup_mock_training_datamodule_config(training_data_dir) + cfg_serial.batch_size = local_batch_size + cfg_serial.moldir = str(canonical_mols_dir) + if dataloader_kind == "val": + cfg_serial.overfit = 4 + else: + cfg_serial.samples_per_epoch = 2 * local_batch_size * device_mesh.shape[0] + seed_by_rank(0, seed=42) + data_module_serial = Boltz2TrainingDataModule(cfg=cfg_serial) + + if dataloader_kind == "val": + serial_set = data_module_serial._val_set + serial_batch_size = cfg_serial.val_batch_size + else: + serial_set = data_module_serial._train_set + serial_batch_size = cfg_serial.batch_size + + sampler_serial = DistributedSampler( + serial_set, + num_replicas=device_mesh.shape[0], + rank=device_mesh.get_local_rank(0), + shuffle=False, + drop_last=False, + ) + dataloader_serial = DataLoader( + serial_set, + sampler=sampler_serial, + batch_size=serial_batch_size, + num_workers=0, + pin_memory=False, + collate_fn=collate_training, + ) + data_batch_list_serial = list(dataloader_serial) + + # DTensor distributed dataloader + cfg_dtensor = setup_mock_training_datamodule_config(training_data_dir) + cfg_dtensor.batch_size = local_batch_size + cfg_dtensor.moldir = str(canonical_mols_dir) + if dataloader_kind == "val": + cfg_dtensor.overfit = 4 + else: + cfg_dtensor.samples_per_epoch = cfg_serial.samples_per_epoch + seed_by_rank(0, seed=42) + data_module_dtensor = BoltzTrainingDataModuleDTensor( + cfg=cfg_dtensor, + device_mesh=device_mesh, + device_mesh_cpu=cp_device_mesh, + ) + if dataloader_kind == "val": + dataloader_dtensor = data_module_dtensor.val_dataloader() + else: + dataloader_dtensor = data_module_dtensor.train_dataloader() + data_batch_list_dtensor = list(dataloader_dtensor) + + skip_keys = ( + { + "atom_pad_mask", + "token_pad_mask", + "msa_mask", + "atom_to_token", + "pair_mask", + } + | TRAINING_FEATURES_DIFFERENCE + | NON_SHARDED_FEATURES_V2 + ) + + n_shards = device_mesh.shape[1] + + for data_batch_serial, data_batch_dtensor in zip(data_batch_list_serial, data_batch_list_dtensor): + serial_keys = set(data_batch_serial.keys()) + dtensor_keys = set(data_batch_dtensor.keys()) + serial_keys_filtered = serial_keys - TRAINING_FEATURES_DIFFERENCE + dtensor_keys_filtered = dtensor_keys - TRAINING_FEATURES_DIFFERENCE + if serial_keys_filtered != dtensor_keys_filtered: + missing_in_dtensor = sorted(serial_keys_filtered - dtensor_keys_filtered) + extra_in_dtensor = sorted(dtensor_keys_filtered - serial_keys_filtered) + raise AssertionError( + "Feature keys are different between serial and DTensor training pipelines. " + f"rank={rank}, cp_rank={manager.group_rank['cp']}, " + f"serial_pdb_id={data_batch_serial.get('pdb_id')}, " + f"dtensor_pdb_id={data_batch_dtensor.get('pdb_id')}, " + f"missing_in_dtensor={missing_in_dtensor}, " + f"extra_in_dtensor={extra_in_dtensor}" + ) + common_keys = sorted(serial_keys & dtensor_keys) + + data_batch_serial = data_module_serial.transfer_batch_to_device(data_batch_serial, manager.device, 0) + data_batch_dtensor = data_module_dtensor.transfer_batch_to_device(data_batch_dtensor, manager.device, 0) + + dp_idx_str = dp_rank * local_batch_size + dp_idx_end = dp_idx_str + local_batch_size + + _compare_features( + common_keys=common_keys, + skip_keys=skip_keys, + data_batch_serial=data_batch_serial, + data_batch_dtensor=data_batch_dtensor, + dp_idx_str=dp_idx_str, + dp_idx_end=dp_idx_end, + n_shards=n_shards, + rank=rank, + manager=manager, + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (3, 3)), True, "cpu", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, device_type={x[2]}", +) +@pytest.mark.parametrize("dataloader_kind", ["train", "val"]) +def test_dtensor_cp_training_dataloader( + setup_env: tuple[dict, int, str, str, str, dict[str, str]], + test_cp_training_data_dir_boltz2: Path, + canonical_mols_dir: Path, + dataloader_kind: str, +): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + spawn_multiprocessing( + parallel_assert_cp_training_dataloader, + world_size, + test_cp_training_data_dir_boltz2, + canonical_mols_dir, + device_type, + backend, + grid_group_sizes, + env_per_rank, + dataloader_kind, + ) diff --git a/tests/distributed/test_dtensor_layernorm.py b/tests/distributed/test_dtensor_layernorm.py new file mode 100755 index 000000000..5d617ca1d --- /dev/null +++ b/tests/distributed/test_dtensor_layernorm.py @@ -0,0 +1,351 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import itertools +from math import isqrt +from typing import Dict, Optional + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_module, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.layernorm import LayerNormParamsReplicated +from boltz.testing.utils import assert_all_identical, seed_by_rank + + +def compute_global_expectation(batch_size, seq_len, features: int | list[int], device): + """Compute expected results using global tensors. + + Args: + batch_size: Batch size + seq_len: Sequence length + features: Feature dimensions + device: Device to place tensors on + + Returns: + tuple: (x_global, layernorm_state_dict, y_global_expected, + x_global_grad, weight_global_grad, bias_global_grad, dy_global) + """ + # Create global tensors with deterministic values for reproducibility + if isinstance(features, int): + features = [features] + x_global = torch.randn(batch_size, seq_len, *features, device=device, requires_grad=True) + layernorm = torch.nn.LayerNorm(features, device=device) + state_dict_layernorm = layernorm.state_dict() + + # Clone inputs for distribution + x_global_clone = x_global.detach().clone() + + # Compute on global tensors using standard layernorm operation + y_global_expected = layernorm(x_global) + + # Create gradients for backward pass + dy_global = torch.rand_like(y_global_expected) + + # Backward pass on global tensors + y_global_expected.backward(dy_global) + + return ( + x_global_clone, + state_dict_layernorm, + y_global_expected.detach().clone(), + x_global.grad.detach().clone(), + layernorm.weight.grad.detach().clone(), + layernorm.bias.grad.detach().clone(), + dy_global.detach().clone(), + ) + + +def parallel_assert_dtensor_layernorm( + rank: int, + batch_size: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + """Test distributed layernorm operation in a parallel environment. + + This test validates that the LayerNormParamsReplicated produces identical results to + standard nn.functional.layer_norm operations with global tensors. It verifies: + + 1. Forward pass produces the same results as global tensor computation + 2. Backward pass correctly propagates gradients through the distributed operation + 3. Results and gradients match the equivalent global tensor operations + + Args: + rank: The process rank in the distributed environment + grid_group_sizes: Dictionary mapping group names to their sizes for distributed setup + device_type: Device to run the test on ("cpu" or "cuda") + backend: The distributed backend to use (e.g., "gloo", "nccl") + env_map: Optional dictionary of environment variables to set before initialization + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters + seq_len_per_rank = 4 + seq_len_global = size_ring * seq_len_per_rank + features = [3, 2] + + # Set random seed based on rank for reproducibility + seed_by_rank(0) + + # Compute global expectations + ( + x_global, + state_dict_layernorm, + y_global_expected, + x_global_grad, + weight_global_grad, + bias_global_grad, + dy_global, + ) = compute_global_expectation(batch_size, seq_len_global, features, manager.device) + + # Create distributed tensors + # Shard the sequence dimension (dim=1) for input tensor + # this emulates the sharded single representation in the Boltz model + input_placements = [Shard(dim=0), Shard(dim=1), Replicate()] + x_dtensor = distribute_tensor(x_global, manager.device_mesh_subgroups, input_placements) + x_dtensor.requires_grad = True + + layernorm_local = torch.nn.LayerNorm(features, device=manager.device) + layernorm_local.load_state_dict(state_dict_layernorm) + layer = LayerNormParamsReplicated(layernorm_local, manager.device_mesh_subgroups) + + layernorm_local_copy = torch.nn.LayerNorm(features, device=manager.device) + layernorm_local_copy.load_state_dict(state_dict_layernorm) + layer_dtensor_native = distribute_module(layernorm_local_copy, manager.device_mesh_subgroups) + x_dtensor_native = x_dtensor.detach().clone().requires_grad_(True) + + # Compute on distributed tensors using LayerNormParamsReplicated + y_dtensor_result = layer(x_dtensor) + y_dtensor_native = layer_dtensor_native(x_dtensor_native) + + # Distribute the upstream adjoint for backward pass + dy_dtensor = distribute_tensor(dy_global, manager.device_mesh_subgroups, input_placements) + dy_dtensor_native = distribute_tensor(dy_global.detach().clone(), manager.device_mesh_subgroups, input_placements) + + # Perform backward pass + y_dtensor_result.backward(dy_dtensor) + y_dtensor_native.backward(dy_dtensor_native) + + # Create distributed tensors from global gradients for comparison + x_grad_dtensor_expected = distribute_tensor(x_global_grad, manager.device_mesh_subgroups, input_placements) + y_dtensor_expected = distribute_tensor(y_global_expected, manager.device_mesh_subgroups, input_placements) + weight_grad_dtensor_expected = distribute_tensor( + weight_global_grad, manager.device_mesh_subgroups, layer.weight.placements + ) + bias_grad_dtensor_expected = distribute_tensor( + bias_global_grad, manager.device_mesh_subgroups, layer.bias.placements + ) + + # Compare results with expected local shards + torch.testing.assert_close(y_dtensor_expected, y_dtensor_result) + torch.testing.assert_close(x_grad_dtensor_expected, x_dtensor.grad) + torch.testing.assert_close(weight_grad_dtensor_expected, layer.weight.grad) + torch.testing.assert_close(bias_grad_dtensor_expected, layer.bias.grad) + + # Compare results with native DTensor implementation + assert y_dtensor_result.shape == y_dtensor_native.shape + assert y_dtensor_result.stride() == y_dtensor_native.stride() + assert x_dtensor.grad.shape == x_grad_dtensor_expected.shape + assert x_dtensor.grad.stride() == x_grad_dtensor_expected.stride() + assert layer.weight.grad.shape == weight_grad_dtensor_expected.shape + assert layer.weight.grad.stride() == weight_grad_dtensor_expected.stride() + assert layer.bias.grad.shape == bias_grad_dtensor_expected.shape + assert layer.bias.grad.stride() == bias_grad_dtensor_expected.stride() + + torch.testing.assert_close(y_dtensor_native, y_dtensor_result) + torch.testing.assert_close(x_dtensor_native.grad, x_dtensor.grad) + torch.testing.assert_close(layer_dtensor_native.weight.grad, layer.weight.grad) + torch.testing.assert_close(layer_dtensor_native.bias.grad, layer.bias.grad) + + # Collect results as global tensors and compare with original global tensors + y_global_result = y_dtensor_result.full_tensor() + x_grad_global_result = x_dtensor.grad.full_tensor() + weight_grad_global_result = layer.weight.grad.full_tensor() + bias_grad_global_result = layer.bias.grad.full_tensor() + + # Assert output and input gradients match the global computation + torch.testing.assert_close(y_global_result, y_global_expected) + torch.testing.assert_close(x_grad_global_result, x_global_grad) + torch.testing.assert_close(weight_grad_global_result, weight_global_grad) + torch.testing.assert_close(bias_grad_global_result, bias_global_grad) + + # assert the parameter gradients are identical across all ranks + assert_all_identical(weight_grad_global_result, manager.group["cp"]) + assert_all_identical(bias_grad_global_result, manager.group["cp"]) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_dtensor_layernorm(setup_env: tuple[dict, int, str, str, str, dict[str, str]]): + """Test distributed layernorm operation across multiple processes. + + This parametrized test launches multiple processes to test the LayerNormParamsReplicated + with different configurations. It verifies that layernorm operations work correctly + in a distributed setting across various: + + - Data parallel (dp) and compute parallel (cp) group sizes + - Device types (CPU/CUDA) + - Initialization methods + + The test ensures operations on distributed tensors produce results identical + to equivalent operations on global tensors, validating the correctness of the + LayerNormParamsReplicated implementation. + + Args: + setup_env: Fixture providing the distributed environment configuration + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + batch_size = 2 * grid_group_sizes["dp"] + + torch.multiprocessing.set_start_method("spawn", force=True) + torch.multiprocessing.spawn( + fn=parallel_assert_dtensor_layernorm, + args=( + batch_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ), + nprocs=world_size, + join=True, + ) + + +def parallel_assert_dtensor_layernorm_raise_uneven_sharding( + rank: int, + batch_size: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters + seq_len_per_rank = 4 + seq_len_global = size_ring * seq_len_per_rank + features = [3, 2] + + x_global = torch.ones(batch_size, seq_len_global, *features, device=manager.device) + + # Create distributed tensors + # Shard the sequence dimension (dim=1) for input tensor + # this emulates the sharded single representation in the Boltz model + input_placements = [Shard(dim=0), Shard(dim=1), Replicate()] + x_dtensor = distribute_tensor(x_global, manager.device_mesh_subgroups, input_placements) + x_dtensor.requires_grad = True + + layernorm_local = torch.nn.LayerNorm(features, device=manager.device) + layer = LayerNormParamsReplicated(layernorm_local, manager.device_mesh_subgroups) + + # should raise here + with pytest.raises( + ValueError, + match="Uneven sharding tensor dimension 0 of size 3 along device mesh dimension 0 of size 2 is not supported", + ): + _ = layer(x_dtensor) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + itertools.product([(2, (1, 1))], [True], ["cpu"], ["ENV"]), + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_dtensor_layernorm_raise_on_uneven_sharding(setup_env: tuple[dict, int, str, str, str, dict[str, str]]): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + batch_size = 3 + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + torch.multiprocessing.set_start_method("spawn", force=True) + torch.multiprocessing.spawn( + fn=parallel_assert_dtensor_layernorm_raise_uneven_sharding, + args=( + batch_size, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ), + nprocs=world_size, + join=True, + ) diff --git a/tests/distributed/test_dtensor_linear.py b/tests/distributed/test_dtensor_linear.py new file mode 100644 index 000000000..d44d4685c --- /dev/null +++ b/tests/distributed/test_dtensor_linear.py @@ -0,0 +1,429 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +from math import isqrt +from typing import Dict, Optional + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_module, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.linear import LinearParamsReplicated +from boltz.testing.utils import assert_all_identical, seed_by_rank + + +def compute_global_expectation(batch_size, seq_len, in_features, out_features, device): + """Compute expected results using global tensors. + + Args: + batch_size: Batch size + seq_len: Sequence length + in_features: Input feature dimension + out_features: Output feature dimension + device: Device to place tensors on + + Returns: + tuple: (x_global, weight_global, bias_global, y_global_expected, + x_global_grad, weight_global_grad, bias_global_grad, dy_global) + """ + # Create global tensors with deterministic values for reproducibility + x_global = torch.randn(batch_size, seq_len, in_features, device=device, requires_grad=True) + linear = torch.nn.Linear(in_features, out_features, device=device) + state_dict_linear = linear.state_dict() + + # Clone inputs for distribution + x_global_clone = x_global.detach().clone() + + # Compute on global tensors using standard linear operation + y_global_expected = linear(x_global) + + # Create gradients for backward pass + dy_global = torch.rand_like(y_global_expected) + + # Backward pass on global tensors + y_global_expected.backward(dy_global) + + return ( + x_global_clone, + state_dict_linear, + y_global_expected.detach().clone(), + x_global.grad.detach().clone(), + linear.weight.grad.detach().clone(), + linear.bias.grad.detach().clone(), + dy_global.detach().clone(), + ) + + +def parallel_assert_dtensor_linear( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + compare_to_native_dtensor: bool, + env_map: Optional[Dict[str, str]] = None, +): + """Test distributed linear operation in a parallel environment. + + This test validates that the LinearParamsReplicatedImpl produces identical results to + standard nn.functional.linear operations with global tensors. It verifies: + + 1. Forward pass produces the same results as global tensor computation + 2. Backward pass correctly propagates gradients through the distributed operation + 3. Results and gradients match the equivalent global tensor operations + + Args: + rank: The process rank in the distributed environment + grid_group_sizes: Dictionary mapping group names to their sizes for distributed setup + device_type: Device to run the test on ("cpu" or "cuda") + backend: The distributed backend to use (e.g., "gloo", "nccl") + env_map: Optional dictionary of environment variables to set before initialization + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + try: + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + # Set test parameters + batch_size = 2 + seq_len_per_rank = 4 + seq_len_global = size_ring * seq_len_per_rank + in_features = 3 + out_features = 5 + + # Set random seed based on rank for reproducibility + seed_by_rank(0) + + # Compute global expectations + ( + x_global, + state_dict_linear, + y_global_expected, + x_global_grad, + weight_global_grad, + bias_global_grad, + dy_global, + ) = compute_global_expectation(batch_size, seq_len_global, in_features, out_features, manager.device) + + # Create distributed tensors + # Shard the sequence dimension (dim=1) for input tensor + # this emulates the sharded single representation in the Boltz model + input_placements = [Shard(dim=0), Shard(dim=1), Replicate()] + x_dtensor = distribute_tensor(x_global, manager.device_mesh_subgroups, input_placements) + x_dtensor.requires_grad = True + + linear_local = torch.nn.Linear(in_features, out_features, device=manager.device) + linear_local.load_state_dict(state_dict_linear) + layer = LinearParamsReplicated(linear_local, manager.device_mesh_subgroups) + + linear_local_copy = torch.nn.Linear(in_features, out_features, device=manager.device) + linear_local_copy.load_state_dict(state_dict_linear) + layer_dtensor_native = distribute_module( + linear_local_copy, + manager.device_mesh_subgroups, + output_fn=lambda module, outputs, device_mesh: outputs.redistribute(device_mesh, input_placements), + ) + x_dtensor_native = x_dtensor.detach().clone().requires_grad_(True) + + # Compute on distributed tensors using LinearParamsReplicatedImpl + y_dtensor_result = layer(x_dtensor) + + # Distribute the upstream adjoint for backward pass + dy_dtensor = distribute_tensor(dy_global, manager.device_mesh_subgroups, input_placements) + + # Perform backward pass + y_dtensor_result.backward(dy_dtensor) + + y_dtensor_native = None + if compare_to_native_dtensor: + y_dtensor_native = layer_dtensor_native(x_dtensor_native) + dy_dtensor_native = distribute_tensor( + dy_global.detach().clone(), manager.device_mesh_subgroups, input_placements + ) + y_dtensor_native.backward(dy_dtensor_native) + + # Create distributed tensors from global gradients for comparison + x_grad_dtensor_expected = distribute_tensor(x_global_grad, manager.device_mesh_subgroups, input_placements) + y_dtensor_expected = distribute_tensor(y_global_expected, manager.device_mesh_subgroups, input_placements) + weight_grad_dtensor_expected = distribute_tensor( + weight_global_grad, manager.device_mesh_subgroups, layer.weight.data.placements + ) + bias_grad_dtensor_expected = distribute_tensor( + bias_global_grad, manager.device_mesh_subgroups, layer.bias.data.placements + ) + + # Compare results with expected local shards + torch.testing.assert_close(y_dtensor_expected, y_dtensor_result) + torch.testing.assert_close(x_grad_dtensor_expected, x_dtensor.grad) + torch.testing.assert_close(weight_grad_dtensor_expected, layer.weight.grad) + torch.testing.assert_close(bias_grad_dtensor_expected, layer.bias.grad) + + if compare_to_native_dtensor: + # Compare results with native DTensor implementation + assert y_dtensor_result.shape == y_dtensor_native.shape + assert y_dtensor_result.stride() == y_dtensor_native.stride() + assert x_dtensor.grad.shape == x_grad_dtensor_expected.shape + assert x_dtensor.grad.stride() == x_grad_dtensor_expected.stride() + assert layer.weight.grad.shape == weight_grad_dtensor_expected.shape + assert layer.weight.grad.stride() == weight_grad_dtensor_expected.stride() + assert layer.bias.grad.shape == bias_grad_dtensor_expected.shape + assert layer.bias.grad.stride() == bias_grad_dtensor_expected.stride() + + torch.testing.assert_close(y_dtensor_native, y_dtensor_result) + torch.testing.assert_close(x_dtensor_native.grad, x_dtensor.grad) + torch.testing.assert_close(layer_dtensor_native.weight.grad, layer.weight.grad) + torch.testing.assert_close(layer_dtensor_native.bias.grad, layer.bias.grad) + + # Collect results as global tensors and compare with original global tensors + y_global_result = y_dtensor_result.full_tensor() + x_grad_global_result = x_dtensor.grad.full_tensor() + weight_grad_global_result = layer.weight.grad.full_tensor() + bias_grad_global_result = layer.bias.grad.full_tensor() + + # Assert output and input gradients match the global computation + torch.testing.assert_close(y_global_result, y_global_expected) + torch.testing.assert_close(x_grad_global_result, x_global_grad) + torch.testing.assert_close(weight_grad_global_result, weight_global_grad) + torch.testing.assert_close(bias_grad_global_result, bias_global_grad) + + # assert the gradients are identical across all ranks + assert_all_identical(weight_grad_global_result, manager.group["cp"]) + assert_all_identical(bias_grad_global_result, manager.group["cp"]) + finally: + DistributedManager.cleanup() + monkeypatch.undo() + + +def parallel_assert_dtensor_linear_bf16_gradient_promotion( + rank: int, + grid_group_sizes, + device_type: str, + backend: str, + upstream_grad_dtype: torch.dtype = torch.float32, + avg_over_replicate_param_grad: bool = True, + env_map=None, +): + """Regression for bf16-mixed backward in ``LinearParamsReplicated``. + + Reproduces the mixed-dtype path used by bf16-mixed training: + - bf16 activations/input shards + - fp32 replicated parameters + - upstream gradient dtype is controlled by ``upstream_grad_dtype``: + * fp32: realistic when loss is computed in fp32 — autograd casts the + gradient to bf16 (matching the forward output) before custom backward + * bf16: the common case where the upstream layer also runs in bf16 + + Then verifies gradients match explicit promote-types reference math. + """ + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + try: + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + size_cp = len(manager.group_ranks["cp"]) + size_ring = isqrt(size_cp) + if size_ring * size_ring != size_cp: + raise ValueError(f"cp group size {size_cp} is not a square int") + + batch_size = 2 + seq_len_per_rank = 4 + seq_len_global = size_ring * seq_len_per_rank + in_features = 8 + out_features = 6 + + seed_by_rank(0) + input_placements = [Shard(dim=0), Shard(dim=1), Replicate()] + + # Simulate bf16 activations arriving at this layer. + x_global_fp32 = torch.randn(batch_size, seq_len_global, in_features, device=manager.device, dtype=torch.float32) + x_global_bf16 = x_global_fp32.to(torch.bfloat16) + + # Parameters stay fp32 in mixed precision training. + linear_ref = torch.nn.Linear(in_features, out_features, device=manager.device, dtype=torch.float32) + state_dict_ref = {k: v.detach().clone() for k, v in linear_ref.state_dict().items()} + weight_ref = state_dict_ref["weight"] + bias_ref = state_dict_ref["bias"] + + linear_local = torch.nn.Linear(in_features, out_features, device=manager.device, dtype=torch.float32) + linear_local.load_state_dict(state_dict_ref) + layer = LinearParamsReplicated( + linear_local, + manager.device_mesh_subgroups, + avg_over_replicate_param_grad=avg_over_replicate_param_grad, + ) + + x_dtensor = distribute_tensor(x_global_bf16, manager.device_mesh_subgroups, input_placements) + x_dtensor.requires_grad_(True) + + dy_upstream = torch.randn( + batch_size, seq_len_global, out_features, device=manager.device, dtype=upstream_grad_dtype + ) + + with torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=True): + y_dtensor = layer(x_dtensor) + y_expected = torch.nn.functional.linear(x_global_bf16, weight_ref, bias_ref) + assert y_dtensor.dtype == torch.bfloat16, f"Expected bf16 output under autocast, got {y_dtensor.dtype}" + torch.testing.assert_close(y_dtensor.full_tensor(), y_expected) + + dy_dtensor = distribute_tensor(dy_upstream, manager.device_mesh_subgroups, input_placements) + y_dtensor.backward(dy_dtensor) + + assert layer.weight.grad is not None, "Weight gradient should not be None" + assert layer.bias.grad is not None, "Bias gradient should not be None" + assert x_dtensor.grad is not None, "Input gradient should not be None" + + weight_grad_result = layer.weight.grad.full_tensor() + bias_grad_result = layer.bias.grad.full_tensor() + x_grad_result = x_dtensor.grad.full_tensor() + + assert weight_grad_result.dtype == torch.float32, f"Weight grad should be fp32, got {weight_grad_result.dtype}" + assert bias_grad_result.dtype == torch.float32, f"Bias grad should be fp32, got {bias_grad_result.dtype}" + assert x_grad_result.dtype == torch.bfloat16, f"Input grad should be bf16, got {x_grad_result.dtype}" + + dy_effective = dy_upstream.to(dtype=y_dtensor.dtype) + + weight_grad_expected = torch.einsum("...i,...o->io", dy_effective.float(), x_global_bf16.float()) + bias_grad_expected = dy_effective.float().sum(dim=(0, 1)) + bf16_atol, bf16_rtol = torch.testing._comparison.default_tolerances(torch.bfloat16) + torch.testing.assert_close(weight_grad_result, weight_grad_expected, atol=2 * bf16_atol, rtol=2 * bf16_rtol) + torch.testing.assert_close(bias_grad_result, bias_grad_expected, atol=2 * bf16_atol, rtol=2 * bf16_rtol) + + x_grad_expected = torch.einsum("...i,io->...o", dy_effective, weight_ref.to(dy_effective.dtype)) + torch.testing.assert_close(x_grad_result, x_grad_expected) + + assert_all_identical(weight_grad_result, manager.group["cp"]) + assert_all_identical(bias_grad_result, manager.group["cp"]) + finally: + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize("avg_over_replicate", [True, False], ids=["avg_reduce", "no_reduce"]) +@pytest.mark.parametrize("upstream_grad_dtype", [torch.float32, torch.bfloat16], ids=["dy_fp32", "dy_bf16"]) +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cpu", "ENV"), + ((1, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cpu-dp1-cp2x2", "cuda-dp1-cp2x2"], +) +def test_dtensor_linear_bf16_gradient_promotion(setup_env, upstream_grad_dtype, avg_over_replicate): + """Goals: bf16-mixed backward produces fp32 param grads from bf16 activations. + + - bf16 activations + fp32 parameters → bf16 output under autocast + - Weight/bias grads are fp32 (promoted), input grad stays bf16 + - Grads match explicit reference math within bf16 tolerance + - Cross-rank gradient consistency + - Tested with both fp32 and bf16 upstream adjoints + - Tested with and without avg-over-replicate gradient synchronisation + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + torch.multiprocessing.set_start_method("spawn", force=True) + torch.multiprocessing.spawn( + fn=parallel_assert_dtensor_linear_bf16_gradient_promotion, + args=( + grid_group_sizes, + device_type, + backend, + upstream_grad_dtype, + avg_over_replicate, + env_per_rank, + ), + nprocs=world_size, + join=True, + ) + + +@pytest.mark.parametrize("compare_to_native_dtensor", [True, False]) +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cuda", "ENV"), + ((1, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp:{x[0][0]}, cp:{x[0][1]}, specify_method:{x[1]}, device_type:{x[2]}, method_init:{x[3]}", +) +def test_dtensor_linear(setup_env, compare_to_native_dtensor): + """Goals: LinearParamsReplicated forward/backward matches global-tensor reference. + + - Forward output matches serial F.linear on global tensors + - Backward produces matching input, weight, and bias gradients + - Cross-rank gradient consistency for replicated parameters + """ + if compare_to_native_dtensor: + pytest.skip( + "Native PyTorch DTensor bugs introduced in the NGC 25.10 container upgrade (from 25.02) " + "cause this test to fail." + ) + + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + torch.multiprocessing.set_start_method("spawn", force=True) + torch.multiprocessing.spawn( + fn=parallel_assert_dtensor_linear, + args=( + grid_group_sizes, + device_type, + backend, + compare_to_native_dtensor, + env_per_rank, + ), + nprocs=world_size, + join=True, + ) diff --git a/tests/distributed/test_dtensor_metadata_tools.py b/tests/distributed/test_dtensor_metadata_tools.py new file mode 100755 index 000000000..df0fe1999 --- /dev/null +++ b/tests/distributed/test_dtensor_metadata_tools.py @@ -0,0 +1,304 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests the simple function raise_if_incorrect_dtensor_metadata_args + +Verification requirements + + V1a: TypeError not raised when object is a DTensor + V1b: TypeError raised when object is not a DTensor + V2a: ValueError not raised when shape matches expected_shape + V2b: ValueError raised when shape doesn't match expected_shape + + V3a: ValueError raised when device_mesh (d) match expected_device_mesh + V3b: ValueError raised when device_mesh (d) match expected_device_mesh + + V4a: ValueError raised when placements don't match expected_placements + V4a: ValueError raised when placements don't match expected_placements + + V5a: ValueError not raised when check_for_partial_placements=True and Partial placement absent + V5b: ValueError raised when check_for_partial_placements=True and Partial placement present + +Implementation status + V1a/b: done + V2a/b: done + V3a/b: done + V4a/b: done + V5a/b: done +""" + +import itertools +from copy import deepcopy +from typing import Any, Dict, OrderedDict + +import pytest +import torch +from torch import Tensor +from torch.distributed import DeviceMesh +from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_tensor +from torch.distributed.tensor._utils import compute_global_tensor_info + +from boltz.distributed.manager import DistributedManager, _GridGroupSizesType +from boltz.distributed.model.layers.dtensor_metadata_tools import raise_if_incorrect_dtensor_metadata_args +from boltz.testing.utils import ( + seed_by_rank, + skip_if_cuda_not_avail_or_device_count_less_than_word_size, + spawn_multiprocessing, +) + +SEED = 42 + + +def parallel_assert_raise_if_incorrect_dtensor_metadata_args( + rank: int, + input_example: OrderedDict[str, Tensor], + fn_kwargs: dict, # noqa + grid_group_sizes: _GridGroupSizesType, + device_type: str, + backend: str, + env_per_rank: Dict[str, str], +): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + dist_manager = DistributedManager() + + # Non-boillerplate + placements_for_single_rep_nonparam = (Shard(0), Shard(1), Replicate()) + input_meta: OrderedDict[str, Dict[str, Any]] = OrderedDict( + [ + ("x", dict(placements=placements_for_single_rep_nonparam, requires_grad=False)), # noqa C408 + ("y", dict(placements=placements_for_single_rep_nonparam, requires_grad=False)), # noqa C408 + ] + ) + + # ------------------------------------------------------------- + # Move inputs and ref outputs to device + # -------------------------------------------------------------- + input_example_device = OrderedDict( + [ + (k, v.detach().to(dist_manager.device, copy=True) if isinstance(v, Tensor) else deepcopy(v)) + for (k, v) in input_example.items() + ] + ) + # ----------------------------------------------------- + # Create input DTensors + # - after move to device + # ---------------------------------------------------- + input_example_as_dtensor = OrderedDict() + for name, meta in input_meta.items(): + input_example_as_dtensor[name] = distribute_tensor( + input_example_device[name], dist_manager.device_mesh_subgroups, meta["placements"] + ).requires_grad_(meta["requires_grad"]) + + # -------------------------------------- + # Verifications + # -------------------------------------- + + # V1a: TypeError not raised if inputs is a DTensor + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=input_example_as_dtensor["x"], + dtensor_name="x", + ) + # V1b: TypeError raised if inputs is not a DTensor + with pytest.raises(TypeError, match="should have type"): + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=input_example_device["x"], + dtensor_name="x", + ) + + # V2a: ValueError not raised when shapes match + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=input_example_as_dtensor["x"], + dtensor_name="x", + expected_shape=input_example_device["y"].shape, + ) + # V2b: ValueError raised when shape doesn't match + with pytest.raises(ValueError, match="should have shape"): + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=input_example_as_dtensor["x"], + dtensor_name="x", + expected_shape=( + input_example_device["x"].shape[0] + 1, + input_example_device["x"].shape[1], + input_example_device["x"].shape[2], + ), + ) + # V3a: ValueError not raised when device_mesh does match + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=input_example_as_dtensor["x"], + dtensor_name="x", + expected_device_mesh=input_example_as_dtensor["x"].device_mesh, + ) + # V3b: ValueError raised when device_mesh doesn't match + wrong_mesh = DeviceMesh(device_type, torch.arange(2)) + with pytest.raises(ValueError, match="should have device mesh"): + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=input_example_as_dtensor["x"], + dtensor_name="x", + expected_device_mesh=wrong_mesh, + ) + + # V4a: ValueError not raised when placements do match + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=input_example_as_dtensor["x"], + dtensor_name="x", + expected_placements=input_example_as_dtensor["x"].placements, + ) + # V4b: ValueError raised when placements don't match + wrong_placements = (Replicate(), Replicate()) + with pytest.raises(ValueError, match="placements"): + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=input_example_as_dtensor["x"], + dtensor_name="x", + expected_placements=wrong_placements, + ) + + # V5a: ValueError not raised placements do not contains Partial + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=input_example_as_dtensor["x"], + dtensor_name="x", + check_for_partial_placements=True, + ) + # --------------------------------------------------------------------- + # V5b: ValueError raised when Partial placement exists and check_for_partial_placements=True + # - Use-case: + # - Some custom autograd function f + # f.apply(a) -> (x,y) + # - Some custom autograd function g + # g.apply(x, y) -> z + # - g.forward(x, y) might check that x and y have the same placements + # but x.placement or y.placement might contain Partial + # --------------------------------------------------------------------- + # emulate the forward pass of f + shape_x_global, stride_x_global = map( + tuple, + compute_global_tensor_info( + input_example_device["x"], dist_manager.device_mesh_subgroups, (Shard(0), Shard(1), Partial()) + ), + ) + x_with_partial = DTensor.from_local( + input_example_device["x"], + device_mesh=dist_manager.device_mesh_subgroups, + placements=(Shard(0), Shard(1), Partial()), + shape=shape_x_global, + stride=stride_x_global, + ) + + shape_y_global, stride_y_global = map( + tuple, + compute_global_tensor_info( + input_example_device["y"], dist_manager.device_mesh_subgroups, (Shard(0), Shard(1), Partial()) + ), + ) + y_with_partial = DTensor.from_local( + input_example_device["y"], + device_mesh=dist_manager.device_mesh_subgroups, + placements=(Shard(0), Shard(1), Partial()), + shape=shape_y_global, + stride=stride_y_global, + ) + # emulate one version of the metadata check for g.forward() + with pytest.raises(ValueError, match="placement of type"): + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=x_with_partial, + dtensor_name="x_with_partial", + ) + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=y_with_partial, + dtensor_name="y_with_partial", + expected_placements=x_with_partial.placements, + check_for_partial_placements=True, + ) + + # emulate a second version of the metadata check for g.forward() + with pytest.raises(ValueError, match="placement of type"): + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=x_with_partial, + dtensor_name="x_with_partial", + check_for_partial_placements=True, + ) + raise_if_incorrect_dtensor_metadata_args( + dtensor_instance=y_with_partial, + dtensor_name="y_with_partial", + expected_placements=x_with_partial.placements, + ) + + # cleanup + DistributedManager.cleanup() + monkeypatch.undo() + + +def get_example_input( + len_dim_0: int, + len_dim_1: int, + len_dim_2: int, + seed: int = SEED, +): + """Generate example input""" + seed_by_rank(seed) # specified seed for each rank + x = torch.randn(len_dim_0, len_dim_1, len_dim_2, requires_grad=False) + y = torch.randn(len_dim_0, len_dim_1, len_dim_2, requires_grad=False) + input_example = OrderedDict([("x", x), ("y", y)]) + return input_example + + +@pytest.mark.parametrize( + "setup_env", + itertools.product([(1, (2, 2))], [True], ["cpu"], ["ENV"]), + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]}, method_init={x[3]}", +) +def test_raise_if_incorrect_dtensor_metadata_args( + setup_env: dict[str, int], + len_dim_0: int = 2, + len_dim_1: int = 3, + len_dim_2: int = 5, + seed: int = SEED, +): + """Test raise_if_incorrect_dtensor_metadata_args.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + skip_if_cuda_not_avail_or_device_count_less_than_word_size(device_type, world_size) + + # Get example input and reference output + input_example = get_example_input( + len_dim_0=len_dim_0, + len_dim_1=len_dim_1, + len_dim_2=len_dim_2, + seed=seed, + ) + + # Launch parallel test across all processes + spawn_multiprocessing( + parallel_assert_raise_if_incorrect_dtensor_metadata_args, + world_size, + input_example, + {}, + grid_group_sizes, + device_type, + backend, + env_per_rank, + ) diff --git a/tests/distributed/test_dtensor_parallel_assert_factored_lddt_loss.py b/tests/distributed/test_dtensor_parallel_assert_factored_lddt_loss.py new file mode 100644 index 000000000..e841d0eac --- /dev/null +++ b/tests/distributed/test_dtensor_parallel_assert_factored_lddt_loss.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from __future__ import annotations + +import pytest +import torch +from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.layers.atom_to_token import reconstruct_atom_to_token_global +from boltz.distributed.model.loss.validation import factored_lddt_loss as dtensor_factored_lddt_loss +from boltz.model.loss.validation import factored_lddt_loss as serial_factored_lddt_loss +from boltz.testing.utils import distribute_atom_features, get_feature_placements, random_features, spawn_multiprocessing + +# Get feature placements for the subset of features needed by factored_lddt_loss +_atom_keys = {"atom_pad_mask", "atom_to_token", "atom_counts_per_token"} +_token_keys = {"mol_type", "asym_id"} +_placements = get_feature_placements(atom_keys=_atom_keys, token_keys=_token_keys) +_placements_atom_features = _placements["atom_features"] +_placements_cp_atom_features = _placements["cp_atom_features"] +_placements_token_features = _placements["token_features"] + +# Pred/true coords placements: [B*mult, n_atoms, 3] +_placements_pred_coords = {"pred_coords": (Shard(0), Shard(1), Replicate())} +_placements_true_coords = {"true_coords": (Shard(0), Shard(1), Replicate())} +_placements_cp_pred_coords = {"pred_coords": (Shard(0), Replicate())} +_placements_cp_true_coords = {"true_coords": (Shard(0), Replicate())} + + +def parallel_assert_factored_lddt_loss(rank, payload): + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + multiplicity, + feats_global_host, + pred_coords_global_host, + true_coords_global_host, + ref_lddt, + ref_total, + ) = payload + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device = manager.device + dtype = torch.float32 + + size_batch = feats_global_host["atom_pad_mask"].shape[0] + rank_dp = manager.group_rank["dp"] + num_dp_ranks = grid_group_sizes["dp"] + local_batch_size = size_batch // num_dp_ranks + local_start = rank_dp * local_batch_size + local_end = local_start + local_batch_size + + def _all_gather_single_repr(single_dtensor): + single_dtensor = single_dtensor.redistribute( + placements=[Shard(0), Replicate(), Replicate()], + ) + return single_dtensor.to_local() + + inputs_atom = { + k: v.to(dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in _placements_cp_atom_features + } + pred_coords_unflat = pred_coords_global_host.unflatten(0, (size_batch, multiplicity)) + true_coords_unflat = true_coords_global_host.unflatten(0, (size_batch, multiplicity)) + for i_mul in range(multiplicity): + inputs_atom[f"pred_coords_{i_mul}"] = pred_coords_unflat[:, i_mul].to(dtype=dtype) + inputs_atom[f"true_coords_{i_mul}"] = true_coords_unflat[:, i_mul].to(dtype=dtype) + + placements_cp = dict(_placements_cp_atom_features) + placements_dp_cp = dict(_placements_atom_features) + for i_mul in range(multiplicity): + placements_cp[f"pred_coords_{i_mul}"] = _placements_cp_pred_coords["pred_coords"] + placements_cp[f"true_coords_{i_mul}"] = _placements_cp_true_coords["true_coords"] + placements_dp_cp[f"pred_coords_{i_mul}"] = _placements_pred_coords["pred_coords"] + placements_dp_cp[f"true_coords_{i_mul}"] = _placements_true_coords["true_coords"] + + feats_atom = distribute_atom_features( + inputs_atom, + placements_cp, + placements_dp_cp, + manager.device_mesh_subgroups, + manager.group["cp"], + multiplicities={"pred_coords": multiplicity, "true_coords": multiplicity}, + ) + + pred_coords_dtensor = feats_atom["pred_coords"] + true_coords_dtensor = feats_atom["true_coords"] + atom_pad_mask_dtensor = feats_atom["atom_pad_mask"] + atom_to_token_dtensor = feats_atom["atom_to_token"] + + mol_type_dtensor = distribute_tensor( + feats_global_host["mol_type"].to(device=device, dtype=torch.int64), + device_mesh=manager.device_mesh_subgroups, + placements=_placements_token_features["mol_type"], + ) + asym_id_dtensor = distribute_tensor( + feats_global_host["asym_id"].to(device=device, dtype=torch.int64), + device_mesh=manager.device_mesh_subgroups, + placements=_placements_token_features["asym_id"], + ) + + pred_coords_local = _all_gather_single_repr(pred_coords_dtensor) + true_coords_local = _all_gather_single_repr(true_coords_dtensor) + atom_mask_base_local = _all_gather_single_repr(atom_pad_mask_dtensor) + mol_type_local = _all_gather_single_repr(mol_type_dtensor) + asym_id_local = _all_gather_single_repr(asym_id_dtensor) + atom_to_token_local = reconstruct_atom_to_token_global(atom_to_token_dtensor) + + feats_local = { + "atom_to_token": atom_to_token_local, + "mol_type": mol_type_local, + "asym_id": asym_id_local, + } + + lddt_dict, total_dict = dtensor_factored_lddt_loss( + true_atom_coords=true_coords_local, + pred_atom_coords=pred_coords_local, + feats=feats_local, + atom_mask=atom_mask_base_local, + multiplicity=multiplicity, + ) + + ref_slice = slice(local_start * multiplicity, local_end * multiplicity) + for key in ref_lddt: + torch.testing.assert_close( + lddt_dict[key].cpu(), + ref_lddt[key][ref_slice], + ) + torch.testing.assert_close( + total_dict[key].cpu(), + ref_total[key][ref_slice], + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (1, 1)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), +) +def test_dtensor_parallel_asserrt_factored_lddt_loss(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type != "cuda": + pytest.skip("cdist_lddt requires CUDA") + + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() < {world_size}") + + torch.manual_seed(0) + rng = torch.Generator(device="cpu") + rng.manual_seed(0) + size_batch = grid_group_sizes["dp"] + cp_shape = grid_group_sizes["cp"] + size_cp = cp_shape if isinstance(cp_shape, int) else int(torch.tensor(cp_shape).prod().item()) + n_tokens = size_cp * 20 + n_atoms = n_tokens * 20 + multiplicity = 2 + + feats_global_host = random_features( + size_batch=size_batch, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=1, + atom_counts_per_token_range=(1, 20), + device=torch.device("cpu"), + float_value_range=(0.0, 1.0), + selected_keys=["atom_to_token", "atom_pad_mask", "atom_counts_per_token", "mol_type", "asym_id"], + rng=rng, + ) + pred_coords_global_host = torch.randn(size_batch * multiplicity, n_atoms, 3, dtype=torch.float32) + true_coords_global_host = torch.randn(size_batch * multiplicity, n_atoms, 3, dtype=torch.float32) + + atom_mask_mult = feats_global_host["atom_pad_mask"].to(torch.float32).repeat_interleave(multiplicity, dim=0) + feats_serial = { + "atom_to_token": feats_global_host["atom_to_token"].to(dtype=torch.float32), + "mol_type": feats_global_host["mol_type"], + "asym_id": feats_global_host["asym_id"], + } + + ref_lddt, ref_total = serial_factored_lddt_loss( + true_atom_coords=true_coords_global_host, + pred_atom_coords=pred_coords_global_host, + feats=feats_serial, + atom_mask=atom_mask_mult, + multiplicity=multiplicity, + ) + + spawn_multiprocessing( + parallel_assert_factored_lddt_loss, + world_size, + ( + grid_group_sizes, + device_type, + backend, + env_per_rank, + multiplicity, + feats_global_host, + pred_coords_global_host, + true_coords_global_host, + ref_lddt, + ref_total, + ), + ) diff --git a/tests/distributed/test_dtensor_predict.py b/tests/distributed/test_dtensor_predict.py new file mode 100644 index 000000000..31566ab3b --- /dev/null +++ b/tests/distributed/test_dtensor_predict.py @@ -0,0 +1,1290 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""End-to-end integration test for Boltz-2 distributed inference via run_predict. + +Adapted from the Boltz-1x-CP ``tests_v1/distributed/model/test_dtensor_predict.py``. + +Invokes the full ``run_predict`` entrypoint with a real Boltz-2 checkpoint and +preprocessed data, then evaluates predicted structure output against golden +reference coordinates using lDDT and RMSD metrics. +""" + +import json +import math +import os +import subprocess +import tempfile +from pathlib import Path +from typing import Any + +import biotite.structure.io.pdbx as pdbx +import numpy as np +import pytest +import torch +from scipy import stats + +from boltz.distributed.data.types import PairMaskMode +from boltz.distributed.model.layers.triangular_attention import cueq_is_installed, trifast_is_installed +from boltz.distributed.model.modules.utils import Precision, SDPAWithBiasBackend, TriAttnBackend +from boltz.distributed.predict import run_predict +from boltz.testing.utils import ( + compute_pairwise_lddt_rmsd_matrices, + concat_data, + energy_distance_from_matrices, + intra_rowwise_best, + matched_mean_metric, + spawn_multiprocessing, +) + +""" +Intra-golden-value lDDT and RMSD matrices (serial Boltz-2, 5 diffusion samples each). + +Because diffusion sampling is stochastic, comparing a single distributed prediction +against a single serial prediction (e.g. model_0 vs model_0) is unreliable: even +serial samples disagree with each other at lDDT ~0.75 for 7z64. Instead we use two +distributional tests that treat all N diffusion samples as draws from each code path: + +1. **Energy Distance** -- measures how far apart the DTensor and serial sample + distributions are. Computed as E_dist = 2*E[d(D,S)] - E[d(D,D)] - E[d(S,S)] + where d is (1 - lDDT) or RMSD. Cross-distribution uses all NxM pairs; intra- + distribution uses upper-triangle pairs. Near-zero means the two distributions + are indistinguishable. + +2. **Hungarian-matched mean** -- optimal 1-to-1 assignment between DTensor and + serial samples (via ``scipy.optimize.linear_sum_assignment``). The mean metric + of matched pairs is compared against the serial-serial baseline (mean of row-wise + best excluding diagonal). A small gap confirms the DTensor samples are as close + to serial samples as serial samples are to each other. + +=== 7ylz (atoms=4949) === +lDDT matrix (row=i, col=j): + m0 m1 m2 m3 m4 + m0 1.0000 0.8958 0.8938 0.8913 0.8985 + m1 0.8969 1.0000 0.8916 0.8907 0.8882 + m2 0.8916 0.8882 1.0000 0.8926 0.8930 + m3 0.8934 0.8914 0.8965 1.0000 0.8953 + m4 0.8974 0.8862 0.8938 0.8928 1.0000 +RMSD matrix (row=i, col=j): + m0 m1 m2 m3 m4 + m0 0.0000 0.9640 1.0000 1.0008 0.9815 + m1 0.9640 0.0000 1.0994 1.0474 1.1162 + m2 1.0000 1.0994 0.0000 1.0240 0.9596 + m3 1.0008 1.0474 1.0240 0.0000 1.0014 + m4 0.9815 1.1162 0.9596 1.0014 0.0000 + +=== 7z64 (atoms=2265) === +lDDT matrix (row=i, col=j): + m0 m1 m2 m3 m4 + m0 1.0000 0.7603 0.7605 0.7681 0.7579 + m1 0.7605 1.0000 0.7342 0.7564 0.7389 + m2 0.7577 0.7327 1.0000 0.7103 0.7500 + m3 0.7598 0.7513 0.7041 1.0000 0.7479 + m4 0.7557 0.7376 0.7476 0.7514 1.0000 +RMSD matrix (row=i, col=j): + m0 m1 m2 m3 m4 + m0 0.0000 2.1296 1.9554 2.1175 2.2710 + m1 2.1296 0.0000 2.3211 2.0270 2.2595 + m2 1.9554 2.3211 0.0000 2.5420 2.0863 + m3 2.1175 2.0270 2.5420 0.0000 2.2688 + m4 2.2710 2.2595 2.0863 2.2688 0.0000 + +=== 8ayv (atoms=2396) === +lDDT matrix (row=i, col=j): + m0 m1 m2 m3 m4 + m0 1.0000 0.9798 0.9824 0.9712 0.9729 + m1 0.9800 1.0000 0.9848 0.9761 0.9795 + m2 0.9824 0.9844 1.0000 0.9749 0.9816 + m3 0.9718 0.9765 0.9757 1.0000 0.9764 + m4 0.9720 0.9781 0.9809 0.9746 1.0000 +RMSD matrix (row=i, col=j): + m0 m1 m2 m3 m4 + m0 0.0000 0.2793 0.2539 0.3235 0.4098 + m1 0.2793 0.0000 0.3147 0.3194 0.2813 + m2 0.2539 0.3147 0.0000 0.3289 0.3896 + m3 0.3235 0.3194 0.3289 0.0000 0.4042 + m4 0.4098 0.2813 0.3896 0.4042 0.0000 + +=== 8b2e (atoms=1062) === +lDDT matrix (row=i, col=j): + m0 m1 m2 m3 m4 + m0 1.0000 0.9637 0.9573 0.9612 0.9503 + m1 0.9657 1.0000 0.9661 0.9698 0.9625 + m2 0.9570 0.9635 1.0000 0.9683 0.9648 + m3 0.9627 0.9694 0.9698 1.0000 0.9617 + m4 0.9513 0.9616 0.9664 0.9614 1.0000 +RMSD matrix (row=i, col=j): + m0 m1 m2 m3 m4 + m0 0.0000 0.3433 0.4236 0.4008 0.4802 + m1 0.3433 0.0000 0.3459 0.3457 0.3600 + m2 0.4236 0.3459 0.0000 0.2976 0.3316 + m3 0.4008 0.3457 0.2976 0.0000 0.3323 + m4 0.4802 0.3600 0.3316 0.3323 0.0000 + +=== prot_custom_msa (atoms=899) === +lDDT matrix (row=i, col=j): + m0 m1 m2 m3 m4 + m0 1.0000 0.6048 0.7097 0.8564 0.6571 + m1 0.5970 1.0000 0.6722 0.5981 0.5972 + m2 0.7100 0.6883 1.0000 0.6872 0.6075 + m3 0.8628 0.6128 0.6918 1.0000 0.6632 + m4 0.6512 0.6097 0.5984 0.6537 1.0000 +RMSD matrix (row=i, col=j): + m0 m1 m2 m3 m4 + m0 0.0000 18.1409 3.5726 5.2528 11.7845 + m1 18.1409 0.0000 3.5930 16.2112 21.6515 + m2 3.5726 3.5930 0.0000 16.4714 16.1519 + m3 5.2528 16.2112 16.4714 0.0000 4.0229 + m4 11.7845 21.6515 16.1519 4.0229 0.0000 +""" + +ENERGY_DIST_LDDT_TOL = 0.03 +ENERGY_DIST_RMSD_TOL = 0.2 +MATCHED_LDDT_DIFF_TOL = 0.03 +MATCHED_RMSD_DIFF_TOL = 0.2 + + +def _get_structural_tolerances(name_sample: str) -> dict: + """Return per-sample structural tolerance dict. + + Most samples have tight structural convergence across diffusion samples + (RMSD < 1 Å, lDDT > 0.95), so the default global constants suffice. + prot_custom_msa is an exception — see matrices above — where the single- + chain protein with limited MSA depth/diversity produces wildly different + conformations across diffusion samples (serial-vs-serial RMSD 1.3–22.3 Å, + lDDT 0.60–0.88 at n=20). The serial-vs-DTensor discrepancy is within + this inherent noise (lDDT mean 0.744 vs 0.752, RMSD mean 13.98 vs 13.64), + so wider thresholds are appropriate. + + n=5 serial-vs-DTensor observed (H200, BF16_MIXED, CUEQ, dp=1 cp=2×2): + energy_dist_lddt=0.084, energy_dist_rmsd=0.706 + matched_lddt_diff=0.059, matched_rmsd_diff=0.308 + Thresholds set at ~2× observed values. + """ + sample_id = name_sample.replace("processed_", "") + if sample_id in _CUSTOM_MSA_SAMPLES: + return { + "energy_dist_lddt": 0.17, + "energy_dist_rmsd": 1.5, + "matched_lddt_diff": 0.12, + "matched_rmsd_diff": 0.7, + } + return { + "energy_dist_lddt": ENERGY_DIST_LDDT_TOL, + "energy_dist_rmsd": ENERGY_DIST_RMSD_TOL, + "matched_lddt_diff": MATCHED_LDDT_DIFF_TOL, + "matched_rmsd_diff": MATCHED_RMSD_DIFF_TOL, + } + + +CONFIDENCE_SCALAR_KEYS = [ + # All 9 metrics exist in both Boltz-1x and Boltz-2, but Boltz-1x-CP only + # golden-compares confidence_score (at 5%). Boltz-2 compares all of them. + "confidence_score", # Boltz-1x-CP also golden-compares this (5% rtol) + "ptm", # Boltz-2 golden comparison only + "iptm", # Boltz-2 golden comparison only + "ligand_iptm", # Boltz-2 golden comparison only + "protein_iptm", # Boltz-2 golden comparison only + "complex_plddt", # Boltz-2 golden comparison only + "complex_iplddt", # Boltz-2 golden comparison only + "complex_pde", # Boltz-2 golden comparison only + "complex_ipde", # Boltz-2 golden comparison only +] + +# --------------------------------------------------------------------------- +# Golden comparison tolerances for confidence metrics. +# +# The golden values come from serial inference (1 GPU), while the distributed +# test uses context parallelism (4 GPUs). Confidence metric agreement depends +# on structure prediction quality and algorithmic sensitivity. +# +# Three tolerance tiers (all at ~1.5x observed max from 8-GPU H100 runs +# across BF16/BF16_MIXED/TF32/FP32 and multiple attention backends): +# +# MSA (7ylz, 8ayv): tight. BF16_MIXED is the dominant noise source +# (score rel_diff up to 0.028 vs <0.01 for other precisions). +# +# LIGAND (8b2e): medium. Non-polymer frame reassignment in +# compute_frame_pred amplifies coordinate differences into iPTM/iPDE +# deviations. ~15% iPTM and ~1.5 A iPDE observed (BF16_MIXED worst). +# +# NO_MSA (7z64): wide. Poor structure predictions cause volatile +# confidence, especially for per-chain-pair iPTM (up to 38% rel_diff). +# +# Within each tier, four regimes: +# 1. confidence_score: composite metric (0.8*plddt + 0.2*iptm), most stable. +# 2. Probability-bounded scalars: ptm, iptm, plddt, etc. ([0, 1]). +# 3. Per-chain / per-chain-pair metrics: smaller denominators, more variance. +# 4. Distance-error metrics (Angstroms): absolute tolerance. +# --------------------------------------------------------------------------- +_DISTANCE_METRICS = {"complex_pde", "complex_ipde"} + +# Override diffusion samples via env var for deeper statistical analysis: +# BOLTZ_PREDICT_DIFFUSION_SAMPLES=20 pytest -v -s ... +_DIFFUSION_SAMPLES = int(os.environ.get("BOLTZ_PREDICT_DIFFUSION_SAMPLES", "5")) + +_PRE_GENERATED_GOLDEN_SAMPLES = 5 +_serial_golden_cache: dict[str, Path] = {} + + +def _run_serial_predict( + data_dir: Path, + checkpoint: Path, + cache_dir: Path, + diffusion_samples: int, +) -> Path: + """Run serial (1-GPU) Boltz-2 predict and return the predictions directory. + + Uses ``boltz predict --input_format preprocessed`` via subprocess for complete + process isolation (no CUDA context or torch.distributed leakage into the test + process). Results are cached per protein so the serial run happens at most once + per session regardless of how many precision/backend parametrizations follow. + + Returns the predictions directory (containing per-protein subdirs) suitable as + a drop-in replacement for ``get_inference_golden_value_dir_v2``. + """ + cache_key = data_dir.name + if cache_key in _serial_golden_cache: + return _serial_golden_cache[cache_key] + + out_dir = Path(tempfile.mkdtemp(prefix=f"serial_golden_{cache_key}_")) + print(f"\n [serial golden] Generating {diffusion_samples} serial samples for {cache_key} -> {out_dir}") + + cmd = [ + "boltz", + "predict", + str(data_dir), + "--out_dir", + str(out_dir), + "--cache", + str(cache_dir), + "--checkpoint", + str(checkpoint), + "--diffusion_samples", + str(diffusion_samples), + "--seed", + "42", + "--input_format", + "preprocessed", + "--write_full_pae", + "--devices", + "1", + "--accelerator", + "gpu", + "--recycling_steps", + "10", + "--sampling_steps", + "200", + "--model", + "boltz2", + "--max_msa_seqs", + "2048", + "--override", + ] + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError( + f"Serial predict failed (exit {result.returncode}) for {cache_key}.\n" + f"Command: {' '.join(cmd)}\n" + f"--- stdout ---\n{result.stdout[-2000:] if result.stdout else '(empty)'}\n" + f"--- stderr ---\n{result.stderr[-2000:] if result.stderr else '(empty)'}" + ) + + golden_dir = out_dir / f"boltz_results_{cache_key}" / "predictions" + if not golden_dir.exists(): + raise RuntimeError( + f"Serial predict exited 0 but {golden_dir} does not exist.\n" + f"Contents of {out_dir}: {list(out_dir.rglob('*'))}\n" + f"--- stdout ---\n{result.stdout[-2000:] if result.stdout else '(empty)'}\n" + f"--- stderr ---\n{result.stderr[-2000:] if result.stderr else '(empty)'}" + ) + _serial_golden_cache[cache_key] = golden_dir + return golden_dir + + +_NO_MSA_SAMPLES = {"7z64"} + +# Samples with a pre-existing custom MSA (e.g. from a local a3m file) rather +# than a server-generated MSA. Limited MSA depth leads to higher variance +# between serial and distributed runs than the fully-MSA'd preprocessed +# samples but lower than the no-MSA tier. +_CUSTOM_MSA_SAMPLES = {"prot_custom_msa"} + +# Samples with non-polymer (ligand) chains where compute_frame_pred's +# distance-based frame reassignment amplifies coordinate differences. +_LIGAND_COMPLEX_SAMPLES = {"8b2e"} + + +def _get_confidence_tolerances(name_sample: str) -> dict: + """Return tolerance dict for a given sample based on prediction quality. + + Tolerances derived from n=20 serial-vs-distributed comparisons on 8-GPU + H100 across BF16/BF16_MIXED/FP32 precisions with ~2× headroom over the + worst observed rel_diff. Serial golden always runs at BF16_MIXED, so BF16 + tests reflect a cross-precision comparison (expected to be noisier). + """ + sample_id = name_sample.replace("processed_", "") + if sample_id in _NO_MSA_SAMPLES: + # n=20 max observed: score=0.008, prob=0.054, chain=0.061, ipde=0.34 Å + # chain_rtol raised to 0.12 after observing 0.1016 on pair_chains_iptm[(0,5)] + # in BF16 cross-precision runs (serial BF16_MIXED vs DTensor BF16). + return { + "score_rtol": 0.02, + "prob_rtol": 0.08, + "chain_rtol": 0.12, + "dist_atol": 0.5, + } + if sample_id in _CUSTOM_MSA_SAMPLES: + # Custom MSA from pre-existing a3m has limited depth/diversity compared + # to server-generated MSAs. This single-chain protein (prot_custom_msa, + # 899 atoms) produces enormous structural variance across diffusion + # samples — serial-vs-serial RMSD spans 1.3–22.3 Å — so the systematic + # serial-vs-DTensor shift in confidence is small relative to inherent + # diffusion noise. + # + # n=20 serial-only inherent noise (H200, BF16_MIXED, seed=42): + # Metric Mean Std Min Max CoV% + # confidence_score 0.7051 0.0157 0.6703 0.7266 2.22 + # ptm 0.6785 0.0247 0.6329 0.7217 3.64 + # complex_plddt 0.7117 0.0148 0.6796 0.7342 2.08 + # complex_pde 1.0478 0.0740 0.8835 1.1748 7.06 + # + # n=20 serial-vs-DTensor (dp=1, cp=2×2, CUEQ, BF16_MIXED): + # Metric SerMean DTMean RelDiff Headroom + # confidence_score 0.7051 0.6799 3.58% 2.0× + # ptm 0.6785 0.7151 5.38% 1.7× + # complex_plddt 0.7117 0.6711 5.72% 1.6× + # complex_pde 1.0478 0.9438 0.104Å 1.4× + # + # n=20 structural noise comparison: + # Source lDDT mean RMSD mean + # Serial-vs-Serial (inherent) 0.7515 13.64 Å + # DTensor-vs-DTensor (inherent) 0.7385 14.14 Å + # Serial-vs-DTensor (matched) 0.7440 13.98 Å + # + # The serial-vs-DTensor structural discrepancy is within the inherent + # serial diffusion noise, confirming the confidence tolerance below is + # driven by floating-point accumulation order (BF16 CP), not structural + # quality degradation. Thresholds set at ~1.4–2× observed values. + return { + "score_rtol": 0.07, + "prob_rtol": 0.09, + "chain_rtol": 0.10, + "dist_atol": 0.15, + } + if sample_id in _LIGAND_COMPLEX_SAMPLES: + # n=20 max observed: score=0.015, prob=0.094, chain=0.128, ipde=0.81 Å + # High variance from ligand frame reassignment in compute_frame_pred. + return { + "score_rtol": 0.03, + "prob_rtol": 0.15, + "chain_rtol": 0.21, + "dist_atol": 2.0, + } + # MSA tier — n=20 max observed: score=0.002, prob=0.004, chain=0.004, ipde=0.019 Å + return { + "score_rtol": 0.01, + "prob_rtol": 0.02, + "chain_rtol": 0.02, + "dist_atol": 0.05, + } + + +def cif_to_tensor(cif_file: Path) -> torch.Tensor: + """Parse a CIF file and return atom coordinates as a Tensor.""" + data_cif = pdbx.CIFFile.read(cif_file) + atom_array = pdbx.get_structure(data_cif, model=1, include_bonds=True) + return torch.tensor(atom_array.coord) + + +def _load_confidence_jsons(directory: Path, name_sample: str) -> list[dict]: + """Load all confidence JSON files for a sample, sorted by model index.""" + jsons = sorted(directory.glob(f"confidence_{name_sample}_model_*.json")) + results = [] + for jf in jsons: + with jf.open() as f: + results.append(json.load(f)) + return results + + +def _assert_confidence_values_sane(conf_data: dict, source: str): + """Assert that confidence values are finite and in expected ranges.""" + errors = [] + for key in CONFIDENCE_SCALAR_KEYS: + if key not in conf_data: + errors.append(f" {key}: MISSING") + continue + val = conf_data[key] + if not math.isfinite(val): + errors.append(f" {key}: {val} (not finite)") + elif key in ("ptm", "iptm", "complex_plddt", "complex_iplddt", "confidence_score") and not (0 <= val <= 1): + errors.append(f" {key}: {val} (outside [0, 1])") + + for nested_key in ("chains_ptm", "pair_chains_iptm"): + if nested_key not in conf_data: + errors.append(f" {nested_key}: MISSING") + elif not isinstance(conf_data[nested_key], dict): + errors.append(f" {nested_key}: expected dict, got {type(conf_data[nested_key]).__name__}") + + if errors: + raise AssertionError(f"Confidence sanity check failed ({source}):\n" + "\n".join(errors)) + + +def _compare_confidence_golden_vs_distributed( + golden_jsons: list[dict], + dist_jsons: list[dict], + name_sample: str, +): + """Compare mean confidence metrics between golden (serial) and distributed. + + Golden values come from serial inference on a single GPU; distributed + values come from context-parallel inference on multiple GPUs. Tolerances + are stratified by sample prediction quality (see ``_get_confidence_tolerances``): + samples with MSA have tight tolerances; samples without MSA (e.g. 7z64) have + wider tolerances because poor structure predictions cause volatile confidence. + """ + tols = _get_confidence_tolerances(name_sample) + errors = [] + diffs: list[str] = [] + for key in CONFIDENCE_SCALAR_KEYS: + golden_vals = [j[key] for j in golden_jsons if key in j] + dist_vals = [j[key] for j in dist_jsons if key in j] + if not golden_vals or not dist_vals: + continue + golden_mean = sum(golden_vals) / len(golden_vals) + dist_mean = sum(dist_vals) / len(dist_vals) + abs_diff = abs(dist_mean - golden_mean) + + if key in _DISTANCE_METRICS: + tag = "FAIL" if abs_diff > tols["dist_atol"] else "ok" + diffs.append( + f" [{tag}] {key}: golden={golden_mean:.6f}, dist={dist_mean:.6f}, " + f"abs_diff={abs_diff:.4f} Å (threshold={tols['dist_atol']} Å)" + ) + if abs_diff > tols["dist_atol"]: + errors.append( + f" {key}: golden_mean={golden_mean:.6f}, dist_mean={dist_mean:.6f}, " + f"abs_diff={abs_diff:.4f} Å > {tols['dist_atol']} Å" + ) + elif abs(golden_mean) < 1e-8: + tag = "FAIL" if abs_diff > 0.01 else "ok" + diffs.append( + f" [{tag}] {key}: golden={golden_mean:.6f}, dist={dist_mean:.6f}, " + f"abs_diff={abs_diff:.6f} (threshold=0.01)" + ) + if abs_diff > 0.01: + errors.append( + f" {key}: golden_mean={golden_mean:.6f}, dist_mean={dist_mean:.6f}, abs_diff={abs_diff:.6f} > 0.01" + ) + else: + rtol = tols["score_rtol"] if key == "confidence_score" else tols["prob_rtol"] + rel_diff = abs_diff / abs(golden_mean) + tag = "FAIL" if rel_diff > rtol else "ok" + diffs.append( + f" [{tag}] {key}: golden={golden_mean:.6f}, dist={dist_mean:.6f}, " + f"rel_diff={rel_diff:.4f} (threshold={rtol})" + ) + if rel_diff > rtol: + errors.append( + f" {key}: golden_mean={golden_mean:.6f}, dist_mean={dist_mean:.6f}, " + f"rel_diff={rel_diff:.4f} > {rtol}" + ) + for nested_key in ("chains_ptm", "pair_chains_iptm"): + _compare_nested_confidence(golden_jsons, dist_jsons, nested_key, name_sample, errors, diffs) + + sample_id = name_sample.replace("processed_", "") + _tier_map = [ + (_NO_MSA_SAMPLES, "NO_MSA"), + (_CUSTOM_MSA_SAMPLES, "CUSTOM_MSA"), + (_LIGAND_COMPLEX_SAMPLES, "LIGAND"), + ] + tier = next((t for s, t in _tier_map if sample_id in s), "MSA") + # pytest -s to see confidence metric diffs for tolerance tuning + n_samples = max(len(golden_jsons), len(dist_jsons)) + print(f"\n=== Confidence diffs for {name_sample} (tier={tier}, n={n_samples}) ===") + + print(f" Per-sample values ({n_samples} diffusion samples):") + for key in CONFIDENCE_SCALAR_KEYS: + golden_vals = [j.get(key) for j in golden_jsons] + dist_vals = [j.get(key) for j in dist_jsons] + g_str = ", ".join(f"{v:.4f}" if v is not None else "N/A" for v in golden_vals) + d_str = ", ".join(f"{v:.4f}" if v is not None else "N/A" for v in dist_vals) + print(f" {key}: golden=[{g_str}] dist=[{d_str}]") + + # Statistical summary: Welch's t-test per metric + print(" Statistical analysis (Welch's t-test, two-sided):") + for key in CONFIDENCE_SCALAR_KEYS: + golden_vals = [j[key] for j in golden_jsons if key in j] + dist_vals = [j[key] for j in dist_jsons if key in j] + if len(golden_vals) < 2 or len(dist_vals) < 2: + continue + g_arr, d_arr = np.array(golden_vals), np.array(dist_vals) + if np.std(g_arr) < 1e-12 and np.std(d_arr) < 1e-12: + continue + t_stat, p_val = stats.ttest_ind(g_arr, d_arr, equal_var=False) + shift = np.mean(d_arr) - np.mean(g_arr) + print( + f" {key}: shift={shift:+.6f}, " + f"golden_std={np.std(g_arr):.4f}, dist_std={np.std(d_arr):.4f}, " + f"t={t_stat:.2f}, p={p_val:.4f}" + ) + + print(" Mean comparison:") + for d in diffs: + print(d) + + if errors: + raise AssertionError( + f"Confidence metric mismatch for {name_sample} " + f"(golden={len(golden_jsons)} samples, dist={len(dist_jsons)} samples):\n" + "\n".join(errors) + ) + + +def _compare_nested_confidence( + golden_jsons: list[dict], + dist_jsons: list[dict], + nested_key: str, + name_sample: str, + errors: list[str], + diffs: list[str] | None = None, +) -> None: + """Compare per-chain / per-chain-pair confidence dicts between golden and distributed. + + For ``chains_ptm``: ``{chain_id: float}`` + For ``pair_chains_iptm``: ``{chain_id1: {chain_id2: float}}`` + + Flattens all (chain_key, sample_index) pairs, computes per-chain-key + mean across diffusion samples, and applies per-sample ``chain_rtol``. + """ + tols = _get_confidence_tolerances(name_sample) + chain_rtol = tols["chain_rtol"] + + golden_has = [nested_key in j and isinstance(j[nested_key], dict) for j in golden_jsons] + dist_has = [nested_key in j and isinstance(j[nested_key], dict) for j in dist_jsons] + if not all(golden_has) or not all(dist_has): + return + + if nested_key == "chains_ptm": + golden_flat = _flatten_chains_ptm(golden_jsons) + dist_flat = _flatten_chains_ptm(dist_jsons) + else: + golden_flat = _flatten_pair_chains_iptm(golden_jsons) + dist_flat = _flatten_pair_chains_iptm(dist_jsons) + + shared_keys = set(golden_flat.keys()) & set(dist_flat.keys()) + for chain_key in sorted(shared_keys): + golden_mean = sum(golden_flat[chain_key]) / len(golden_flat[chain_key]) + dist_mean = sum(dist_flat[chain_key]) / len(dist_flat[chain_key]) + abs_diff = abs(dist_mean - golden_mean) + if abs(golden_mean) < 1e-8: + tag = "FAIL" if abs_diff > 0.01 else "ok" + if diffs is not None: + diffs.append( + f" [{tag}] {nested_key}[{chain_key}]: golden={golden_mean:.6f}, " + f"dist={dist_mean:.6f}, abs_diff={abs_diff:.6f} (threshold=0.01)" + ) + if abs_diff > 0.01: + errors.append( + f" {nested_key}[{chain_key}]: golden_mean={golden_mean:.6f}, " + f"dist_mean={dist_mean:.6f}, abs_diff={abs_diff:.6f} > 0.01" + ) + else: + rel_diff = abs_diff / abs(golden_mean) + tag = "FAIL" if rel_diff > chain_rtol else "ok" + if diffs is not None: + diffs.append( + f" [{tag}] {nested_key}[{chain_key}]: golden={golden_mean:.6f}, " + f"dist={dist_mean:.6f}, rel_diff={rel_diff:.4f} (threshold={chain_rtol})" + ) + if rel_diff > chain_rtol: + errors.append( + f" {nested_key}[{chain_key}]: golden_mean={golden_mean:.6f}, " + f"dist_mean={dist_mean:.6f}, rel_diff={rel_diff:.4f} > {chain_rtol}" + ) + + +def _flatten_chains_ptm(jsons: list[dict]) -> dict[str, list[float]]: + """Collect per-chain PTM values across diffusion samples.""" + result: dict[str, list[float]] = {} + for j in jsons: + for chain_id, val in j.get("chains_ptm", {}).items(): + result.setdefault(str(chain_id), []).append(float(val)) + return result + + +def _flatten_pair_chains_iptm(jsons: list[dict]) -> dict[str, list[float]]: + """Collect per-chain-pair iPTM values across diffusion samples.""" + result: dict[str, list[float]] = {} + for j in jsons: + for id1, inner in j.get("pair_chains_iptm", {}).items(): + for id2, val in inner.items(): + key = f"({id1},{id2})" + result.setdefault(key, []).append(float(val)) + return result + + +def parallel_assert_run_predict_v2( + rank: int, + env_per_rank: dict[str, Any], + kwargs_run_predict: dict[str, Any], + dir_expected_serial: Path, +): + """Worker: run distributed predict, evaluate lDDT/RMSD on CP rank 0.""" + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + run_predict(**kwargs_run_predict) + + size_cp = kwargs_run_predict["size_cp"] + rank_cp = rank % size_cp + out_dir = Path(kwargs_run_predict["out_dir"]) + data_stem = Path(kwargs_run_predict["data"]).stem + results_dir = out_dir / f"boltz_results_{data_stem}" + n_diffusion_samples = kwargs_run_predict["diffusion_samples"] + + if rank_cp != 0: + return + + result_cif_files: dict[str, list[Path]] = {} + for cif in results_dir.rglob("*.cif"): + name_sample = cif.parent.name + result_cif_files.setdefault(name_sample, []).append(cif) + + assert len(result_cif_files) > 0, f"No CIF output files found in {results_dir}" + + for name_sample, cif_files in result_cif_files.items(): + assert ( + len(cif_files) == n_diffusion_samples + ), f"Expected {n_diffusion_samples} CIF files for {name_sample}, found {len(cif_files)} in {results_dir}" + + dist_coords = [cif_to_tensor(f) for f in sorted(cif_files)] + golden_dir = dir_expected_serial / name_sample + golden_cif_files = sorted(golden_dir.glob(f"{name_sample}_model_*.cif")) + assert len(golden_cif_files) > 0, f"No golden CIF files found in {golden_dir}" + golden_coords = [cif_to_tensor(f) for f in golden_cif_files] + + cross_lddt, cross_rmsd = compute_pairwise_lddt_rmsd_matrices(dist_coords, golden_coords) + dd_lddt, dd_rmsd = compute_pairwise_lddt_rmsd_matrices(dist_coords, dist_coords) + ss_lddt, ss_rmsd = compute_pairwise_lddt_rmsd_matrices(golden_coords, golden_coords) + + e_dist_lddt = energy_distance_from_matrices(cross_lddt, dd_lddt, ss_lddt, maximize=True) + e_dist_rmsd = energy_distance_from_matrices(cross_rmsd, dd_rmsd, ss_rmsd, maximize=False) + + matched_lddt = matched_mean_metric(cross_lddt, maximize=True) + matched_rmsd = matched_mean_metric(cross_rmsd, maximize=False) + baseline_lddt = intra_rowwise_best(ss_lddt, maximize=True) + baseline_rmsd = intra_rowwise_best(ss_rmsd, maximize=False) + + stols = _get_structural_tolerances(name_sample) + struct_errors = [] + if e_dist_lddt > stols["energy_dist_lddt"]: + struct_errors.append(f"Energy distance (lDDT) {e_dist_lddt:.6f} > {stols['energy_dist_lddt']}") + if e_dist_rmsd > stols["energy_dist_rmsd"]: + struct_errors.append(f"Energy distance (RMSD) {e_dist_rmsd:.6f} > {stols['energy_dist_rmsd']}") + if baseline_lddt - matched_lddt > stols["matched_lddt_diff"]: + struct_errors.append( + f"Matched lDDT {matched_lddt:.4f} below baseline {baseline_lddt:.4f} " + f"by {baseline_lddt - matched_lddt:.4f} > {stols['matched_lddt_diff']}" + ) + if matched_rmsd - baseline_rmsd > stols["matched_rmsd_diff"]: + struct_errors.append( + f"Matched RMSD {matched_rmsd:.4f} above baseline {baseline_rmsd:.4f} " + f"by {matched_rmsd - baseline_rmsd:.4f} > {stols['matched_rmsd_diff']}" + ) + + # --- Confidence output checks (run before raising struct errors so + # diagnostics are always printed with pytest -s) --- + if kwargs_run_predict.get("confidence_prediction", True): + struct_dir = cif_files[0].parent + + # 1. Confidence summary JSON files — existence and count + dist_jsons_data = _load_confidence_jsons(struct_dir, name_sample) + assert len(dist_jsons_data) == n_diffusion_samples, ( + f"Expected {n_diffusion_samples} confidence JSON files for {name_sample}, " + f"found {len(dist_jsons_data)} in {struct_dir}" + ) + + # 2. Sanity: every JSON has all expected keys with finite, in-range values + for i, conf_data in enumerate(dist_jsons_data): + _assert_confidence_values_sane(conf_data, f"{name_sample} dist model_{i}") + + # 3. Compare against golden serial confidence + golden_sample_dir = dir_expected_serial / name_sample + golden_jsons_data = _load_confidence_jsons(golden_sample_dir, name_sample) + assert len(golden_jsons_data) > 0, ( + f"No golden confidence JSON files found for {name_sample} in {golden_sample_dir}. " + f"Golden files must exist for serial-vs-distributed comparison." + ) + for i, conf_data in enumerate(golden_jsons_data): + _assert_confidence_values_sane(conf_data, f"{name_sample} golden model_{i}") + _compare_confidence_golden_vs_distributed(golden_jsons_data, dist_jsons_data, name_sample) + + if struct_errors: + raise AssertionError( + f"Distributional comparison failed for {name_sample}:\n" + + "\n".join(struct_errors) + + f"\nCheck CIF: {cif_files[0]}" + ) + + # 4. pLDDT npz files + plddt_files = sorted(struct_dir.glob(f"plddt_{name_sample}_model_*.npz")) + assert len(plddt_files) == n_diffusion_samples, ( + f"Expected {n_diffusion_samples} plddt files for {name_sample}, " + f"found {len(plddt_files)} in {struct_dir}" + ) + for pf in plddt_files: + plddt = np.load(pf)["plddt"] + assert plddt.ndim == 1, f"plddt should be 1D, got shape {plddt.shape} in {pf}" + assert np.all(np.isfinite(plddt)), f"plddt contains non-finite values in {pf}" + + # 5. PDE npz files + pde_files = sorted(struct_dir.glob(f"pde_{name_sample}_model_*.npz")) + assert ( + len(pde_files) == n_diffusion_samples + ), f"Expected {n_diffusion_samples} pde files for {name_sample}, found {len(pde_files)} in {struct_dir}" + for df in pde_files: + pde = np.load(df)["pde"] + assert pde.ndim == 2, f"pde should be 2D, got shape {pde.shape} in {df}" + assert pde.shape[0] == pde.shape[1], f"pde should be square, got {pde.shape} in {df}" + assert np.all(np.isfinite(pde)), f"pde contains non-finite values in {df}" + + # 6. PAE npz files (when write_full_pae is enabled) + if kwargs_run_predict.get("write_full_pae", False): + pae_files = sorted(struct_dir.glob(f"pae_{name_sample}_model_*.npz")) + assert len(pae_files) == n_diffusion_samples, ( + f"Expected {n_diffusion_samples} pae files for {name_sample}, " + f"found {len(pae_files)} in {struct_dir}" + ) + for af in pae_files: + pae = np.load(af)["pae"] + assert pae.ndim == 2, f"pae should be 2D, got shape {pae.shape} in {af}" + assert pae.shape[0] == pae.shape[1], f"pae should be square, got {pae.shape} in {af}" + assert np.all(np.isfinite(pae)), f"pae contains non-finite values in {af}" + + +@pytest.mark.predict +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ], + indirect=["setup_env"], + ids=lambda val: f"{val[2]}-dp:{val[0][0]}-cp{'x'.join(map(str, val[0][1]))}", +) +@pytest.mark.parametrize( + "triattn_backend", + [TriAttnBackend.CUEQ, TriAttnBackend.REFERENCE, TriAttnBackend.TRIFAST], + ids=lambda b: b.value, +) +@pytest.mark.parametrize( + "sdpa_backends", + [ + (SDPAWithBiasBackend.REFERENCE, SDPAWithBiasBackend.REFERENCE), + (SDPAWithBiasBackend.TORCH_FLEX_ATTN, SDPAWithBiasBackend.TORCH_FLEX_ATTN), + (SDPAWithBiasBackend.REFERENCE, SDPAWithBiasBackend.TORCH_SDPA_EFFICIENT_ATTENTION), + ], + ids=lambda pair: f"{pair[0].value}-{pair[1].value}", +) +@pytest.mark.parametrize( + "precision", + [Precision.BF16, Precision.BF16_MIXED, Precision.TF32, Precision.FP32], + ids=lambda p: p.value, +) +def test_boltz2_run_predict( + setup_env, + tmp_path, + get_preprocessed_boltz2, + canonical_mols_dir, + get_model_ckpt_v2, + get_inference_golden_value_dir_v2, + triattn_backend, + sdpa_backends, + precision, +): + """Full run_predict end-to-end: verify predicted structures via lDDT/RMSD. + + Uses real preprocessed data, a real Boltz-2 checkpoint, and golden reference + structures. Evaluates predicted CIF output against golden values using + weighted lDDT and RMSD after rigid alignment. + """ + sdpa_with_bias_backend, sdpa_with_bias_shardwise_backend = sdpa_backends + + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + if triattn_backend != TriAttnBackend.CUEQ and "8b2e" not in str(get_preprocessed_boltz2): + pytest.skip(f"{triattn_backend.value} backend only tested with 8b2e sample") + + if triattn_backend == TriAttnBackend.CUEQ and not cueq_is_installed: + pytest.skip("cuequivariance_torch is not installed") + if triattn_backend == TriAttnBackend.TRIFAST and not trifast_is_installed: + pytest.skip("trifast is not installed") + + is_flex_flex = ( + sdpa_with_bias_backend == SDPAWithBiasBackend.TORCH_FLEX_ATTN + and sdpa_with_bias_shardwise_backend == SDPAWithBiasBackend.TORCH_FLEX_ATTN + ) + if not is_flex_flex and "8b2e" not in str(get_preprocessed_boltz2): + pytest.skip("Non-(flex,flex) SDPA combos only tested with 8b2e sample") + + sample_name = get_preprocessed_boltz2.name + if precision in (Precision.TF32, Precision.FP32) and sample_name not in ( + "processed_7ylz", + "processed_8b2e", + "processed_7z64", + ): + pytest.skip(f"{precision.value} precision only tested with 7ylz, 8b2e, and 7z64 samples") + + result_dir = tmp_path / "result" + kwargs_run_predict = { + "data": str(get_preprocessed_boltz2), + "out_dir": str(result_dir), + "mol_dir": str(canonical_mols_dir), + "checkpoint": str(get_model_ckpt_v2), + "size_dp": grid_group_sizes["dp"], + "size_cp": math.prod(grid_group_sizes["cp"]), + "accelerator": "gpu", + "recycling_steps": 10, + "sampling_steps": 200, + "diffusion_samples": _DIFFUSION_SAMPLES, + "max_msa_seqs": 2048, + "msa_pad_to_max_seqs": True, + "seed": 42, + "timeout_nccl": 30, + "timeout_gloo": 30, + "precision": precision, + "pair_mask_mode": PairMaskMode.NONE, + "atoms_per_window_queries_keys": (32, 128), + "use_templates": False, + "confidence_prediction": True, + "write_full_pae": True, + "triattn_backend": triattn_backend, + "sdpa_with_bias_backend": sdpa_with_bias_backend, + "sdpa_with_bias_shardwise_backend": sdpa_with_bias_shardwise_backend, + } + if _DIFFUSION_SAMPLES > _PRE_GENERATED_GOLDEN_SAMPLES: + golden_dir = _run_serial_predict( + get_preprocessed_boltz2, + get_model_ckpt_v2, + canonical_mols_dir.parent, + _DIFFUSION_SAMPLES, + ) + else: + golden_dir = get_inference_golden_value_dir_v2 + + spawn_multiprocessing( + parallel_assert_run_predict_v2, + world_size, + env_per_rank, + kwargs_run_predict, + golden_dir, + ) + + +@pytest.mark.predict +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=["setup_env"], + ids=lambda val: f"{val[2]}-dp:{val[0][0]}-cp{'x'.join(map(str, val[0][1]))}", +) +def test_boltz2_run_predict_dp2( + setup_env, + tmp_path, + test_cp_training_base_data_dir_boltz2, + canonical_mols_dir, + get_model_ckpt_v2, + get_inference_golden_value_dir_v2, +): + """End-to-end dp=2 run_predict: CUEQ + flex-flex + BF16_MIXED on all 4 samples.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + if not cueq_is_installed: + pytest.skip("cuequivariance_torch is not installed") + + names = ["7ylz", "7z64", "8ayv", "8b2e"] + data_dir = concat_data( + Path(tmp_path) / "processed_collection", + *[test_cp_training_base_data_dir_boltz2 / f"processed_{name}" for name in names], + ) + if grid_group_sizes["dp"] > len(names): + pytest.skip(f"dp ({grid_group_sizes['dp']}) exceeds number of samples ({len(names)})") + + result_dir = tmp_path / "result" + kwargs_run_predict = { + "data": str(data_dir), + "out_dir": str(result_dir), + "mol_dir": str(canonical_mols_dir), + "checkpoint": str(get_model_ckpt_v2), + "size_dp": grid_group_sizes["dp"], + "size_cp": math.prod(grid_group_sizes["cp"]), + "accelerator": "gpu", + "recycling_steps": 10, + "sampling_steps": 200, + "diffusion_samples": _DIFFUSION_SAMPLES, + "max_msa_seqs": 2048, + "msa_pad_to_max_seqs": True, + "seed": 42, + "timeout_nccl": 30, + "timeout_gloo": 30, + "precision": Precision.BF16_MIXED, + "pair_mask_mode": PairMaskMode.NONE, + "atoms_per_window_queries_keys": (32, 128), + "use_templates": False, + "confidence_prediction": True, + "write_full_pae": True, + "triattn_backend": TriAttnBackend.CUEQ, + "sdpa_with_bias_backend": SDPAWithBiasBackend.TORCH_FLEX_ATTN, + "sdpa_with_bias_shardwise_backend": SDPAWithBiasBackend.TORCH_FLEX_ATTN, + } + if _DIFFUSION_SAMPLES > _PRE_GENERATED_GOLDEN_SAMPLES: + golden_dir = _run_serial_predict( + data_dir, + get_model_ckpt_v2, + canonical_mols_dir.parent, + _DIFFUSION_SAMPLES, + ) + else: + golden_dir = get_inference_golden_value_dir_v2 + + spawn_multiprocessing( + parallel_assert_run_predict_v2, + world_size, + env_per_rank, + kwargs_run_predict, + golden_dir, + ) + + +SM100F_WARNING_SUBSTR = "Can't use SM100f kernel because q.shape[3] is not a multiple of 8" + + +def parallel_assert_sm100f_warning( + rank: int, + env_per_rank: dict[str, Any], + kwargs_run_predict: dict[str, Any], + expect_warning: bool, +): + """Worker: run distributed predict and assert SM100f warning presence/absence.""" + import warnings as _warnings + + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + monkeypatch.setenv(var_name, f"{rank}" if value == "" else value) + + with _warnings.catch_warnings(record=True) as caught: + _warnings.simplefilter("always") + run_predict(**kwargs_run_predict) + + sm100f_msgs = [w for w in caught if SM100F_WARNING_SUBSTR in str(w.message)] + if expect_warning: + assert sm100f_msgs, f"Rank {rank}: expected SM100f warning but none was emitted" + else: + assert not sm100f_msgs, f"Rank {rank}: SM100f warning(s) emitted ({len(sm100f_msgs)}): " + "; ".join( + str(w.message) for w in sm100f_msgs + ) + + +@pytest.mark.predict +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ], + indirect=["setup_env"], + ids=lambda val: f"{val[2]}-dp:{val[0][0]}-cp{'x'.join(map(str, val[0][1]))}", +) +@pytest.mark.parametrize( + "precision", + [Precision.BF16, Precision.BF16_MIXED], + ids=lambda p: p.value, +) +@pytest.mark.parametrize( + "auto_pad_tokens_for_sm100f", + [True, False], + ids=lambda v: f"auto_pad={v}", +) +def test_boltz2_run_predict_auto_pad_for_sm100f( + setup_env, + tmp_path, + canonical_mols_dir, + get_model_ckpt_v2, + test_cp_training_base_data_dir_boltz2, + precision, + auto_pad_tokens_for_sm100f, +): + """Verify SM100f auto-padding: no cuEq warning with auto_pad=True, warning with False. + + Uses processed_8b2e (145 tokens), which is NOT a multiple of sqrt(size_cp)*8 = 16, + so the test is non-vacuous. With auto_pad=True, tokens are padded to 160 -> each + shard gets 80 (divisible by 8). With auto_pad=False, tokens are padded to 146 -> + each shard gets 73 (not divisible by 8), triggering the SM100f warning. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + if not cueq_is_installed: + pytest.skip("cuequivariance_torch is not installed") + + device_cc = torch.cuda.get_device_capability() + if device_cc not in ((10, 0), (10, 3)): + pytest.skip(f"GPU compute capability {device_cc} is not SM100/SM103") + + sample_dir = test_cp_training_base_data_dir_boltz2 / "processed_8b2e" + + result_dir = tmp_path / "result" + kwargs_run_predict = { + "data": str(sample_dir), + "out_dir": str(result_dir), + "mol_dir": str(canonical_mols_dir), + "checkpoint": str(get_model_ckpt_v2), + "size_dp": grid_group_sizes["dp"], + "size_cp": math.prod(grid_group_sizes["cp"]), + "accelerator": "gpu", + "recycling_steps": 1, + "sampling_steps": 2, + "diffusion_samples": 1, + "max_msa_seqs": 16, + "msa_pad_to_max_seqs": True, + "seed": 42, + "timeout_nccl": 30, + "timeout_gloo": 30, + "precision": precision, + "pair_mask_mode": PairMaskMode.NONE, + "atoms_per_window_queries_keys": (32, 128), + "use_templates": False, + "triattn_backend": TriAttnBackend.CUEQ, + "sdpa_with_bias_backend": SDPAWithBiasBackend.TORCH_FLEX_ATTN, + "sdpa_with_bias_shardwise_backend": SDPAWithBiasBackend.TORCH_FLEX_ATTN, + "auto_pad_tokens_for_sm100f": auto_pad_tokens_for_sm100f, + } + expect_warning = not auto_pad_tokens_for_sm100f + spawn_multiprocessing( + parallel_assert_sm100f_warning, + world_size, + env_per_rank, + kwargs_run_predict, + expect_warning, + ) + + +_yaml_serial_golden_cache: dict[str, Path] = {} + + +def _run_serial_predict_yaml( + yaml_path: Path, + checkpoint: Path, + cache_dir: Path, + diffusion_samples: int, +) -> Path: + """Run serial (1-GPU) Boltz-2 predict on a YAML config file. + + Uses ``boltz predict --input_format config_files`` via subprocess for + complete process isolation. Results are cached per YAML stem under + ``infer_cache/inference_yaml_examples/``— both on-disk (persists across + sessions) and in-memory (avoids redundant filesystem checks within the + same session). Serial inference is skipped when the golden dir already + contains CIF files. + + Returns the predictions directory (containing per-protein subdirs). + """ + cache_key = yaml_path.stem + if cache_key in _yaml_serial_golden_cache: + return _yaml_serial_golden_cache[cache_key] + + out_dir = Path("infer_cache/inference_yaml_examples") / cache_key + golden_dir = out_dir / f"boltz_results_{cache_key}" / "predictions" + if golden_dir.exists() and list(golden_dir.rglob("*.cif")): + print(f"\n [serial golden yaml] Reusing cached golden values for {cache_key} at {golden_dir}") + _yaml_serial_golden_cache[cache_key] = golden_dir + return golden_dir + + out_dir.mkdir(parents=True, exist_ok=True) + print(f"\n [serial golden yaml] Generating {diffusion_samples} serial samples for {cache_key} -> {out_dir}") + + cmd = [ + "boltz", + "predict", + str(yaml_path), + "--out_dir", + str(out_dir), + "--cache", + str(cache_dir), + "--checkpoint", + str(checkpoint), + "--diffusion_samples", + str(diffusion_samples), + "--seed", + "42", + "--input_format", + "config_files", + "--write_full_pae", + "--devices", + "1", + "--accelerator", + "gpu", + "--recycling_steps", + "10", + "--sampling_steps", + "200", + "--model", + "boltz2", + "--max_msa_seqs", + "2048", + "--override", + ] + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError( + f"Serial predict (yaml) failed (exit {result.returncode}) for {cache_key}.\n" + f"Command: {' '.join(cmd)}\n" + f"--- stdout ---\n{result.stdout[-2000:] if result.stdout else '(empty)'}\n" + f"--- stderr ---\n{result.stderr[-2000:] if result.stderr else '(empty)'}" + ) + + golden_dir = out_dir / f"boltz_results_{cache_key}" / "predictions" + if not golden_dir.exists(): + raise RuntimeError( + f"Serial predict (yaml) exited 0 but {golden_dir} does not exist.\n" + f"Contents of {out_dir}: {list(out_dir.rglob('*'))}\n" + f"--- stdout ---\n{result.stdout[-2000:] if result.stdout else '(empty)'}\n" + f"--- stderr ---\n{result.stderr[-2000:] if result.stderr else '(empty)'}" + ) + _yaml_serial_golden_cache[cache_key] = golden_dir + return golden_dir + + +@pytest.mark.predict +@pytest.mark.slow +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ], + indirect=["setup_env"], + ids=lambda val: f"{val[2]}-dp:{val[0][0]}-cp{'x'.join(map(str, val[0][1]))}", +) +@pytest.mark.parametrize( + "sdpa_backends", + [ + (SDPAWithBiasBackend.TORCH_FLEX_ATTN, SDPAWithBiasBackend.TORCH_FLEX_ATTN), + ], + ids=lambda pair: f"{pair[0].value}-{pair[1].value}", +) +def test_boltz2_run_predict_yaml( + setup_env, + tmp_path, + canonical_mols_dir, + get_model_ckpt_v2, + get_inference_golden_value_dir_v2, + sdpa_backends, +): + """End-to-end run_predict with input_format='config_files' on a YAML example. + + Uses prot_custom_msa.yaml (which embeds a custom MSA path, avoiding the + MSA server dependency). When ``_DIFFUSION_SAMPLES`` is at most + ``_PRE_GENERATED_GOLDEN_SAMPLES`` (default 5), the pre-generated golden + archive is reused; otherwise serial inference is run on-the-fly. + """ + from pathlib import Path as _Path + + EXAMPLE_PROT_CUSTOM_MSA_YAML = _Path(__file__).resolve().parents[2] / "examples" / "prot_custom_msa.yaml" + + sdpa_with_bias_backend, sdpa_with_bias_shardwise_backend = sdpa_backends + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + if not cueq_is_installed: + pytest.skip("cuequivariance_torch is not installed") + + if _DIFFUSION_SAMPLES > _PRE_GENERATED_GOLDEN_SAMPLES: + golden_dir = _run_serial_predict_yaml( + EXAMPLE_PROT_CUSTOM_MSA_YAML, + get_model_ckpt_v2, + canonical_mols_dir.parent, + _DIFFUSION_SAMPLES, + ) + else: + golden_dir = get_inference_golden_value_dir_v2 + + result_dir = tmp_path / "result" + kwargs_run_predict = { + "data": str(EXAMPLE_PROT_CUSTOM_MSA_YAML), + "out_dir": str(result_dir), + "mol_dir": str(canonical_mols_dir), + "checkpoint": str(get_model_ckpt_v2), + "size_dp": grid_group_sizes["dp"], + "size_cp": math.prod(grid_group_sizes["cp"]), + "input_format": "config_files", + "accelerator": "gpu", + "recycling_steps": 10, + "sampling_steps": 200, + "diffusion_samples": _DIFFUSION_SAMPLES, + "max_msa_seqs": 2048, + "msa_pad_to_max_seqs": True, + "seed": 42, + "timeout_nccl": 30, + "timeout_gloo": 30, + "precision": Precision.BF16_MIXED, + "pair_mask_mode": PairMaskMode.NONE, + "atoms_per_window_queries_keys": (32, 128), + "use_templates": False, + "confidence_prediction": True, + "write_full_pae": True, + "triattn_backend": TriAttnBackend.CUEQ, + "sdpa_with_bias_backend": sdpa_with_bias_backend, + "sdpa_with_bias_shardwise_backend": sdpa_with_bias_shardwise_backend, + "override": True, + } + + spawn_multiprocessing( + parallel_assert_run_predict_v2, + world_size, + env_per_rank, + kwargs_run_predict, + golden_dir, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/distributed/test_dtensor_stop_and_go.py b/tests/distributed/test_dtensor_stop_and_go.py new file mode 100644 index 000000000..9fec89787 --- /dev/null +++ b/tests/distributed/test_dtensor_stop_and_go.py @@ -0,0 +1,1061 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Distributed DTensor CP stop/go (checkpoint resume) parity tests. + +The primary stop/go test routes through the real ``train()`` entrypoint, +exercising the full save/resume cycle including +:class:`BoltzContextParallelStrategy` checkpoint conversion, checkpoint +callback defaults, and the ``cfg.resume`` auto-resume path. + +Cross-mode tests (serial <-> distributed) use direct ``Trainer.fit()`` calls +because the serial leg cannot use ``train()`` (which always creates a +``BoltzContextParallelStrategy``). +""" + +from pathlib import Path +from typing import Any + +import pytest +import pytorch_lightning as pl +import torch +from omegaconf import OmegaConf +from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict +from torch.distributed.tensor import DTensor, distribute_tensor + +import boltz.distributed.train as train_module +from boltz.distributed.lightning_strategy import BoltzContextParallelStrategy +from boltz.distributed.manager import DistributedManager +from boltz.distributed.model.models.boltz2 import Boltz2 as Boltz2Distributed +from boltz.distributed.model.models.boltz2 import _PlaceholderModule +from boltz.distributed.model.modules.utils import has_dtensors +from boltz.model.models.boltz2 import Boltz2 as SerialBoltz2 +from boltz.testing.utils import spawn_multiprocessing + +from .dtensor_train_harness import ( + DistogramTrainDataModule, + TinyDistogramCPModel, + TinyDistogramCPModelWithEMA, + TinyDistogramSerialModel, + create_initial_serial_state_dict, + create_train_dataloader, +) +from .model.models.test_dtensor_boltz2 import _prepare_serial_model + + +def _to_local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to local tensor for comparison.""" + return t.to_local() if isinstance(t, DTensor) else t + + +def _assert_optimizer_states_match( + opt_a: torch.optim.Optimizer, + opt_b: torch.optim.Optimizer, + label: str, +) -> None: + """Assert that two optimizers have matching state (exp_avg, exp_avg_sq, step). + + This compares the per-parameter state buffers that Adam maintains. + If these diverge after resume, future gradient updates will be wrong. + """ + state_a = opt_a.state_dict()["state"] + state_b = opt_b.state_dict()["state"] + assert state_a.keys() == state_b.keys(), f"{label}: optimizer state key mismatch" + for param_idx in state_a: + for buf_key in state_a[param_idx]: + val_a = state_a[param_idx][buf_key] + val_b = state_b[param_idx][buf_key] + if isinstance(val_a, torch.Tensor): + torch.testing.assert_close( + _to_local(val_a).cpu(), + _to_local(val_b).cpu(), + msg=lambda msg, k=buf_key, p=param_idx: ( + f"{label}: optimizer state mismatch for param {p}, key '{k}'\n{msg}" + ), + ) + else: + assert val_a == val_b, f"{label}: optimizer scalar mismatch param {param_idx}, key '{buf_key}'" + + +# --------------------------------------------------------------------------- +# Distogram stop/go via train.py entrypoint +# --------------------------------------------------------------------------- + + +def _write_distogram_config( + *, + config_path: Path, + output_dir: Path, + size_dp: int, + size_cp: int, + accelerator: str = "cpu", + max_epochs: int = 2, + limit_train_batches: int = 2, + resume: str | None = None, + weights_seed: int = 37, + data_seed: int = 53, + token_z: int = 16, + num_bins: int = 8, + num_distograms: int = 2, + num_conformers: int = 2, + seq_len: int = 12, + num_samples: int = 2, + learning_rate: float = 1e-2, + ema_decay: float = 0.999, +) -> None: + """Write a YAML config for a distogram stop/go ``train()`` run.""" + config: dict[str, Any] = { + "data": { + "seq_len": seq_len, + "token_z": token_z, + "num_bins": num_bins, + "num_conformers": num_conformers, + "num_samples": num_samples, + "seed": data_seed, + }, + "model": { + "token_z": token_z, + "num_bins": num_bins, + "num_distograms": num_distograms, + "num_conformers": num_conformers, + "weights_seed": weights_seed, + "learning_rate": learning_rate, + "ema_decay": ema_decay, + }, + "output": str(output_dir), + "trainer": { + "accelerator": accelerator, + "devices": 1, + "max_epochs": max_epochs, + "limit_train_batches": limit_train_batches, + "enable_progress_bar": False, + "enable_model_summary": False, + "num_sanity_val_steps": 0, + }, + "parallel_size": {"size_dp": size_dp, "size_cp": size_cp}, + "precision": "FP32", + "find_unused_parameters": False, + "save_top_k": -1, + "disable_checkpoint": False, + "debug": False, + "validation_only": False, + "seed": 11, + "checkpoint": { + "monitor": None, + "save_last": True, + "every_n_epochs": 1, + }, + } + if resume is not None: + config["resume"] = resume + config_path.parent.mkdir(parents=True, exist_ok=True) + OmegaConf.save(OmegaConf.create(config), config_path) + + +def _instantiate_distogram_config(config_dict: Any) -> dict[str, Any]: + """Replace ``hydra.utils.instantiate`` for distogram stop/go tests. + + Returns the config as a plain dict; ``_create_distogram_distributed_model`` + handles model creation from config + dist_manager. + """ + cfg = OmegaConf.to_container(config_dict, resolve=True) + assert isinstance(cfg, dict) + return cfg + + +def _create_distogram_distributed_model( + cfg: Any, + dist_manager: DistributedManager, +) -> TinyDistogramCPModelWithEMA: + """Monkeypatched ``_create_distributed_model`` for distogram tests.""" + model_cfg = cfg.model + serial_state = create_initial_serial_state_dict( + token_z=model_cfg["token_z"], + num_bins=model_cfg["num_bins"], + num_distograms=model_cfg["num_distograms"], + seed=model_cfg["weights_seed"], + ) + return TinyDistogramCPModelWithEMA( + dist_manager=dist_manager, + token_z=model_cfg["token_z"], + num_bins=model_cfg["num_bins"], + num_distograms=model_cfg["num_distograms"], + num_conformers=model_cfg["num_conformers"], + serial_state_dict=serial_state, + learning_rate=model_cfg["learning_rate"], + ema_decay=model_cfg.get("ema_decay", 0.999), + ) + + +def _create_distogram_distributed_data_module( + data_config: Any, + dist_manager: DistributedManager, +) -> DistogramTrainDataModule: + """Monkeypatched ``_create_distributed_data_module`` for distogram tests.""" + return DistogramTrainDataModule( + seq_len=data_config["seq_len"], + token_z=data_config["token_z"], + num_bins=data_config["num_bins"], + num_conformers=data_config["num_conformers"], + num_samples=data_config["num_samples"], + seed=data_config["seed"], + dp_size=dist_manager.group["dp"].size(), + ) + + +def _assert_checkpoint_optimizer_states_match( + opt_a: dict[str, Any], + opt_b: dict[str, Any], + label: str, +) -> None: + """Assert optimizer state dicts from two checkpoints match. + + Handles both FQN (string) and legacy integer keys transparently. + """ + state_a = opt_a["state"] + state_b = opt_b["state"] + assert state_a.keys() == state_b.keys(), f"{label}: optimizer state key mismatch" + for param_key in state_a: + for buf_key in state_a[param_key]: + val_a = state_a[param_key][buf_key] + val_b = state_b[param_key][buf_key] + if isinstance(val_a, torch.Tensor): + torch.testing.assert_close( + val_a, + val_b, + msg=lambda msg, k=buf_key, p=param_key: ( + f"{label}: optimizer state mismatch for param {p}, key '{k}'\n{msg}" + ), + ) + else: + assert val_a == val_b, f"{label}: optimizer scalar mismatch param {param_key}, key '{buf_key}'" + + +def _parallel_assert_dtensor_stop_and_go_ema(rank: int, payload: tuple[Any, ...]) -> None: + """Verify stop/go parity through the real ``train()`` entrypoint. + + Runs three ``train()`` calls: + 1. Continuous baseline — 2 epochs, checkpointing enabled + 2. Stop/go stage 1 — 1 epoch, checkpointing enabled + 3. Stop/go stage 2 — resume from checkpoint, complete to epoch 2 + + Compares the final ``last.ckpt`` files for exact parity: model weights, + optimizer state, EMA shadow weights, epoch, and global step. This + validates the full save→resume cycle through ``train.py``, including + strategy checkpoint conversion and ``cfg.resume``. + """ + ( + env_per_rank, + continuous_config_path, + stage1_config_path, + stage2_config_path, + continuous_dir, + stopgo_dir, + ) = payload + continuous_dir = Path(continuous_dir) + stopgo_dir = Path(stopgo_dir) + + monkeypatch = pytest.MonkeyPatch() + for key, value in env_per_rank.items(): + monkeypatch.setenv(key, f"{rank}" if value == "" else value) + + monkeypatch.setattr(train_module.hydra.utils, "instantiate", _instantiate_distogram_config) + monkeypatch.setattr(train_module, "_create_distributed_model", _create_distogram_distributed_model) + monkeypatch.setattr(train_module, "_create_distributed_data_module", _create_distogram_distributed_data_module) + # Suppress per-call cleanup so process groups survive across the 3 + # sequential train() calls; the test's own finally block cleans up. + monkeypatch.setattr(train_module, "_cleanup_distributed", lambda: None) + DistributedManager._state = {} + + try: + # ---- Continuous baseline: 2 epochs in one run. ---- + # _create_dist_manager initializes process groups on the first call; + # subsequent calls reuse the existing DistributedManager singleton. + train_module.train(str(continuous_config_path), []) + + # ---- Stop/go stage 1: 1 epoch, checkpoint produced. ---- + train_module.train(str(stage1_config_path), []) + ckpt_path = stopgo_dir / "last.ckpt" + assert ckpt_path.exists(), f"Rank {rank}: stage 1 checkpoint not found at {ckpt_path}" + + # Sanity: stage-1 (1 epoch) must differ from the continuous run + # (2 epochs). Guards against a vacuous test where both checkpoints + # are identical before resume even happens. + continuous_ckpt_early = torch.load(continuous_dir / "last.ckpt", map_location="cpu", weights_only=False) + stage1_ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + assert ( + continuous_ckpt_early["epoch"] != stage1_ckpt["epoch"] + or continuous_ckpt_early["global_step"] != stage1_ckpt["global_step"] + ), ( + "Stage-1 checkpoint should differ from the 2-epoch continuous checkpoint " + f"(both have epoch={stage1_ckpt['epoch']}, step={stage1_ckpt['global_step']})" + ) + weights_differ = any( + not torch.equal(continuous_ckpt_early["state_dict"][k], stage1_ckpt["state_dict"][k]) + for k in stage1_ckpt["state_dict"] + ) + assert weights_differ, "Stage-1 weights should differ from 2-epoch continuous weights" + + # ---- Stop/go stage 2: resume from checkpoint to epoch 2. ---- + train_module.train(str(stage2_config_path), []) + + # ---- Compare checkpoint files for parity. ---- + continuous_ckpt = torch.load(continuous_dir / "last.ckpt", map_location="cpu", weights_only=False) + stopgo_ckpt = torch.load(stopgo_dir / "last.ckpt", map_location="cpu", weights_only=False) + + # Checkpoints should contain only plain tensors (strategy strips DTensors). + assert not has_dtensors(continuous_ckpt["state_dict"]), "Continuous checkpoint has DTensors" + assert not has_dtensors(stopgo_ckpt["state_dict"]), "Stop/go checkpoint has DTensors" + + # 1) Model weights must match. + assert continuous_ckpt["state_dict"].keys() == stopgo_ckpt["state_dict"].keys(), "state_dict key mismatch" + for key in continuous_ckpt["state_dict"]: + torch.testing.assert_close( + continuous_ckpt["state_dict"][key], + stopgo_ckpt["state_dict"][key], + msg=lambda msg, k=key: f"Stop/go weight mismatch for {k} on rank {rank}\n{msg}", + ) + + # 2) Epoch and global step must match. + assert ( + continuous_ckpt["epoch"] == stopgo_ckpt["epoch"] + ), f"Epoch mismatch: continuous={continuous_ckpt['epoch']}, stopgo={stopgo_ckpt['epoch']}" + assert ( + continuous_ckpt["global_step"] == stopgo_ckpt["global_step"] + ), f"Step mismatch: continuous={continuous_ckpt['global_step']}, stopgo={stopgo_ckpt['global_step']}" + + # 3) Optimizer state (Adam exp_avg / exp_avg_sq / step) must match. + _assert_checkpoint_optimizer_states_match( + continuous_ckpt["optimizer_states"][0], + stopgo_ckpt["optimizer_states"][0], + label=f"rank {rank}", + ) + + # 3b) Optimizer state keys must be FQN strings (not legacy integers). + opt_state_keys = list(continuous_ckpt["optimizer_states"][0]["state"].keys()) + assert opt_state_keys, f"Rank {rank}: optimizer state is empty" + assert all(isinstance(k, str) for k in opt_state_keys), ( + f"Rank {rank}: optimizer state keys should be FQN strings, " + f"got {[type(k).__name__ for k in opt_state_keys[:3]]}" + ) + + # 4) EMA shadow weights and step counter must match. + assert "ema" in continuous_ckpt, "Continuous checkpoint missing EMA state" + assert "ema" in stopgo_ckpt, "Stop/go checkpoint missing EMA state" + assert continuous_ckpt["ema"]["cur_step"] == stopgo_ckpt["ema"]["cur_step"], ( + f"EMA step mismatch: continuous={continuous_ckpt['ema']['cur_step']}, " + f"stopgo={stopgo_ckpt['ema']['cur_step']}" + ) + for key in continuous_ckpt["ema"]["ema_weights"]: + torch.testing.assert_close( + continuous_ckpt["ema"]["ema_weights"][key], + stopgo_ckpt["ema"]["ema_weights"][key], + msg=lambda msg, k=key: f"EMA weight mismatch for {k} on rank {rank}\n{msg}", + ) + finally: + DistributedManager.cleanup() + DistributedManager._state = {} + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (3, 3)), True, "cpu", "ENV"), + ((2, (1, 1)), True, "cuda", "ENV"), + ((1, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cpu-dp2-cp3x3", "cuda-dp2-cp1x1", "cuda-dp1-cp2x2"], +) +def test_stop_and_go_via_train_entrypoint(setup_env, tmp_path): + """Goals: checkpoint resume parity through the real ``train()`` entrypoint. + + - Continuous 2-epoch run matches stop-at-epoch-1 + resume-to-epoch-2 + - Model weights, optimizer state, EMA shadow weights all match exactly + - Epoch and global_step counters match + - Validates ``BoltzContextParallelStrategy`` checkpoint conversion roundtrip + - Validates ``cfg.resume`` auto-resume path in ``train.py`` + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + + ema_decay = 0.999 + size_dp = int(grid_group_sizes["dp"]) + cp_group = grid_group_sizes["cp"] + size_cp = int(cp_group[0] * cp_group[1]) if isinstance(cp_group, tuple) else int(cp_group) + accelerator = "gpu" if device_type == "cuda" else "cpu" + + continuous_dir = tmp_path / "continuous" + stopgo_dir = tmp_path / "stopgo" + + common_kwargs: dict[str, Any] = { + "size_dp": size_dp, + "size_cp": size_cp, + "accelerator": accelerator, + "ema_decay": ema_decay, + } + + # Continuous baseline: 2 epochs. + continuous_config = continuous_dir / "config.yaml" + _write_distogram_config(config_path=continuous_config, output_dir=continuous_dir, max_epochs=2, **common_kwargs) + + # Stage 1: 1 epoch with checkpoint. + stage1_config = stopgo_dir / "config_stage1.yaml" + _write_distogram_config(config_path=stage1_config, output_dir=stopgo_dir, max_epochs=1, **common_kwargs) + + # Stage 2: resume to epoch 2. + stage2_config = stopgo_dir / "config_stage2.yaml" + _write_distogram_config( + config_path=stage2_config, + output_dir=stopgo_dir, + max_epochs=2, + resume=str(stopgo_dir / "last.ckpt"), + **common_kwargs, + ) + + payload = ( + env_per_rank, + str(continuous_config), + str(stage1_config), + str(stage2_config), + str(continuous_dir), + str(stopgo_dir), + ) + spawn_multiprocessing(_parallel_assert_dtensor_stop_and_go_ema, world_size, payload) + + +# --------------------------------------------------------------------------- +# Cross-mode stop/go: serial <-> distributed checkpoint interop (both dirs) +# --------------------------------------------------------------------------- + + +def _parallel_assert_cross_mode_stop_and_go(rank: int, payload: tuple[Any, ...]) -> None: + """Both cross-mode directions in one worker. + + Direction 1 (serial → distributed): + 1-epoch serial train → checkpoint → resume as distributed for epoch 2. + Compared against a 2-epoch continuous distributed baseline. + + Direction 2 (distributed → serial): + 1-epoch distributed train → checkpoint → resume as serial for epoch 2. + Compared against a 2-epoch continuous serial baseline (rank 0 only). + """ + grid_group_sizes, device_type, backend, env_per_rank, output_dir = payload + output_dir = Path(output_dir) + + token_z, num_bins, num_distograms, num_conformers, seq_len = 16, 8, 2, 2, 12 + + monkeypatch = pytest.MonkeyPatch() + for key, value in env_per_rank.items(): + monkeypatch.setenv(key, f"{rank}" if value == "" else value) + + try: + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + dp_size = manager.group["dp"].size() + + serial_state = create_initial_serial_state_dict( + token_z=token_z, + num_bins=num_bins, + num_distograms=num_distograms, + seed=41, + ) + + model_kwargs: dict[str, Any] = { + "token_z": token_z, + "num_bins": num_bins, + "num_distograms": num_distograms, + "num_conformers": num_conformers, + "serial_state_dict": serial_state, + "learning_rate": 1e-2, + } + + def _make_dl(): + return create_train_dataloader( + seq_len=seq_len, + token_z=token_z, + num_bins=num_bins, + num_conformers=num_conformers, + num_samples=2, + seed=53, + dp_size=dp_size, + ) + + def _trainer(root, *, strategy=None, ckpt_cb=None, epochs=2): + kw: dict[str, Any] = { + "default_root_dir": str(root), + "accelerator": "cpu" if device_type == "cpu" else "gpu", + "devices": 1, + "max_epochs": epochs, + "limit_train_batches": 2, + "logger": False, + "enable_progress_bar": False, + "enable_model_summary": False, + } + if strategy is not None: + kw["strategy"] = strategy + kw["use_distributed_sampler"] = False + if ckpt_cb: + kw["callbacks"] = [ckpt_cb] + kw["enable_checkpointing"] = True + else: + kw["callbacks"] = [] + kw["enable_checkpointing"] = False + return pl.Trainer(**kw) + + def _ckpt_cb(dirpath): + return pl.callbacks.ModelCheckpoint( + dirpath=str(dirpath), + filename="epoch-{epoch:02d}", + every_n_epochs=1, + save_top_k=-1, + save_last=True, + ) + + def _check_parity(state_cont, state_res, trainer_cont, trainer_res, model_cont, model_res, label): + for key in state_cont: + torch.testing.assert_close( + state_res[key], state_cont[key], msg=lambda msg, k=key: f"{label} mismatch for {k}\n{msg}" + ) + assert trainer_cont.current_epoch == trainer_res.current_epoch + assert trainer_cont.global_step == trainer_res.global_step + _assert_optimizer_states_match(trainer_cont.optimizers[0], trainer_res.optimizers[0], label=label) + e2_c = {k: v for k, v in model_cont._loss_log.items() if k[0] == 1} + e2_r = {k: v for k, v in model_res._loss_log.items() if k[0] == 1} + assert e2_c.keys() == e2_r.keys() + for k in sorted(e2_c): + assert e2_c[k] == pytest.approx(e2_r[k]), f"{label}: loss mismatch at {k}" + + # ================================================================ + # Direction 1: serial → distributed + # ================================================================ + s2d = output_dir / "s2d" + + # 2-epoch continuous distributed baseline. + s2d_cont_model = TinyDistogramCPModel(dist_manager=manager, **model_kwargs) + s2d_cont_strat = BoltzContextParallelStrategy(dist_manager=manager) + s2d_cont_tr = _trainer(s2d / "cont", strategy=s2d_cont_strat) + s2d_cont_tr.fit(s2d_cont_model, train_dataloaders=_make_dl()) + state_s2d_cont = s2d_cont_strat.lightning_module_state_dict() + + # Stage 1: 1-epoch serial (rank 0 only). + s2d_ckpt = s2d / "serial" / "last.ckpt" + if manager.rank == 0: + s2d_s1 = TinyDistogramSerialModel(**model_kwargs) + s2d_s1_tr = _trainer(s2d / "serial", ckpt_cb=_ckpt_cb(s2d / "serial"), epochs=1) + s2d_s1_tr.fit(s2d_s1, train_dataloaders=_make_dl()) + assert s2d_ckpt.exists() + # Guard: serial stage-1 must differ from initial (training was not a no-op). + for key, init_val in serial_state.items(): + prefixed = f"distogram_module.{key}" + assert not torch.equal( + s2d_s1.state_dict()[prefixed].cpu(), init_val + ), f"Serial stage-1 weight '{key}' unchanged after 1 epoch" + # Verify serial checkpoint uses integer keys (standard Lightning). + s2d_s1_ckpt_data = torch.load(s2d_ckpt, map_location="cpu", weights_only=False) + s2d_s1_opt_keys = list(s2d_s1_ckpt_data["optimizer_states"][0]["state"].keys()) + assert s2d_s1_opt_keys, "Serial checkpoint optimizer state is empty" + assert all(isinstance(k, int) for k in s2d_s1_opt_keys), ( + f"Serial checkpoint should have integer optimizer keys, " + f"got types {[type(k).__name__ for k in s2d_s1_opt_keys[:3]]}" + ) + torch.distributed.barrier() + + # Stage 2: resume as distributed (loads serial int-key checkpoint via + # the legacy path in load_optimizer_state_dict). + s2d_s2 = TinyDistogramCPModel(dist_manager=manager, **model_kwargs) + s2d_s2_strat = BoltzContextParallelStrategy(dist_manager=manager) + s2d_s2_tr = _trainer(s2d / "resume", strategy=s2d_s2_strat) + s2d_s2_tr.fit(s2d_s2, train_dataloaders=_make_dl(), ckpt_path=str(s2d_ckpt)) + state_s2d_res = s2d_s2_strat.lightning_module_state_dict() + + _check_parity( + state_s2d_cont, + state_s2d_res, + s2d_cont_tr, + s2d_s2_tr, + s2d_cont_model, + s2d_s2, + f"rank {rank} serial→distributed", + ) + + # ================================================================ + # Direction 2: distributed → serial + # ================================================================ + d2s = output_dir / "d2s" + + # 2-epoch continuous serial baseline (rank 0 only). + if manager.rank == 0: + d2s_cont_model = TinyDistogramSerialModel(**model_kwargs) + d2s_cont_tr = _trainer(d2s / "cont", epochs=2) + d2s_cont_tr.fit(d2s_cont_model, train_dataloaders=_make_dl()) + state_d2s_cont = {k: v.detach().cpu().clone() for k, v in d2s_cont_model.state_dict().items()} + torch.distributed.barrier() + + # Stage 1: 1-epoch distributed with checkpoint. + d2s_s1 = TinyDistogramCPModel(dist_manager=manager, **model_kwargs) + d2s_s1_strat = BoltzContextParallelStrategy(dist_manager=manager) + d2s_s1_tr = _trainer(d2s / "dist", strategy=d2s_s1_strat, ckpt_cb=_ckpt_cb(d2s / "dist"), epochs=1) + d2s_s1_tr.fit(d2s_s1, train_dataloaders=_make_dl()) + # Guard: distributed stage-1 must differ from initial (training was not a no-op). + d2s_s1_sd = d2s_s1_strat.lightning_module_state_dict() + for key, init_val in serial_state.items(): + prefixed = f"distogram_module.{key}" + assert not torch.equal( + d2s_s1_sd[prefixed].cpu(), init_val + ), f"Distributed stage-1 weight '{key}' unchanged after 1 epoch" + # Barrier after all-ranks distributed training, before rank-0-only + # assertions. Placing it here avoids deadlock: if rank 0's assertions + # fail below, other ranks have already passed this sync point and + # exit cleanly instead of hanging in an NCCL wait. + torch.distributed.barrier() + + d2s_ckpt = d2s / "dist" / "last.ckpt" + if manager.rank == 0: + assert d2s_ckpt.exists() + ckpt = torch.load(d2s_ckpt, map_location="cpu", weights_only=False) + assert not has_dtensors(ckpt["state_dict"]), "Distributed ckpt should be plain tensors" + # Verify distributed checkpoint uses FQN string keys. + d2s_opt_keys = list(ckpt["optimizer_states"][0]["state"].keys()) + assert d2s_opt_keys, "Distributed checkpoint optimizer state is empty" + assert all(isinstance(k, str) for k in d2s_opt_keys), ( + f"Distributed checkpoint should have FQN string optimizer keys, " + f"got types {[type(k).__name__ for k in d2s_opt_keys[:3]]}" + ) + # Verify FQN keys correspond to actual model parameter names. + expected_param_names = [n for n, _ in d2s_s1.named_parameters()] + assert sorted(d2s_opt_keys) == sorted(expected_param_names), ( + f"FQN optimizer keys don't match model parameters.\n" + f" Optimizer keys: {sorted(d2s_opt_keys)}\n" + f" Model params: {sorted(expected_param_names)}" + ) + # Verify param_groups also use FQN keys (not integers). + pg_params = ckpt["optimizer_states"][0]["param_groups"][0]["params"] + assert all(isinstance(p, str) for p in pg_params), ( + f"Distributed checkpoint param_groups should use FQN strings, " + f"got types {[type(p).__name__ for p in pg_params[:3]]}" + ) + + # Stage 2: resume as serial (rank 0 only). + # The serial model uses Lightning's default load path, which calls + # optimizer.load_state_dict() — this handles FQN keys transparently + # via positional mapping in param_groups. + if manager.rank == 0: + d2s_s2 = TinyDistogramSerialModel(**model_kwargs) + d2s_s2_tr = _trainer(d2s / "resume", epochs=2) + d2s_s2_tr.fit(d2s_s2, train_dataloaders=_make_dl(), ckpt_path=str(d2s_ckpt)) + state_d2s_res = {k: v.detach().cpu().clone() for k, v in d2s_s2.state_dict().items()} + _check_parity( + state_d2s_cont, + state_d2s_res, + d2s_cont_tr, + d2s_s2_tr, + d2s_cont_model, + d2s_s2, + "rank 0 distributed→serial", + ) + finally: + DistributedManager.cleanup() + DistributedManager._state = {} + monkeypatch.undo() + + +# --------------------------------------------------------------------------- +# Optimizer parameter ordering: serial vs distributed +# --------------------------------------------------------------------------- + + +def _parallel_assert_optimizer_param_ordering(rank: int, payload: tuple[Any, ...]) -> None: + """Assert that optimizer parameter ordering matches between serial and distributed models. + + PyTorch optimizer state_dict uses integer keys derived from the iteration + order of ``model.parameters()`` (which mirrors ``model.named_parameters()``). + If serial and distributed models yield parameters in different orders, cross- + topology checkpoint resume silently applies optimizer state (exp_avg, etc.) + to the wrong parameters. This test catches that. + + Also verifies that ``get_optimizer_state_dict`` produces FQN keys matching + ``named_parameters()`` — the mechanism used by ``BoltzContextParallelStrategy`` + for portable, name-keyed optimizer checkpoints. + """ + grid_group_sizes, device_type, backend, env_per_rank = payload + + monkeypatch = pytest.MonkeyPatch() + for key, value in env_per_rank.items(): + monkeypatch.setenv(key, f"{rank}" if value == "" else value) + + try: + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + serial_state = create_initial_serial_state_dict(token_z=16, num_bins=8, num_distograms=2, seed=41) + model_kwargs = { + "token_z": 16, + "num_bins": 8, + "num_distograms": 2, + "num_conformers": 2, + "serial_state_dict": serial_state, + "learning_rate": 1e-2, + } + + cp_model = TinyDistogramCPModel(dist_manager=manager, **model_kwargs) + serial_model = TinyDistogramSerialModel(**model_kwargs) + + # The optimizer receives parameters in iteration order of model.parameters(). + # named_parameters() yields (name, param) in the same order, so comparing + # the name lists verifies that integer optimizer state keys will align. + cp_names = [name for name, _ in cp_model.named_parameters()] + serial_names = [name for name, _ in serial_model.named_parameters()] + + assert cp_names == serial_names, ( + f"Parameter ordering mismatch between serial and distributed models.\n" + f" Serial: {serial_names}\n" + f" Distributed: {cp_names}\n" + f"Optimizer state checkpoint resume will silently apply state to wrong parameters." + ) + + # Also verify that the optimizers themselves see the same number of + # parameters in their param_groups (guards against extra params from + # distributed wrappers). + cp_opt = cp_model.configure_optimizers() + serial_opt = serial_model.configure_optimizers() + cp_param_count = sum(len(g["params"]) for g in cp_opt.param_groups) + serial_param_count = sum(len(g["params"]) for g in serial_opt.param_groups) + assert cp_param_count == serial_param_count, ( + f"Optimizer param count mismatch: serial has {serial_param_count} params, " + f"distributed has {cp_param_count} params" + ) + + # Verify get_optimizer_state_dict produces FQN keys matching + # named_parameters(). This is the save-side mechanism used by + # BoltzContextParallelStrategy.optimizer_state(). + + # Run one optimizer step to populate state buffers (exp_avg, etc.). + # Use a dummy gradient to avoid going through training_step which + # requires Trainer integration. + for p in cp_model.parameters(): + p.grad = torch.randn_like(p.to_local() if isinstance(p, DTensor) else p) + if isinstance(p, DTensor): + p.grad = distribute_tensor(p.grad, device_mesh=p.device_mesh, placements=p.placements) + cp_opt.step() + + fqn_sd = get_optimizer_state_dict(cp_model, cp_opt) + fqn_state_keys = sorted(fqn_sd["state"].keys()) + fqn_pg_params = sorted(fqn_sd["param_groups"][0]["params"]) + + assert all(isinstance(k, str) for k in fqn_state_keys), ( + f"get_optimizer_state_dict should return FQN string keys, " + f"got types {[type(k).__name__ for k in fqn_state_keys[:3]]}" + ) + assert fqn_state_keys == sorted(cp_names), ( + f"FQN optimizer state keys don't match named_parameters().\n" + f" FQN keys: {fqn_state_keys}\n" + f" named_parameters: {sorted(cp_names)}" + ) + assert fqn_pg_params == sorted(cp_names), ( + f"FQN param_group params don't match named_parameters().\n" + f" param_groups: {fqn_pg_params}\n" + f" named_parameters: {sorted(cp_names)}" + ) + + # Verify FQN keys match between serial and distributed models. + for p in serial_model.parameters(): + p.grad = torch.randn_like(p) + serial_opt.step() + + serial_fqn_sd = get_optimizer_state_dict(serial_model, serial_opt) + serial_fqn_keys = sorted(serial_fqn_sd["state"].keys()) + assert fqn_state_keys == serial_fqn_keys, ( + f"FQN keys mismatch between serial and distributed models.\n" + f" Distributed: {fqn_state_keys}\n" + f" Serial: {serial_fqn_keys}" + ) + finally: + DistributedManager.cleanup() + DistributedManager._state = {} + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=["cpu-dp2-cp3x3"], +) +def test_optimizer_param_ordering_serial_vs_distributed(setup_env): + """Goals: optimizer parameter ordering and FQN key consistency. + + - named_parameters() yields identical key lists for both model types + - Optimizers see the same number of parameters + - get_optimizer_state_dict produces FQN string keys matching named_parameters() + - FQN keys are identical between serial and distributed models + - Guards against silent optimizer state misalignment on cross-topology resume + + See also: test_optimizer_param_ordering_boltz2 for full Boltz-2 model + coverage, which catches subtle registration-order bugs that the tiny + harness cannot. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + spawn_multiprocessing( + _parallel_assert_optimizer_param_ordering, + world_size, + (grid_group_sizes, device_type, backend, env_per_rank), + ) + + +def _parallel_assert_boltz2_optimizer_param_ordering(rank: int, payload: tuple[Any, ...]) -> None: + """Assert optimizer parameter consistency for the full Boltz-2 model wrapper. + + The Boltz2Distributed wrapper renames placeholder module parameters with + a ``._serial.`` prefix and may register submodules in a different order + than the serial model (wrapped modules first, then placeholders). + + Since ``BoltzContextParallelStrategy`` uses FQN-keyed (name-based) + optimizer state dicts, parameter *ordering* does not need to match. + This test verifies what does matter for correctness: + + 1. The canonical parameter name *sets* (stripping ``._serial.``) are + identical between serial and distributed models — no missing or extra + parameters. + 2. The optimizer covers exactly the set of trainable parameters. + 3. ``get_optimizer_state_dict`` produces well-formed FQN string keys + matching ``named_parameters()``. + 4. Canonical FQN keys match between serial and distributed models, + ensuring cross-topology checkpoint portability. + """ + grid_group_sizes, device_type, backend, env_per_rank, serial_state_dict, serial_hparams = payload + + monkeypatch = pytest.MonkeyPatch() + for key, value in env_per_rank.items(): + monkeypatch.setenv(key, f"{rank}" if value == "" else value) + + try: + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + serial_model = SerialBoltz2(**serial_hparams) + serial_model.load_state_dict(serial_state_dict, strict=True) + serial_model = serial_model.to(manager.device) + + cp_model = Boltz2Distributed(serial_model, manager) + cp_model = cp_model.to(manager.device) + + # Re-create a fresh serial model for comparison (the original was + # consumed by the distributed wrapper). + serial_model2 = SerialBoltz2(**serial_hparams) + serial_model2.load_state_dict(serial_state_dict, strict=True) + + # Canonical name set comparison: strip ._serial. from distributed names + cp_names = [name for name, _ in cp_model.named_parameters()] + serial_names = [name for name, _ in serial_model2.named_parameters()] + + cp_names_canon_set = {n.replace("._serial.", ".") for n in cp_names} + serial_names_set = set(serial_names) + + missing = serial_names_set - cp_names_canon_set + extra = cp_names_canon_set - serial_names_set + assert not missing and not extra, ( + f"Parameter name set mismatch between serial and distributed Boltz2.\n" + f" Missing from distributed: {sorted(missing)[:5]}\n" + f" Extra in distributed: {sorted(extra)[:5]}" + ) + + # Build per-name param counts from the serial model and verify + # the per-name sum equals the optimizer's total. + cp_result = cp_model.configure_optimizers() + serial_result = serial_model2.configure_optimizers() + + # Both return ([optimizer], [scheduler_dict]) with lr_scheduler="af3" + cp_opt = cp_result[0][0] if isinstance(cp_result, tuple) else cp_result + serial_opt = serial_result[0][0] if isinstance(serial_result, tuple) else serial_result + + serial_trainable_names = {n for n, p in serial_model2.named_parameters() if p.requires_grad} + serial_opt_total = sum(len(g["params"]) for g in serial_opt.param_groups) + assert len(serial_trainable_names) == serial_opt_total, ( + f"Serial optimizer param count ({serial_opt_total}) doesn't match " + f"serial trainable named_parameters count ({len(serial_trainable_names)})" + ) + + # Identify which serial param names correspond to placeholder + # modules in the distributed model. + placeholder_prefixes = tuple( + f"{name}." for name, mod in cp_model.named_modules() if isinstance(mod, _PlaceholderModule) + ) + serial_placeholder_names = {n for n in serial_trainable_names if n.startswith(placeholder_prefixes)} + serial_non_placeholder_names = serial_trainable_names - serial_placeholder_names + + # The distributed optimizer should cover exactly the non-placeholder + # (i.e. trainable) params. + cp_opt_count = sum(len(g["params"]) for g in cp_opt.param_groups) + cp_trainable_names = {n for n, p in cp_model.named_parameters() if p.requires_grad} + assert cp_opt_count == len(cp_trainable_names), ( + f"Distributed optimizer param count ({cp_opt_count}) doesn't match " + f"distributed trainable param count ({len(cp_trainable_names)})" + ) + assert cp_opt_count > 0, "Distributed optimizer has zero params" + + # Param count on the non-placeholder subset must match between + # serial and distributed. + assert cp_opt_count == len(serial_non_placeholder_names), ( + f"Distributed optimizer has {cp_opt_count} params but serial model " + f"has {len(serial_non_placeholder_names)} non-placeholder trainable params " + f"(total serial trainable: {len(serial_trainable_names)}, " + f"placeholder: {len(serial_placeholder_names)})" + ) + + # Canonicalize distributed trainable names and compare against + # the serial non-placeholder set. + cp_trainable_canon = {n.replace("._serial.", ".") for n in cp_trainable_names} + assert cp_trainable_canon == serial_non_placeholder_names, ( + f"Trainable distributed params don't match serial non-placeholder params.\n" + f" Only in distributed: {sorted(cp_trainable_canon - serial_non_placeholder_names)[:5]}\n" + f" Only in serial: {sorted(serial_non_placeholder_names - cp_trainable_canon)[:5]}" + ) + + # Parameter shapes must match between serial and distributed. + # DTensor params report their global shape, which should equal the + # serial shape. + serial_shapes = {n: p.shape for n, p in serial_model2.named_parameters() if p.requires_grad} + shape_mismatches = [] + for n, p in cp_model.named_parameters(): + if not p.requires_grad: + continue + canon = n.replace("._serial.", ".") + serial_shape = serial_shapes.get(canon) + if serial_shape is None: + continue + cp_shape = p.shape + if cp_shape != serial_shape: + shape_mismatches.append((canon, cp_shape, serial_shape)) + assert not shape_mismatches, "Parameter shape mismatches (distributed vs serial):\n" + "\n".join( + f" {n}: {cs} vs {ss}" for n, cs, ss in shape_mismatches[:10] + ) + + # FQN keys from get_optimizer_state_dict + for p in cp_model.parameters(): + p.grad = torch.randn_like(p.to_local() if isinstance(p, DTensor) else p) + if isinstance(p, DTensor): + p.grad = distribute_tensor(p.grad, device_mesh=p.device_mesh, placements=p.placements) + cp_opt.step() + + fqn_sd = get_optimizer_state_dict(cp_model, cp_opt) + fqn_state_keys = sorted(fqn_sd["state"].keys()) + + assert all(isinstance(k, str) for k in fqn_state_keys), ( + f"get_optimizer_state_dict should return FQN string keys, " + f"got types {[type(k).__name__ for k in fqn_state_keys[:3]]}" + ) + # FQN state keys only cover trainable params (those in the optimizer) + cp_trainable_names = sorted(n for n, p in cp_model.named_parameters() if p.requires_grad) + assert fqn_state_keys == cp_trainable_names, ( + f"FQN optimizer state keys don't match trainable named_parameters().\n" + f" FQN keys (first 5): {fqn_state_keys[:5]}\n" + f" trainable params (first 5): {cp_trainable_names[:5]}" + ) + + # Cross-model FQN key comparison: distributed trainable params + # (canonicalized) must be a subset of serial optimizer params. + # Placeholder modules have their params frozen, so only the + # "ready" distributed submodules appear in the distributed optimizer. + for p in serial_model2.parameters(): + p.grad = torch.randn_like(p) + serial_opt.step() + + serial_fqn_sd = get_optimizer_state_dict(serial_model2, serial_opt) + serial_fqn_keys = set(serial_fqn_sd["state"].keys()) + + cp_fqn_canon = {k.replace("._serial.", ".") for k in fqn_state_keys} + not_in_serial = cp_fqn_canon - serial_fqn_keys + assert not not_in_serial, ( + f"Distributed optimizer FQN keys not found in serial optimizer:\n" f" {sorted(not_in_serial)[:5]}" + ) + finally: + DistributedManager.cleanup() + DistributedManager._state = {} + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (3, 3)), True, "cpu", "ENV"), + ((1, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=["cpu-dp2-cp3x3", "cuda-dp1-cp2x2"], +) +def test_optimizer_param_ordering_boltz2(setup_env): + """Boltz-2 model: optimizer parameter ordering and FQN key consistency. + + Extends the tiny-distogram harness test to the full Boltz-2 wrapper, + catching registration-order bugs across the many submodules (ready + + placeholder) that the smaller harness cannot exercise. + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + serial_state_dict, serial_hparams = _prepare_serial_model(ema=False) + + spawn_multiprocessing( + _parallel_assert_boltz2_optimizer_param_ordering, + world_size, + (grid_group_sizes, device_type, backend, env_per_rank, serial_state_dict, serial_hparams), + ) + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=["cpu-dp2-cp3x3"], +) +def test_cross_mode_stop_and_go(setup_env, tmp_path): + """Goals: cross-mode checkpoint interop in both directions. + + - Serial→distributed: 1-epoch serial + resume as distributed matches + 2-epoch continuous distributed baseline + - Distributed→serial: 1-epoch distributed + resume as serial matches + 2-epoch continuous serial baseline + - Validates BoltzContextParallelStrategy strips DTensor metadata for + portable checkpoints (the train-with-CP/deploy-with-serial workflow) + - Verifies serial checkpoints use integer optimizer keys + - Verifies distributed checkpoints use FQN string optimizer keys + - Verifies cross-format loading: serial int-key ckpt loaded by distributed + strategy (legacy path), and distributed FQN-key ckpt loaded by serial + model (PyTorch positional mapping in optimizer.load_state_dict) + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need {world_size} GPUs, have {torch.cuda.device_count()}") + output_dir = tmp_path / "cross_mode_output" + payload = (grid_group_sizes, device_type, backend, env_per_rank, str(output_dir)) + spawn_multiprocessing(_parallel_assert_cross_mode_stop_and_go, world_size, payload) diff --git a/tests/distributed/test_dtensor_train_utils.py b/tests/distributed/test_dtensor_train_utils.py new file mode 100644 index 000000000..9b815ddcd --- /dev/null +++ b/tests/distributed/test_dtensor_train_utils.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Unit tests for distributed train.py utility functions. + +These tests cover the pure-logic helpers that don't require a running +distributed process group, catching configuration errors before a full +training launch. +""" + +import os +from unittest.mock import MagicMock + +import pytest +import torch + +import boltz.distributed.train as train_module +from boltz.distributed.model.modules.utils import Precision, setup_tf32_env +from boltz.distributed.train import ( + DistributedTrainConfig, + _apply_matmul_precision, + _create_dist_manager, + _parse_precision, +) + + +def _make_config(**overrides: object) -> DistributedTrainConfig: + return DistributedTrainConfig( + data=MagicMock(), + model=MagicMock(), + output="/tmp/test", + **overrides, + ) + + +# --------------------------------------------------------------------------- +# _parse_precision +# --------------------------------------------------------------------------- +def test_parse_precision_returns_precision_if_already_precision() -> None: + """Goals: Precision enum passthrough.""" + assert _parse_precision(Precision.BF16) is Precision.BF16 + + +@pytest.mark.parametrize( + "name,expected", + [ + ("FP32", Precision.FP32), + ("BF16", Precision.BF16), + ("BF16_MIXED", Precision.BF16_MIXED), + ("FP16", Precision.FP16), + ("TF32", Precision.TF32), + ("FP64", Precision.FP64), + ], +) +def test_parse_precision_from_string(name: str, expected: Precision) -> None: + """Goals: string name → Precision enum conversion.""" + assert _parse_precision(name) is expected + + +@pytest.mark.parametrize("bad_value", ["INVALID", 42, None]) +def test_parse_precision_raises_for_unsupported_value(bad_value: object) -> None: + """Goals: ValueError for invalid precision inputs.""" + with pytest.raises(ValueError, match="Unsupported precision value"): + _parse_precision(bad_value) + + +# --------------------------------------------------------------------------- +# _apply_matmul_precision +# --------------------------------------------------------------------------- +def test_apply_matmul_precision_noop_when_unset(monkeypatch: pytest.MonkeyPatch) -> None: + """Goals: None value does not call set_float32_matmul_precision.""" + calls: list[str] = [] + + def _record(value: str) -> None: + calls.append(value) + + monkeypatch.setattr(torch, "set_float32_matmul_precision", _record) + _apply_matmul_precision(None) + assert calls == [] + + +def test_apply_matmul_precision_applies_when_set(monkeypatch: pytest.MonkeyPatch) -> None: + """Goals: string value is forwarded to set_float32_matmul_precision.""" + calls: list[str] = [] + + def _record(value: str) -> None: + calls.append(value) + + monkeypatch.setattr(torch, "set_float32_matmul_precision", _record) + _apply_matmul_precision("high") + assert calls == ["high"] + + +# --------------------------------------------------------------------------- +# DistributedTrainConfig +# --------------------------------------------------------------------------- +def test_distributed_train_config_defaults() -> None: + """Goals: DistributedTrainConfig defaults match expected Boltz-2 production values.""" + cfg = _make_config() + assert cfg.precision is Precision.FP32 + assert cfg.seed is None + assert cfg.matmul_precision is None + assert cfg.find_unused_parameters is False + assert cfg.save_top_k == 1 + assert cfg.validation_only is False + + +# --------------------------------------------------------------------------- +# _create_dist_manager (already initialized mismatch guards) +# --------------------------------------------------------------------------- +def test_create_dist_manager_raises_on_initialized_device_type_mismatch(monkeypatch: pytest.MonkeyPatch) -> None: + """Goals: explicit error when reusing singleton with different accelerator.""" + + class _FakeDistManager: + @staticmethod + def is_initialized() -> bool: + return True + + def __init__(self) -> None: + self.device = torch.device("cpu") + self.group_ranks = {"dp": [0], "cp": [0]} + + monkeypatch.setattr(train_module, "DistributedManager", _FakeDistManager) + cfg = _make_config(trainer={"accelerator": "gpu"}, parallel_size={"size_dp": 1, "size_cp": 1}) + with pytest.raises(ValueError, match="Cannot change device type"): + _create_dist_manager(cfg) + + +def test_create_dist_manager_raises_on_initialized_topology_mismatch(monkeypatch: pytest.MonkeyPatch) -> None: + """Goals: explicit error when reusing singleton with different dp/cp sizes.""" + + class _FakeDistManager: + @staticmethod + def is_initialized() -> bool: + return True + + def __init__(self) -> None: + self.device = torch.device("cpu") + self.group_ranks = {"dp": [0], "cp": [0]} + + monkeypatch.setattr(train_module, "DistributedManager", _FakeDistManager) + cfg = _make_config(trainer={"accelerator": "cpu"}, parallel_size={"size_dp": 2, "size_cp": 1}) + with pytest.raises(ValueError, match="Cannot change topology"): + _create_dist_manager(cfg) + + +# --------------------------------------------------------------------------- +# setup_tf32_env +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=False) +def _tf32_snapshot(): + """Save and restore TF32 global state around each test that uses it.""" + orig_matmul = torch.backends.cuda.matmul.allow_tf32 + orig_cudnn = torch.backends.cudnn.allow_tf32 + orig_env = os.environ.get("NVIDIA_TF32_OVERRIDE") + yield orig_matmul, orig_cudnn, orig_env + torch.backends.cuda.matmul.allow_tf32 = orig_matmul + torch.backends.cudnn.allow_tf32 = orig_cudnn + if orig_env is not None: + os.environ["NVIDIA_TF32_OVERRIDE"] = orig_env + else: + os.environ.pop("NVIDIA_TF32_OVERRIDE", None) + + +def test_setup_tf32_env_tf32_precision_enables_tf32(_tf32_snapshot) -> None: + """Goals: TF32 precision sets NVIDIA_TF32_OVERRIDE=1 and enables matmul/cudnn TF32.""" + orig_matmul, orig_cudnn, orig_env = _tf32_snapshot + with setup_tf32_env(Precision.TF32): + assert os.environ.get("NVIDIA_TF32_OVERRIDE") == "1" + assert torch.backends.cuda.matmul.allow_tf32 is True + assert torch.backends.cudnn.allow_tf32 is True + # Context manager should restore originals + assert torch.backends.cuda.matmul.allow_tf32 == orig_matmul + assert torch.backends.cudnn.allow_tf32 == orig_cudnn + if orig_env is not None: + assert os.environ.get("NVIDIA_TF32_OVERRIDE") == orig_env + else: + assert "NVIDIA_TF32_OVERRIDE" not in os.environ + + +def test_setup_tf32_env_fp32_precision_disables_tf32(_tf32_snapshot) -> None: + """Goals: FP32 precision explicitly disables TF32.""" + with setup_tf32_env(Precision.FP32): + assert os.environ.get("NVIDIA_TF32_OVERRIDE") == "0" + assert torch.backends.cuda.matmul.allow_tf32 is False + assert torch.backends.cudnn.allow_tf32 is False + + +def test_setup_tf32_env_bf16_precision_leaves_tf32_unchanged(_tf32_snapshot) -> None: + """Goals: BF16 precision does not modify TF32 state.""" + orig_matmul, orig_cudnn, _ = _tf32_snapshot + with setup_tf32_env(Precision.BF16): + assert torch.backends.cuda.matmul.allow_tf32 == orig_matmul + assert torch.backends.cudnn.allow_tf32 == orig_cudnn + + +def test_setup_tf32_env_restores_on_exception(_tf32_snapshot) -> None: + """Goals: TF32 state is restored even when the context manager body raises.""" + orig_matmul, orig_cudnn, _ = _tf32_snapshot + with pytest.raises(RuntimeError, match="boom"): + with setup_tf32_env(Precision.TF32): + raise RuntimeError("boom") + assert torch.backends.cuda.matmul.allow_tf32 == orig_matmul + assert torch.backends.cudnn.allow_tf32 == orig_cudnn diff --git a/tests/distributed/test_layoutmap.py b/tests/distributed/test_layoutmap.py new file mode 100644 index 000000000..7c71c63d0 --- /dev/null +++ b/tests/distributed/test_layoutmap.py @@ -0,0 +1,310 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +from collections import namedtuple +from itertools import product + +import numpy as np +import pytest +import torch + +from boltz.distributed.utils import LayoutMap, update_exhaustive_strides + +# Define a named tuple for the test parameters +LayoutMapParams = namedtuple("LayoutMapParams", ["strides", "shape", "offset", "slices", "layoutmap"]) + + +@pytest.fixture( + params=[ + # Test cases without slices + ((), (), 3, None), # rank-0 layout + ((1,), (10,), 1, None), # rank-1 layout + ((1, 1), (2, 1), 1, None), # rank-2 LayoutRightMap with all-1 strides + ((1, 1, 1), (1, 2, 1), 0, None), # rank-3 LayoutRightMap with all-1 strides + ((5, 5, 1), (3, 1, 5), 0, None), # rank-3 LayoutRightMap with singleton axis + ((1, 2, 6), (2, 3, 4), 2, None), # LayoutLeftMap + ((12, 4, 1), (2, 3, 4), 3, None), # LayoutRightMap + ((3, 1, 6), (2, 3, 4), 4, None), # Layout strided + ((3, 1, 24), (4, 2, 3), 0, None), # LayoutMap(strides=(), shape=())[:4, :2, :3] + ((12, 2, 72), (2, 2, 2), 2, None), # LayoutMap(strides=(), shape=())[::4, ::2, ::3] + # Test cases with slices + ((1, 2, 6), (2, 3, 4), 2, (slice(1, 2, 1),)), # Single slice on first dimension (test slice padding) + ((12, 4, 1), (2, 3, 4), 3, (0, slice(2, 3, 1))), # Int index + slice + ((3, 1, 6), (2, 3, 4), 4, (slice(1, 2, 1), 1, slice(2, 4, 1))), # Multiple slices and int + ( + (12, 3, 24), + (2, 2, 3), + 0, + (slice(None, None, 2), 0, slice(None, None, 2)), + ), # LayoutMap(strides=(6, 1, 24), shape=(4, 6, 5))[::2, ::3, :3][::2, 0, ::2] + ], + ids=lambda p: f"stride = {p[0]}, shape = {p[1]}, offset = {p[2]}, slices = {p[3]}", +) +def layout_map_fixture(request): + strides, shape, offset, slices = request.param + + layout_map = LayoutMap(strides, shape, offset=offset) + # save the input parameters for debugging purpose + params = LayoutMapParams(strides, shape, offset, slices, layout_map) + + if len(strides) == 0: + required_span_size = 1 + flat_indices_nd = torch.tensor(offset) + else: + required_span_size = 1 + sum((shape[i] - 1) * strides[i] for i in range(len(strides))) + flat_indices = torch.arange(required_span_size) + offset + flat_indices_nd = torch.as_strided(flat_indices, size=shape, stride=strides) + + # Apply slicing if specified + if slices is not None: + try: + # Slice the flat_indices_nd tensor + flat_indices_nd = flat_indices_nd[slices] + + # Slice the layout_map + layout_map = layout_map[slices] + + # Update the shape and strides to match the sliced layout + shape = layout_map.shape + strides = layout_map.strides + offset = layout_map.offset + required_span_size = layout_map.required_span_size + except (ValueError, RuntimeError) as e: + # Skip this test case if slicing raises an exception + pytest.skip(f"Slicing failed with error: {e}") + + return params, strides, shape, offset, layout_map, required_span_size, flat_indices_nd + + +def test_layout_map_init(layout_map_fixture): + params, strides, shape, offset, layout_map, required_span_size, flat_indices_nd = layout_map_fixture + assert isinstance(layout_map.shape, tuple) + assert isinstance(layout_map.strides, tuple) + assert strides == layout_map.strides + assert layout_map.shape == shape + assert layout_map.numel == np.prod(shape) + assert required_span_size == layout_map.required_span_size + # for exhaustive layout, the flat_indices_nd should be arange upon sorting + is_exhaustive_expected = ( + required_span_size == flat_indices_nd.numel() + and torch.all( + flat_indices_nd.flatten().sort(stable=True).values == torch.arange(required_span_size) + offset + ).item() + ) + assert layout_map.is_exhaustive == is_exhaustive_expected + + +def test_layout_map_init_invalid_strides(layout_map_fixture): + _, _, shape, _, _, _, _ = layout_map_fixture + invalid_strides = (1, 1, 1) + if len(shape) != 3: + with pytest.raises(ValueError, match=r"^.* must have the same length"): + LayoutMap(invalid_strides, shape) + else: + # the only unique layout with all-1 strides is the one where + # only 1 dimension of the shape is larger or equal to 1 + # while the rest are 1. + values, counts = np.unique(shape, return_counts=True) + is_unique_layout = ( + values.size == 2 and values[0] == 1 and values[1] > 1 and counts[0] == len(shape) - 1 and counts[1] == 1 + ) + if not is_unique_layout: + with pytest.raises(ValueError, match=r"^.* do not give unique layout"): + LayoutMap(invalid_strides, shape) + + +def test_layout_map_init_invalid_shape(layout_map_fixture): + _, strides, _, _, _, _, _ = layout_map_fixture + invalid_shape = (99, 99, 99) + with pytest.raises(ValueError): + LayoutMap(strides, invalid_shape) + if len(strides) != 3: + with pytest.raises(ValueError, match=r"^.* must have the same length"): + LayoutMap(strides, invalid_shape) + else: + with pytest.raises(ValueError, match=r"^.* do not give unique layout"): + LayoutMap(strides, invalid_shape) + negative_shape = (1, 2, -3) + with pytest.raises(ValueError, match=r"^.* contain negative values"): + LayoutMap(strides, negative_shape) + zero_shape = (1, 2, 0) + with pytest.raises(ValueError, match=r"^.* contain zero values"): + LayoutMap(strides, zero_shape) + + +def test_layout_map_call(layout_map_fixture): + _, _, shape, _, layout_map, _, flat_indices_nd = layout_map_fixture + for idx in product(*[range(s) for s in shape]): + flat_idx_expected = flat_indices_nd[idx] + flat_idx_result = layout_map(idx) + assert flat_idx_expected == flat_idx_result + + +def test_layout_map_call_invalid_idx(layout_map_fixture): + _, _, shape, _, layout_map, _, _ = layout_map_fixture + idx = tuple(i for i in range(len(shape) + 1)) + with pytest.raises(ValueError, match=r"Expected .* elements in ids but got only .*"): + layout_map(idx) + if len(shape) > 0: + # test out of bounds + idx_oob = tuple(shape[i] for i in range(len(shape))) + with pytest.raises(ValueError, match=r"Expected ids to satisfy 0 <= ids\[.*\] <= .* but found ids\[.*\] == .*"): + layout_map(idx_oob) + + +def test_layout_map_unravel(layout_map_fixture): + _, _, shape, _, layout_map, _, flat_indices_nd = layout_map_fixture + for idx_expected in product(*[range(s) for s in shape]): + flat_idx = flat_indices_nd[idx_expected].item() + idx_result = layout_map.unravel(flat_idx) + assert idx_result == idx_expected + + +def test_layout_map_unravel_invalid_flat_idx(layout_map_fixture): + _, _, _, _, layout_map, required_span_size, _ = layout_map_fixture + flat_idx_oob = required_span_size + layout_map.offset + with pytest.raises(ValueError, match=r"Expected flat_index in range"): + layout_map.unravel(flat_idx_oob) + + +# Add a test specifically for slicing with padding +def test_layout_map_getitem_padding(): + """Test that slicing with fewer dimensions than the layout properly pads with ':' slices.""" + layout_map = LayoutMap((12, 4, 1), (2, 3, 4)) + + # Test partial slicing + partial_slice_1 = layout_map[1] # Should be equivalent to layout_map[1, :, :] + full_slice_1 = layout_map[1, slice(None), slice(None)] + + assert partial_slice_1.shape == full_slice_1.shape + assert partial_slice_1.strides == full_slice_1.strides + assert partial_slice_1.offset == full_slice_1.offset + + # Test with 2 dimensions only + partial_slice_2 = layout_map[1, 2] # Should be equivalent to layout_map[1, 2, :] + full_slice_2 = layout_map[1, 2, slice(None)] + + assert partial_slice_2.shape == full_slice_2.shape + assert partial_slice_2.strides == full_slice_2.strides + assert partial_slice_2.offset == full_slice_2.offset + + # Test with a single slice + partial_slice_3 = layout_map[slice(0, 2, 1)] # Should be equivalent to layout_map[0:2, :, :] + full_slice_3 = layout_map[slice(0, 2, 1), slice(None), slice(None)] + + assert partial_slice_3.shape == full_slice_3.shape + assert partial_slice_3.strides == full_slice_3.strides + assert partial_slice_3.offset == full_slice_3.offset + + +# Add a test that specifically checks the slicing behavior +def test_layout_map_getitem(layout_map_fixture): + """Test that slicing behavior is as expected for cases with slices.""" + params, _, _, _, layout_map, _, flat_indices_nd = layout_map_fixture + + # Skip test cases without slices + if params.slices is None: + return + + # For cases with slices, verify that indices after slicing still match + for idx in product(*[range(s) for s in layout_map.shape]): + flat_idx_expected = flat_indices_nd[idx].item() + flat_idx_result = layout_map(idx) + assert flat_idx_expected == flat_idx_result + + +# Test for invalid slicing cases +def test_layout_map_invalid_slicing(): + """Test the __getitem__ method with invalid slices.""" + layout_map = LayoutMap((1, 2, 6), (2, 3, 4)) + + # Test with negative step + with pytest.raises(ValueError, match="Negative or zero steps"): + layout_map[slice(1, 0, -1)] + + # Test with start >= stop + with pytest.raises(ValueError, match="start not smaller than stop"): + layout_map[slice(1, 1, 1)] + + with pytest.raises(ValueError, match="start not smaller than stop"): + layout_map[slice(2, 1, 1)] + + # Test with invalid slice type + with pytest.raises(TypeError, match="Unsupported slice type"): + layout_map[("invalid_type",)] + + +def test_update_exhaustive_strides(layout_map_fixture): + params, strides, shape, offset, layout_map, required_span_size, flat_indices_nd = layout_map_fixture + + # Create common label for test case debugging + label_test_case = f"shape_original={shape}, strides_original={strides}" + + # Skip test cases that are not exhaustive or have slices applied + if not layout_map.is_exhaustive: + # The function should raise ValueError + with pytest.raises(ValueError, match="Input layout .* is not exhaustive"): + update_exhaustive_strides(shape, strides, shape) + return + + # Generate a new shape with the same number of dimensions + # Use different sizes but keep the same dimensionality + shape_new = tuple(max(1, s + 1) for s in shape) # Increment each dimension by 1 + + # Call the function under test + strides_new = update_exhaustive_strides(shape, strides, shape_new) + + # Verify the output + assert len(strides_new) == len(shape_new), ( + f"Output strides should have same length as new shape. " + f"{label_test_case}, shape_new={shape_new}, strides_new={strides_new}" + ) + assert all( + isinstance(s, (int, np.integer)) and s > 0 for s in strides_new + ), f"All strides should be positive integers. {label_test_case}, shape_new={shape_new}, strides_new={strides_new}" + + # Create a new LayoutMap with the returned strides and new shape + layout_new = LayoutMap(strides_new, shape_new) + + # Verify that the new layout is exhaustive + assert ( + layout_new.is_exhaustive + ), f"The new layout should be exhaustive. {label_test_case}, shape_new={shape_new}, strides_new={strides_new}" + + # Test with a different new shape to ensure generality + shape_new_2 = tuple(max(1, s * 2) for s in shape) # Double each dimension + strides_new_2 = update_exhaustive_strides(shape, strides, shape_new_2) + layout_new_2 = LayoutMap(strides_new_2, shape_new_2) + + assert layout_new_2.is_exhaustive, ( + f"The second new layout should be exhaustive. " + f"{label_test_case}, shape_new_2={shape_new_2}, strides_new_2={strides_new_2}" + ) + assert layout_new_2.is_unique, ( + f"The second new layout should be unique. " + f"{label_test_case}, shape_new_2={shape_new_2}, strides_new_2={strides_new_2}" + ) + + # self reflective test + strides_new_3 = update_exhaustive_strides(shape, strides, shape) + assert ( + strides_new_3 == strides + ), f"The self reflective test should return the original strides. {label_test_case}, strides_new_3={strides_new_3}" diff --git a/tests/distributed/test_lightning_strategy.py b/tests/distributed/test_lightning_strategy.py new file mode 100644 index 000000000..7f225432d --- /dev/null +++ b/tests/distributed/test_lightning_strategy.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Error-path unit tests for BoltzContextParallelStrategy. + +These test input-validation guards (missing model, wrong optimizer count, +etc.) that integration tests would not naturally trigger. + +The strategy's DTensor-aware happy-path behaviour — checkpoint conversion, +optimizer-state redistribution, save/load roundtrip — is exercised +end-to-end by: + +- ``test_dtensor_train.py`` (one-step optimizer parity with real DTensor model) +- ``test_dtensor_stop_and_go.py`` (full save/resume cycle through ``train()``) +""" + +from dataclasses import dataclass + +import pytest +import pytorch_lightning as pl +import torch + +import boltz.distributed.lightning_strategy as strategy_module + + +@dataclass +class _DummyDistManager: + device: torch.device = torch.device("cpu") + rank: int = 0 + local_rank: int = 0 + world_size: int = 1 + + +class _TinyLightningModule(pl.LightningModule): + def __init__(self) -> None: + super().__init__() + self.layer = torch.nn.Linear(3, 2) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.25) + + +def _new_strategy() -> strategy_module.BoltzContextParallelStrategy: + return strategy_module.BoltzContextParallelStrategy(dist_manager=_DummyDistManager()) + + +def _attach_lightning_module( + strategy: strategy_module.BoltzContextParallelStrategy, module: _TinyLightningModule +) -> None: + strategy.model = module + strategy._lightning_module = module + + +# --------------------------------------------------------------------------- +# lightning_module_state_dict +# --------------------------------------------------------------------------- + + +def test_module_state_dict_raises_without_model(): + """Goals: RuntimeError when model not attached to strategy.""" + strategy = _new_strategy() + with pytest.raises(RuntimeError, match="model is not set"): + strategy.lightning_module_state_dict() + + +# --------------------------------------------------------------------------- +# load_model_state_dict +# --------------------------------------------------------------------------- + + +def test_load_model_state_dict_raises_without_lightning_module(): + """Goals: RuntimeError when lightning_module not set.""" + strategy = _new_strategy() + strategy.model = _TinyLightningModule() + checkpoint = {"state_dict": {"layer.weight": torch.randn(2, 3), "layer.bias": torch.randn(2)}} + with pytest.raises(RuntimeError, match="lightning_module is not set"): + strategy.load_model_state_dict(checkpoint, strict=False) + + +# --------------------------------------------------------------------------- +# load_checkpoint +# --------------------------------------------------------------------------- + + +def test_load_checkpoint_uses_checkpoint_io_with_cpu_map_location(monkeypatch): + """Goals: checkpoint load routes through checkpoint_io with CPU remap.""" + strategy = _new_strategy() + expected = {"state_dict": {"layer.weight": torch.randn(2, 3)}} + calls: dict[str, object] = {} + + class _FakeCheckpointIO: + def load_checkpoint(self, checkpoint_path, map_location=None): + calls["checkpoint_path"] = checkpoint_path + calls["map_location"] = map_location + return expected + + monkeypatch.setattr(strategy_module.torch, "load", lambda *_a, **_k: (_ for _ in ()).throw(AssertionError())) + strategy.checkpoint_io = _FakeCheckpointIO() + + checkpoint_path = "dummy.ckpt" + result = strategy.load_checkpoint(checkpoint_path) + + assert result is expected + assert calls["checkpoint_path"] == checkpoint_path + assert calls["map_location"] == "cpu" + + +# --------------------------------------------------------------------------- +# load_optimizer_state_dict +# --------------------------------------------------------------------------- + + +def test_load_optimizer_state_dict_raises_without_optimizer_states(): + """Goals: ValueError when checkpoint has no optimizer_states key.""" + strategy = _new_strategy() + module = _TinyLightningModule() + _attach_lightning_module(strategy, module) + strategy.optimizers = [torch.optim.SGD(module.parameters(), lr=0.25)] + with pytest.raises(ValueError, match="no optimizer_states found"): + strategy.load_optimizer_state_dict({}) + + +def test_load_optimizer_state_dict_raises_on_length_mismatch(): + """Goals: ValueError when checkpoint has wrong number of optimizer states.""" + strategy = _new_strategy() + module = _TinyLightningModule() + _attach_lightning_module(strategy, module) + strategy.optimizers = [ + torch.optim.SGD(module.parameters(), lr=0.1), + torch.optim.Adam(module.parameters(), lr=0.1), + ] + with pytest.raises(ValueError, match="length mismatch"): + strategy.load_optimizer_state_dict({"optimizer_states": [strategy.optimizers[0].state_dict()]}) + + +def test_load_optimizer_state_dict_raises_on_non_sequence(): + """Goals: TypeError when optimizer_states is not a list/tuple.""" + strategy = _new_strategy() + module = _TinyLightningModule() + _attach_lightning_module(strategy, module) + strategy.optimizers = [torch.optim.SGD(module.parameters(), lr=0.1)] + with pytest.raises(TypeError, match="must be a list/tuple"): + strategy.load_optimizer_state_dict({"optimizer_states": {"state": {}}}) + + +def test_load_optimizer_state_dict_no_optimizers(): + """Goals: silently return when no optimizers are attached.""" + strategy = _new_strategy() + strategy.optimizers = [] + # Should not raise. + strategy.load_optimizer_state_dict({"optimizer_states": [{"state": {}}]}) + + +# --------------------------------------------------------------------------- +# barrier +# --------------------------------------------------------------------------- + + +def test_barrier_no_distributed(monkeypatch): + """Goals: barrier() is a no-op when torch.distributed is not initialized.""" + strategy = _new_strategy() + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: False) + # Should not raise. + strategy.barrier() diff --git a/tests/distributed/test_manager.py b/tests/distributed/test_manager.py new file mode 100644 index 000000000..9bfeba2a2 --- /dev/null +++ b/tests/distributed/test_manager.py @@ -0,0 +1,258 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import itertools +import os +from typing import Dict, Optional + +import pytest +import torch + +from boltz.distributed.manager import DistributedManager + + +def test_manager_singleton(monkeypatch): + # Test distributed manager singleton functions as expected + monkeypatch.setenv("MASTER_ADDR", "localhost") + monkeypatch.setenv("MASTER_PORT", "45678") + monkeypatch.setenv("RANK", "0") + monkeypatch.setenv("WORLD_SIZE", "1") + DistributedManager.initialize({"dp": 1, "cp": 1}, "cpu", "gloo") + + manager_1 = DistributedManager() + manager_1.random_property = "random_string" + manager_2 = DistributedManager() + + # Compare attributes + for attr in manager_1.__dict__.keys(): + assert getattr(manager_1, attr) == getattr(manager_2, attr) + assert manager_1.random_property == manager_2.random_property + DistributedManager.cleanup() + + +def create_manager_and_assert( + rank_expected: int, + world_size_expected: int, + grid_group_sizes_expected: Dict[str, int], + device_type_expected: str, + backend_expected: str, + method_init_expected: str, + env_map: Optional[Dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank_expected}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes_expected, device_type=device_type_expected, backend=backend_expected) + manager = DistributedManager() + assert ( + manager.has_dist == torch.distributed.is_available() + ), "DistributedManager.has_dist inconsistent with torch.distributed's availability" + assert manager.rank == rank_expected + assert manager.world_size == world_size_expected + assert manager.group["world"] == torch.distributed.group.WORLD + assert manager.group_rank["world"] == rank_expected + assert manager.group_ranks["world"] == list(range(world_size_expected)) + # by default, the underlying DeviceMesh should use layout-right for the grid groups + layoutMap = manager.layout_device_mesh + grid_coords_expected = layoutMap.unravel(rank_expected) + has_subgroups = False + for i_group, (name_group, size_group) in enumerate(grid_group_sizes_expected.items()): + if isinstance(size_group, tuple) and all(isinstance(size_group_i, int) for size_group_i in size_group): + has_subgroups = True + layoutMap_subgroups = manager.layout_device_mesh_subgroups + grid_coords_expected_subgroups = layoutMap_subgroups.unravel(rank_expected) + # check if the layout of the subgroups is correct + slices_subgroup = list(grid_coords_expected_subgroups) + for i_subgroup, size_subgroup in enumerate(size_group): + name_subgroup = f"{name_group}_axis_{i_subgroup}" + # check if the subgroups' ranks are set consistently with the layout + assert len(manager.group_ranks[name_subgroup]) == size_subgroup + assert manager.group_rank[name_subgroup] == grid_coords_expected_subgroups[i_group + i_subgroup] + # check if the parent group is mapped correctly to the subgroups + assert manager.subgroups[name_group][i_subgroup] is manager.group[name_subgroup] + assert manager.subgroups_ranks[name_group][i_subgroup] == manager.group_ranks[name_subgroup] + assert manager.subgroups_rank[name_group][i_subgroup] == manager.group_rank[name_subgroup] + slices_subgroup[i_group + i_subgroup] = slice(None) + # check if the layout of the subgroups is correct + layoutMap_subgroup = layoutMap_subgroups[*slices_subgroup] + assert manager.layout_subgroups[name_group].shape == layoutMap_subgroup.shape + assert manager.layout_subgroups[name_group].strides == layoutMap_subgroup.strides + assert manager.layout_subgroups[name_group].offset == 0 + elif isinstance(size_group, int): + assert len(manager.group_ranks[name_group]) == size_group + assert manager.group_rank[name_group] == grid_coords_expected[i_group] + else: + raise ValueError(f"Invalid group size type: {type(size_group)}") + assert manager.has_subgroups == has_subgroups + + assert manager.backend == backend_expected + assert manager.device.type == device_type_expected + assert manager.method_init == method_init_expected + DistributedManager.cleanup() + + monkeypatch.undo() + + +def create_default_manager_can_raise( + rank_expected: int, + grid_group_sizes_expected: Dict[str, int], + device_type_expected: str, + backend_expected: str, + env_map: Optional[Dict[str, str]] = None, +): + monkeypatch = pytest.MonkeyPatch() + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank_expected}") + continue + monkeypatch.setenv(var_name, value) + if "BOLTZ_DISTRIBUTED_INIT_METHOD" in os.environ: + # setting BOLTZ_DISTRIBUTED_INIT_METHOD without the other + # relevant env vars for world_size and rank etc will trigger + # a RuntimeError + with pytest.raises(RuntimeError): + DistributedManager.initialize( + grid_group_sizes_expected, device_type=device_type_expected, backend=backend_expected + ) + else: + # default initialization should happen + DistributedManager.initialize( + grid_group_sizes_expected, device_type=device_type_expected, backend=backend_expected + ) + manager = DistributedManager() + assert manager.initialized + assert not manager.has_dist + assert DistributedManager().rank == 0 + assert DistributedManager().world_size == 1 + assert DistributedManager().local_rank == 0 + assert DistributedManager().device == torch.device("cpu") + assert DistributedManager().backend is None + assert DistributedManager().method_init is None + assert DistributedManager().group == {} + assert DistributedManager().group_rank == {} + assert DistributedManager().group_ranks == {} + assert DistributedManager().device_mesh is None + assert DistributedManager().device_mesh_subgroups is None + assert DistributedManager().layout_device_mesh is None + assert DistributedManager().layout_device_mesh_subgroups is None + assert DistributedManager().has_subgroups is False + assert DistributedManager().subgroups == {} + assert DistributedManager().subgroups_ranks == {} + assert DistributedManager().subgroups_rank == {} + assert DistributedManager().layout_subgroups == {} + + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + itertools.product([(1, 2)], [False], ["cpu", "cuda"], [None]), + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]} method_init={x[3]}", +) +def test_manager_default(setup_env): + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + # device_type and backend don't matter since there is no valid set of distributed environment variables + torch.multiprocessing.set_start_method("spawn", force=True) + torch.multiprocessing.spawn( + fn=create_default_manager_can_raise, + args=(grid_group_sizes, device_type, backend, env_per_rank), + nprocs=world_size, + join=True, + ) + + +@pytest.mark.parametrize( + "setup_env", + itertools.product([(1, 1), (1, 2), (2, (2, 2)), (1, (4, 4))], [True, False], ["cpu", "cuda"], ["ENV", "SLURM"]), + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]} method_init={x[3]}", +) +def test_manager(setup_env): + grid_group_sizes, world_size, device_type, backend, method_init, env_per_rank = setup_env + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip("skip cuda test because torch.cuda.device_count() != world_size") + + torch.multiprocessing.set_start_method("spawn", force=True) + torch.multiprocessing.spawn( + fn=create_manager_and_assert, + args=(world_size, grid_group_sizes, device_type, backend, method_init, env_per_rank), + nprocs=world_size, + join=True, + ) + + +def create_manager_and_group( + rank: int, + grid_group_sizes: Dict[str, int], + device_type: str, + backend: str, + env_map: Optional[Dict[str, str]] = None, +): + if env_map is not None: + for var_name, value in env_map.items(): + if value == "": + os.environ[var_name] = f"{rank}" + continue + os.environ[var_name] = value + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + assert "cp" in grid_group_sizes, "Must have group 'cp' in the input grid_group_sizes" + name_group = "new_group_test" + DistributedManager.create_group(name_group, manager.group_ranks["cp"], use_local_synchronization=True) + assert name_group in manager.group + assert manager.group_ranks[name_group] == manager.group_ranks["cp"] + assert manager.group_rank[name_group] == manager.group_rank["cp"] + DistributedManager.cleanup() + + +@pytest.mark.parametrize( + "setup_env", + itertools.product([(2, 4)], [False], ["cpu"], ["ENV"]), + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, specify_method={x[1]}, device_type={x[2]} method_init={x[3]}", +) +def test_manager_create_group(setup_env): + grid_group_sizes, world_size, device_type, backend, method_init, env_per_rank = setup_env + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip("skip cuda test because torch.cuda.device_count() != world_size") + + torch.multiprocessing.set_start_method("spawn", force=True) + torch.multiprocessing.spawn( + fn=create_manager_and_group, + args=(grid_group_sizes, device_type, backend, env_per_rank), + nprocs=world_size, + join=True, + ) diff --git a/tests/distributed/test_tiled_softmax_attn_update.py b/tests/distributed/test_tiled_softmax_attn_update.py new file mode 100644 index 000000000..e92394c49 --- /dev/null +++ b/tests/distributed/test_tiled_softmax_attn_update.py @@ -0,0 +1,311 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import pytest +import torch + +from boltz.distributed.model.modules.utils import DTYPE_TO_PRECISION +from boltz.distributed.utils import tiled_softmax_attention_update +from boltz.testing.utils import PRECISION_TO_INF + + +def expected_accum(a, v, dim_softmax, return_amax: bool = True): + """Compute expected accumulation for validation purposes.""" + dtype_input = a.dtype + a = a.double() + v = v.double() + o = torch.einsum("...i, ...i -> ...", torch.softmax(a, dim=dim_softmax), v) + amax = a.max(dim=dim_softmax, keepdim=True)[0] + lse_m = torch.logsumexp(a - amax, dim=dim_softmax, keepdim=True) + if return_amax: + return lse_m.to(dtype_input), o.reshape_as(lse_m).to(dtype_input), amax.to(dtype_input) + else: + # when return_amax is False, we return lse = lse_m + amax + # but the lse_m and amax were computed as return_amax is True + lse = lse_m + amax + return lse.to(dtype_input), o.reshape_as(lse).to(dtype_input), None + + +@pytest.fixture +def softmax_test_tensors(): + """Fixture providing test tensors for online softmax accumulation tests.""" + torch.manual_seed(42) + dtype = torch.float32 + + size_chunk = 3 + n_chunks = 21 + + n_elems = size_chunk * n_chunks + size_batch = 3 + + device = torch.device("cuda:0") + a = torch.randn((size_batch, n_elems), dtype=dtype, device=device) + v = torch.randn_like(a, device=device) + + # generate chunks with specific patterns including inf and -inf + inf = 1e9 + a = a.reshape(size_batch, n_chunks, size_chunk) + + # Set up specific patterns for batch 0: + # chunk 0: -inf + # chunk 1: -inf + # chunk 2: inf + # chunk 3: inf + # chunk 4: -inf + # chunk 5: -inf + # chunk 6: inf + # chunk 7: inf first then -inf + # chunk 8: -inf first then inf + # chunk 9: inf first then randn + # chunk 10: randn first then inf + # chunk 11: -inf first then randn + # chunk 12: randn first then -inf + # chunk 13: randn + # chunk 14: randn + # chunk 15: randn + # chunk 16: inf + # chunk 17: inf + # chunk 18: -inf + # chunk 19: -inf + # chunk 20: -inf + a[0, 0:2] = -inf + a[0, 2:4] = inf + a[0, 4:6] = -inf + a[0, 6] = inf + a[0, 7, : size_chunk // 2] = inf + a[0, 7, size_chunk // 2 :] = -inf + a[0, 8, : size_chunk // 2] = -inf + a[0, 8, size_chunk // 2 :] = inf + a[0, 9, : size_chunk // 2] = inf + a[0, 10, size_chunk // 2 :] = inf + a[0, 11, : size_chunk // 2] = -inf + a[0, 12, size_chunk // 2 :] = -inf + a[0, 16:18] = inf + a[0, 18:21] = -inf + + # inverse the pattern of batch 0 for batch 1 + a[1] = -a[0] + + # a typical masked softmax pattern for batch 2 + a[2, 10:] = -inf + + a = a.flatten(start_dim=1) + + return { + "a": a, + "v": v, + "size_chunk": size_chunk, + "n_chunks": n_chunks, + "dim_softmax": 1, + "device": device, + "dtype": dtype, + } + + +@pytest.mark.parametrize("has_amax", [True, False], ids=lambda x: f"has_amax:{x}") +def test_tiled_softmax_attention_update_correctness(softmax_test_tensors, has_amax): + """Test that tiled softmax attention update produces correct results.""" + a = softmax_test_tensors["a"] + v = softmax_test_tensors["v"] + size_chunk = softmax_test_tensors["size_chunk"] + n_chunks = softmax_test_tensors["n_chunks"] + dim_softmax = softmax_test_tensors["dim_softmax"] + device = softmax_test_tensors["device"] + + if not has_amax: + # Without amax to take away contribution from extreme values, + # lse_m is actually lse = lse_m + amax, where the sum often results in + # catastrophic cancellation. In this case, the tiled_softmax_attention_update + # can only work if the "-inf" padding in the attention score only shows up after + # those normal values but not preceding them, i.e., we have a lse pattern of: + # [... , -inf, -inf, ..., -inf]. See tiled_softmax_attention_update's + # comments for more details. + torch.manual_seed(42) + inf = PRECISION_TO_INF[DTYPE_TO_PRECISION[a.dtype]] + a = torch.randn_like(a, device=device) + a[:, a.shape[1] // 2 :] = -inf + + lse_m, o = None, None + amax = None + + for i_chunk in range(n_chunks): + i_begin = i_chunk * size_chunk + i_end = (i_chunk + 1) * size_chunk + ids_chunk = torch.arange(i_begin, i_end, device=device) + ids_cum_chunk = torch.arange(0, i_end, device=device) + + a_chunk = a.index_select(dim_softmax, ids_chunk) + v_chunk = v.index_select(dim_softmax, ids_chunk) + + # Compute expected cumulative results + lse_m_cum_expected, o_cum_expected, amax_cum_expected = expected_accum( + a.index_select(dim_softmax, ids_cum_chunk), + v.index_select(dim_softmax, ids_cum_chunk), + dim_softmax, + has_amax, + ) + + # Perform online softmax accumulation + # Step 1: compute per-block amax, lse_m, and o + amax_chunk = a_chunk.amax(dim=dim_softmax, keepdim=True) + a_chunk_delta = a_chunk - amax_chunk + + # Subtract out the current chunk's amax and keep it away from the + # following computation until absolutely impossible to do so any further + lse_m_chunk = torch.logsumexp(a_chunk_delta, dim=dim_softmax, keepdim=True) + s_chunk = torch.exp(a_chunk_delta - lse_m_chunk) + o_chunk = torch.einsum("...i, ...i -> ...", s_chunk, v_chunk).reshape_as(amax_chunk) + + if not has_amax: + # when has_amax is False, we use lse instead of lse_m for accumulation across chunks + lse_m_chunk = lse_m_chunk + amax_chunk + amax_chunk = None + + # Update accumulated values + o, lse_m, amax = tiled_softmax_attention_update(o_chunk, lse_m_chunk, amax_chunk, o, lse_m, amax) + + # Verify correctness against expected results + torch.testing.assert_close(lse_m, lse_m_cum_expected) + torch.testing.assert_close(o, o_cum_expected) + if has_amax: + torch.testing.assert_close(amax, amax_cum_expected) + else: + assert amax is None + assert amax_cum_expected is None + + +def test_tiled_softmax_attention_update_error_cases(): + """Test error handling for invalid inputs.""" + device = torch.device("cuda:0") + dtype = torch.float32 + + # Create some test tensors + o_chunk = torch.randn(3, 5, device=device, dtype=dtype) + lse_m_chunk = torch.randn(3, 1, device=device, dtype=dtype) + amax_chunk = torch.randn(3, 1, device=device, dtype=dtype) + + # Test case 1: Inconsistent None/not-None parameters for o and lse_m + with pytest.raises(ValueError, match="o and lse_m must both be None or both be not None"): + tiled_softmax_attention_update(o_chunk, lse_m_chunk, amax_chunk, o_chunk, None, None) + + # Test case 2: Shape mismatch between lse_m_chunk and amax_chunk + wrong_shape_amax = torch.randn(3, 2, device=device, dtype=dtype) + with pytest.raises(ValueError, match="lse_m_chunk and amax_chunk must have the same shape"): + tiled_softmax_attention_update(o_chunk, lse_m_chunk, wrong_shape_amax, None, None, None) + + # Test case 3: lse_m_chunk doesn't have last dimension of size 1 + wrong_lse_m = torch.randn(3, 2, device=device, dtype=dtype) + wrong_amax = torch.randn(3, 2, device=device, dtype=dtype) + with pytest.raises(ValueError, match="lse_m_chunk must have shape \\(\\.\\.\\.\\, 1\\)"): + tiled_softmax_attention_update(o_chunk, wrong_lse_m, wrong_amax, None, None, None) + + # Test case 4: Different number of dimensions between o_chunk and lse_m_chunk + wrong_dim_lse_m = torch.randn(3, 1, 1, device=device, dtype=dtype) + wrong_dim_amax = torch.randn(3, 1, 1, device=device, dtype=dtype) + with pytest.raises(ValueError, match="o_chunk and lse_m_chunk must have the same number of dimensions"): + tiled_softmax_attention_update(o_chunk, wrong_dim_lse_m, wrong_dim_amax, None, None, None) + + # Test case 5: Shape mismatch between o_chunk and lse_m_chunk (except last dimension) + wrong_batch_o = torch.randn(4, 5, device=device, dtype=dtype) # Different batch size + with pytest.raises( + ValueError, match="o_chunk and lse_m_chunk must have the same shape except for the last dimension" + ): + tiled_softmax_attention_update(wrong_batch_o, lse_m_chunk, amax_chunk, None, None, None) + + # Test case 6: Shape mismatch between o_chunk and o (non-initial chunk) + o_accum = torch.randn(3, 5, device=device, dtype=dtype) + lse_m_accum = torch.randn(3, 1, device=device, dtype=dtype) + amax_accum = torch.randn(3, 1, device=device, dtype=dtype) + wrong_shape_o_chunk = torch.randn(3, 6, device=device, dtype=dtype) # Different feature dimension + wrong_shape_lse_m_chunk = torch.randn(3, 1, device=device, dtype=dtype) + with pytest.raises(ValueError, match="o_chunk and o must have the same shape"): + tiled_softmax_attention_update( + wrong_shape_o_chunk, wrong_shape_lse_m_chunk, amax_chunk, o_accum, lse_m_accum, amax_accum + ) + + # Test case 7: Shape mismatch between lse_m_chunk and lse_m (non-initial chunk) + wrong_shape_lse_m_chunk = torch.randn(4, 1, device=device, dtype=dtype) # Different batch size + with pytest.raises(ValueError, match="lse_m_chunk and amax_chunk must have the same shape"): + tiled_softmax_attention_update(o_chunk, wrong_shape_lse_m_chunk, amax_chunk, o_accum, lse_m_accum, amax_accum) + + # Test case 8: Inconsistent amax and amax_chunk (non-initial chunk) + with pytest.raises( + ValueError, match="amax and amax_chunk must both be None or both be not None for non-initial chunks" + ): + tiled_softmax_attention_update(o_chunk, lse_m_chunk, amax_chunk, o_accum, lse_m_accum, None) + + # Test case 9: Shape mismatch between amax_chunk and amax (non-initial chunk) + wrong_shape_amax_accum = torch.randn(3, 2, device=device, dtype=dtype) + with pytest.raises(ValueError, match="amax_chunk and amax must have the same shape"): + tiled_softmax_attention_update(o_chunk, lse_m_chunk, amax_chunk, o_accum, lse_m_accum, wrong_shape_amax_accum) + + # Test case 10: Shape mismatch between lse_m_chunk and lse_m (non-initial chunk) + wrong_shape_lse_m_accum = torch.randn(3, 2, device=device, dtype=dtype) + with pytest.raises(ValueError, match="lse_m_chunk and lse_m must have the same shape"): + tiled_softmax_attention_update(o_chunk, lse_m_chunk, amax_chunk, o_accum, wrong_shape_lse_m_accum, amax_accum) + + +@pytest.mark.parametrize( + "dtype", + [torch.bfloat16, torch.float16], + ids=lambda d: str(d).split(".")[-1], +) +@pytest.mark.parametrize("has_amax", [True, False], ids=lambda x: f"has_amax:{x}") +def test_tiled_softmax_attention_update_dtype_preservation(dtype, has_amax): + """Regression: under autocast, torch.logsumexp promotes BF16/FP16 → FP32. + + torch.logsumexp preserves dtype without autocast but promotes to FP32 + when autocast is active (logsumexp is on autocast's FP32-promotion list). + In production, _RingMultiHeadTriangleAttentionImpl.forward uses + @custom_fwd without cast_inputs, which preserves the caller's autocast + context — so the promotion occurs during BF16-mixed training. + + Without the fix, the has_amax=True path promoted lse_m to FP32 on step 1 + (via logsumexp), which then infected delta_lse → sigmoid → o on step 2+. + The has_amax=False path (using logsigmoid) was not affected. + + This test wraps the calls in autocast and feeds 5 chunks in the input + dtype, asserting that o, lse_m, and amax preserve that dtype after every + step — including step 2+ where the cascade used to occur. + """ + device = torch.device("cuda:0") + n_chunks = 5 + batch, feat = 4, 8 + torch.manual_seed(0) + + o, lse_m, amax = None, None, None + with torch.amp.autocast("cuda", dtype=dtype): + for _ in range(n_chunks): + o_chunk = torch.randn(batch, feat, device=device, dtype=dtype) + lse_m_chunk = torch.randn(batch, 1, device=device, dtype=dtype) + amax_chunk = torch.randn(batch, 1, device=device, dtype=dtype) if has_amax else None + + o, lse_m, amax = tiled_softmax_attention_update(o_chunk, lse_m_chunk, amax_chunk, o, lse_m, amax) + + assert o.dtype == dtype, f"o promoted to {o.dtype}" + assert lse_m.dtype == dtype, f"lse_m promoted to {lse_m.dtype}" + if has_amax: + assert amax.dtype == dtype, f"amax promoted to {amax.dtype}" + else: + assert amax is None + + assert o.isfinite().all(), "o contains non-finite values" + assert lse_m.isfinite().all(), "lse_m contains non-finite values" diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py new file mode 100644 index 000000000..d2a28151d --- /dev/null +++ b/tests/distributed/test_utils.py @@ -0,0 +1,735 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import unittest +from random import randint +from unittest.mock import patch + +import pytest +import torch +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor + +from boltz.distributed.manager import DistributedManager +from boltz.distributed.utils import create_distributed_randn +from boltz.testing.utils import ( + assert_all_identical, + distribute_atom_features, + homogenize_shard_shapes, + seed_by_rank, + spawn_multiprocessing, +) + + +def assert_homogenize_shard_shapes_worker(rank: int, grid_group_sizes, world_size, device_type, backend, env_per_rank): + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + device_mesh = manager.device_mesh_subgroups + + # Define test cases inside the worker function + test_cases = [ + # (global_shape, placements, value_to_pad) + ( + (2 * device_mesh.size(0), 5 * device_mesh.size(1), 4 * device_mesh.size(2)), + (Shard(0), Shard(1), Shard(2)), + 0.0, + ), + ( + (3 * device_mesh.size(0), 4 * device_mesh.size(1)), + (Shard(0), Shard(1), Replicate()), + None, + ), + ( + (3 * device_mesh.size(0), 4 * device_mesh.size(1)), + (Shard(0), Replicate(), Shard(1)), + None, + ), + ( + (5 * device_mesh.size(0), 7 * device_mesh.size(1)), + (Replicate(), Shard(0), Shard(1)), + 0.0, + ), + ( + (5 * device_mesh.size(0), 7 * device_mesh.size(1)), + (Replicate(), Replicate(), Replicate()), + 0.0, + ), + ] + + seed_by_rank(manager.group_rank["world"], seed=42) + + for global_shape, placements, value_to_pad in test_cases: + label_test = f"global_shape={global_shape}, placements={placements}, value_to_pad={value_to_pad}" + # Create a global tensor with known values for verification + global_tensor = torch.randn(global_shape, dtype=torch.float32, device=manager.device) + + global_dtensor = distribute_tensor( + global_tensor, device_mesh=device_mesh, placements=placements, src_data_rank=0 + ) + + # slice local shard to make an expected tensor + shard = global_dtensor.to_local() + + slice_local = [] + for i in range(shard.ndim): + # randomly generate a slice towards the beginning along each tensor axis + # so that the resulting equivalent padding is towards the end + slice_local.append(slice(None, max(1, randint(1, shard.shape[i] - 1)), None)) + shard_sliced = shard[slice_local] + + if value_to_pad is None: + shard_expected = torch.zeros_like(shard) + else: + shard_expected = torch.full_like(shard, torch.tensor(value_to_pad, dtype=shard.dtype)) + # consistent with the target function homogenize_shard_shapes's padding pattern: + # always pad towards the end (or the last element) along each tensor axis + shard_expected[slice_local] = shard_sliced + + global_heterogeneous_dtensor = DTensor.from_local( + shard_sliced, + device_mesh=global_dtensor.device_mesh, + placements=global_dtensor.placements, + shape=global_dtensor.shape, + stride=global_dtensor.stride(), + ) + + expected = DTensor.from_local( + shard_expected, + device_mesh=global_dtensor.device_mesh, + placements=global_dtensor.placements, + shape=global_dtensor.shape, + stride=global_dtensor.stride(), + ) + + results = homogenize_shard_shapes(global_heterogeneous_dtensor, value_to_pad=value_to_pad) + + torch.testing.assert_close( + results.to_local(), + expected.to_local(), + atol=0, + rtol=0, + msg=lambda msg: f"{label_test}\n{msg}", + ) + + expected_full = expected.full_tensor() + results_full = results.full_tensor() + torch.testing.assert_close( + expected_full, + results_full, + atol=0, + rtol=0, + msg=lambda msg: f"{label_test}\n{msg}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((2, (2, 2)), True, "cuda", "ENV"), + ((3, (3, 3)), True, "cpu", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, device_type={x[2]}", +) +def test_homogenize_shard_shapes(setup_env): + """Test homogenize_shard_shapes function with various sharding configurations.""" + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("skip cuda test because torch.cuda.is_available == False") + if torch.cuda.device_count() < world_size: + pytest.skip(f"skip cuda test because torch.cuda.device_count() != {world_size}") + + spawn_multiprocessing( + assert_homogenize_shard_shapes_worker, + world_size, + grid_group_sizes, + world_size, + device_type, + backend, + env_per_rank, + ) + + +def assert_create_distributed_randn( + global_rank: int, + shape: tuple[int, ...], + placements: tuple[Shard | Replicate, ...], + grid_group_sizes: dict[str, int], + device_type: str, + backend: str, + env_per_rank: dict[str, str], + dtype: torch.dtype, + scale: float, + seed: int, + expected_source_ranks: set[int], +): + """Assert correctness of create_distributed_randn on each rank.""" + monkeypatch = pytest.MonkeyPatch() + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{global_rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + + # Use device_mesh_subgroups which has 3D structure: (dp, cp_axis_0, cp_axis_1) + device_mesh = manager.device_mesh_subgroups + + # Set seed for reproducibility + torch.manual_seed(seed + global_rank) + + # Use the expected source rank from parametrization + is_source_rank = global_rank in expected_source_ranks + + # Track whether torch.randn was called + randn_was_called = False + original_randn = torch.randn + + def mock_randn(*args, **kwargs): + nonlocal randn_was_called # does not leak to other subprocesses + randn_was_called = True + return original_randn(*args, **kwargs) + + # Mock torch.randn and create distributed random tensor + with patch("torch.randn", side_effect=mock_randn): + dtensor = create_distributed_randn( + shape=shape, + device_mesh=device_mesh, + placements=placements, + dtype=dtype, + scale=scale, + ) + + # Verify that torch.randn was called only on source ranks + if is_source_rank: + assert randn_was_called, f"Source rank {global_rank} should have called torch.randn" + else: + assert not randn_was_called, f"Non-source rank {global_rank} should not have called torch.randn" + + # Get local tensor + local_tensor = dtensor.to_local() + + # Verify DTensor properties + assert dtensor.device_mesh == device_mesh + assert dtensor.placements == placements + assert dtensor.shape == shape + assert dtensor.dtype == dtype + + # Verify local shape is correct based on placements + expected_local_shape = list(shape) + for i_dim_mesh, placement in enumerate(placements): + if placement.is_shard(): + expected_local_shape[placement.dim] = shape[placement.dim] // device_mesh.shape[i_dim_mesh] + assert local_tensor.shape == tuple(expected_local_shape) + + # Verify replication: for each mesh dimension with Replicate placement, + # all ranks along that mesh dimension should have identical values + for mesh_dim, placement in enumerate(placements): + if not placement.is_shard(): + # This mesh dimension is replicated, verify all ranks have identical local tensors + assert_all_identical( + local_tensor, + device_mesh.get_group(mesh_dim), + check_stride=False, + check_grad=False, + check_grad_fn=False, + check_storage_offset=False, + check_storage_pointer=False, + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env,placements,shape,expected_source_ranks", + [ + # ===== 3D tensors with 3D mesh ===== + # DP=1, CP=(2,2) tests - 3D mesh (1, 2, 2) + # Mesh coordinates: rank 0:(0,0,0), 1:(0,0,1), 2:(0,1,0), 3:(0,1,1) + ([(1, (2, 2)), True, "cpu", "ENV"], (Replicate(), Replicate(), Replicate()), (8, 16, 32), {0}), + ([(1, (2, 2)), True, "cpu", "ENV"], (Replicate(), Shard(1), Replicate()), (8, 16, 32), {0, 2}), + ([(1, (2, 2)), True, "cpu", "ENV"], (Replicate(), Replicate(), Shard(2)), (8, 16, 32), {0, 1}), + ([(1, (2, 2)), True, "cpu", "ENV"], (Replicate(), Shard(0), Shard(1)), (8, 16, 32), {0, 1, 2, 3}), + # DP=2, CP=(2,2) tests - 3D mesh (2, 2, 2) + # Mesh coordinates: rank 0:(0,0,0), 1:(0,0,1), 2:(0,1,0), 3:(0,1,1), 4:(1,0,0), 5:(1,0,1), 6:(1,1,0), 7:(1,1,1) + ([(2, (2, 2)), True, "cpu", "ENV"], (Shard(0), Replicate(), Replicate()), (8, 16, 32), {0, 4}), + ([(2, (2, 2)), True, "cpu", "ENV"], (Shard(0), Shard(1), Shard(2)), (8, 16, 32), {0, 1, 2, 3, 4, 5, 6, 7}), + ([(2, (2, 2)), True, "cpu", "ENV"], (Shard(0), Replicate(), Replicate()), (2,), {0, 4}), + ([(2, (2, 2)), True, "cpu", "ENV"], (Shard(0), Replicate(), Replicate()), (16,), {0, 4}), + # CUDA tests + ([(1, (2, 2)), True, "cuda", "ENV"], (Replicate(), Shard(0), Shard(1)), (8, 16, 32), {0, 1, 2, 3}), + ([(2, (2, 2)), True, "cuda", "ENV"], (Shard(0), Shard(1), Shard(2)), (8, 16, 32), {0, 1, 2, 3, 4, 5, 6, 7}), + ([(2, (2, 2)), True, "cuda", "ENV"], (Shard(0), Replicate(), Replicate()), (2,), {0, 4}), + ], + indirect=("setup_env",), + ids=[ + "dp:1_3d_all_replicated_cpu", + "dp:1_3d_shard_dim1_cp0_cpu", + "dp:1_3d_shard_dim2_cp1_cpu", + "dp:1_3d_shard_dim01_cp01_cpu", + "dp:2_3d_shard_dp_cpu", + "dp:2_3d_shard_all_cpu", + "dp:2_1d_batch2_shard_dp_cpu", + "dp:2_1d_batch16_shard_dp_cpu", + "dp:1_3d_shard_dim01_cp01_cuda", + "dp:2_3d_shard_all_cuda", + "dp:2_1d_batch2_shard_dp_cuda", + ], +) +def test_create_distributed_randn( + setup_env: dict[str, int], + placements: tuple[Shard | Replicate, ...], + shape: tuple[int, ...], + expected_source_ranks: set[int], + dtype: torch.dtype = torch.float32, + scale: float = 1.0, + seed: int = 42, +): + """Test create_distributed_randn with various configurations. + + This test covers: + - 3D tensors with various sharding patterns + - 1D tensors (diffusion noise use case) with batch_size=1 edge case + - Per-rank seeds to mimic training scenario + - Verification of replication correctness across mesh dimensions + - Verification that only source ranks call torch.randn + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda" and torch.cuda.device_count() < world_size: + pytest.skip("Not enough GPUs available") + + # Run test across all ranks + spawn_multiprocessing( + assert_create_distributed_randn, + world_size, + shape, + placements, + grid_group_sizes, + device_type, + backend, + env_per_rank, + dtype, + scale, + seed, + expected_source_ranks, + ) + + +def parallel_assert_distribute_atom_features( + rank: int, + grid_group_sizes: dict[str, int], + world_size: int, + device_type: str, + backend: str, + env_per_rank: dict[str, str], + # Test data passed from main test function + feats_global_host: dict[str, torch.Tensor], + placements_cp: dict[str, tuple], + placements_dp_cp: dict[str, tuple], + multiplicities: dict[str, int] | None, +): + """Parallel worker for testing distribute_atom_features. + + Tests that: + 1. DTensors are created correctly + 2. full_tensor() reconstructs data that matches original (accounting for interspersed padding) + 3. Non-padded values are preserved correctly + """ + monkeypatch = pytest.MonkeyPatch() + if env_per_rank is not None: + for var_name, value in env_per_rank.items(): + if value == "": + monkeypatch.setenv(var_name, f"{rank}") + continue + monkeypatch.setenv(var_name, value) + + DistributedManager.initialize(grid_group_sizes, device_type=device_type, backend=backend) + manager = DistributedManager() + device_mesh = manager.device_mesh_subgroups + + # Get test parameters + size_batch = feats_global_host["atom_pad_mask"].shape[0] + n_tokens = feats_global_host["token_pad_mask"].shape[1] + atom_counts_per_token = feats_global_host["atom_counts_per_token"] + + # Convert inputs to device + dtype = torch.float64 + inputs = { + k: v.to(device=manager.device, dtype=dtype if v.dtype.is_floating_point else v.dtype) + for k, v in feats_global_host.items() + if k in placements_cp + } + + # Call distribute_atom_features + feats_dtensor = distribute_atom_features( + inputs=inputs, + placements_cp=placements_cp, + placements_dp_cp=placements_dp_cp, + device_mesh=device_mesh, + cp_group=manager.group["cp"], + multiplicities=multiplicities, + ) + + # Compute shard metadata for extracting non-padded values + n_rows = device_mesh.size(1) # cp_axis_0 + n_tokens_per_shard = n_tokens // n_rows + token_atom_count_cumsum = torch.cat([torch.tensor([0]), atom_counts_per_token[0].cumsum(dim=0)]) + shard_atom_counts = atom_counts_per_token[0].unflatten(0, (n_rows, n_tokens_per_shard)).sum(dim=1) + max_atoms_per_shard = shard_atom_counts.max().item() + n_atoms_padded = max_atoms_per_shard * n_rows + + # Build mapping from padded indices to original indices for atom dimension + # Interspersed padding: shard i has atoms at padded indices [i*max_atoms_per_shard : i*max_atoms_per_shard + actual_atoms_in_shard] + # and these correspond to original indices [token_atom_count_cumsum[i*n_tokens_per_shard] : token_atom_count_cumsum[(i+1)*n_tokens_per_shard]] + padded_to_original = torch.full((n_atoms_padded,), -1, dtype=torch.long) + for i_shard in range(n_rows): + token_start = n_tokens_per_shard * i_shard + token_end = n_tokens_per_shard * (i_shard + 1) + atom_start_orig = token_atom_count_cumsum[token_start].item() + atom_end_orig = token_atom_count_cumsum[token_end].item() + actual_atoms = atom_end_orig - atom_start_orig + + padded_start = i_shard * max_atoms_per_shard + for j in range(actual_atoms): + padded_to_original[padded_start + j] = atom_start_orig + j + + # Extract valid (non-padding) indices + valid_atom_mask = padded_to_original >= 0 + valid_padded_indices = torch.where(valid_atom_mask)[0] + valid_original_indices = padded_to_original[valid_atom_mask] + + # Test each output DTensor + for k, dtensor in feats_dtensor.items(): + if k in {"atom_to_token", "token_to_rep_atom"}: + # Block-diagonal matrices need special handling + # These matrices have placement (Shard(0), Shard(1), Replicate()) but their structure + # means each shard (i, j) contains the diagonal block (i, i). + # The token dimension in the local shard is local (N_tokens_per_shard), not global (n_tokens). + # full_tensor() reconstructs [B, n_atoms_padded, N_tokens_per_shard] (NOT n_tokens). + # + # For proper comparison, we extract local shards and compare with the corresponding + # diagonal blocks from the reference. + + ref = feats_global_host[k].cpu() + local_shard = dtensor.to_local().cpu() + + # Get this rank's position in the mesh + rank_dp = manager.group_rank["dp"] + rank_cp_0 = manager.group_rank["cp_axis_0"] + + # For block-diagonal: the token range is determined by cp_axis_0 (row index) + token_start = n_tokens_per_shard * rank_cp_0 + token_end = n_tokens_per_shard * (rank_cp_0 + 1) + atom_start_orig = token_atom_count_cumsum[token_start].item() + atom_end_orig = token_atom_count_cumsum[token_end].item() + actual_atoms = atom_end_orig - atom_start_orig + + # local_shard shape: [1, max_atoms_per_shard, N_tokens_per_shard] or [1, N_tokens_per_shard, max_atoms_per_shard] + if k == "atom_to_token": + # Shape: [1, max_atoms_per_shard, N_tokens_per_shard] + # Extract non-padded part + local_valid = local_shard[0, :actual_atoms, :] + + # Reference block + ref_block = ref[rank_dp, atom_start_orig:atom_end_orig, token_start:token_end] + + torch.testing.assert_close( + local_valid, + ref_block.to(local_valid.dtype), + atol=0, + rtol=0, + msg=lambda m, k=k: f"feature {k}:\n {m}", + ) + + elif k == "token_to_rep_atom": + # Shape: [1, N_tokens_per_shard, max_atoms_per_shard] + # Extract non-padded part + local_valid = local_shard[0, :, :actual_atoms] + + # Reference block + ref_block = ref[rank_dp, token_start:token_end, atom_start_orig:atom_end_orig] + + torch.testing.assert_close( + local_valid, + ref_block.to(local_valid.dtype), + atol=0, + rtol=0, + msg=lambda m, k=k: f"feature {k}:\n {m}", + ) + + elif k == "pair_mask": + full_tensor = dtensor.full_tensor().cpu() + # pair_mask has shape [B, n_atoms_padded, n_atoms_padded] after full_tensor + # Need to extract valid [atom_i, atom_j] pairs + ref = feats_global_host[k].cpu() + for b in range(size_batch): + for i_shard in range(n_rows): + for j_shard in range(n_rows): + # Row shard atoms + row_token_start = n_tokens_per_shard * i_shard + row_token_end = n_tokens_per_shard * (i_shard + 1) + shard_atom_start = token_atom_count_cumsum[row_token_start].item() + shard_atom_end = token_atom_count_cumsum[row_token_end].item() + shard_actual_atoms = shard_atom_end - shard_atom_start + row_padded_start = i_shard * max_atoms_per_shard + + # Col shard atoms + col_token_start = n_tokens_per_shard * j_shard + col_token_end = n_tokens_per_shard * (j_shard + 1) + col_atom_start = token_atom_count_cumsum[col_token_start].item() + col_atom_end = token_atom_count_cumsum[col_token_end].item() + col_actual_atoms = col_atom_end - col_atom_start + col_padded_start = j_shard * max_atoms_per_shard + + # Extract block + block_full = full_tensor[ + b, + row_padded_start : row_padded_start + shard_actual_atoms, + col_padded_start : col_padded_start + col_actual_atoms, + ] + block_ref = ref[b, shard_atom_start:shard_atom_end, col_atom_start:col_atom_end] + + torch.testing.assert_close( + block_full, + block_ref.to(block_full.dtype), + atol=0, + rtol=0, + msg=lambda m, k=k: f"feature {k}:\n {m}", + ) + + else: + # Single representation features (atom_pad_mask, ref_pos, ref_charge, etc.) + # Shape: [B, n_atoms_padded, ...] - extract valid atoms using valid_padded_indices + full_tensor = dtensor.full_tensor().cpu() + ref = feats_global_host[k].cpu() + + # Handle different tensor dimensions + if full_tensor.ndim == 2: + # Shape: [B, n_atoms_padded] + extracted = full_tensor[:, valid_padded_indices] + # Reorder to original order + reordered = torch.empty_like(ref, dtype=full_tensor.dtype) + for idx, orig_idx in enumerate(valid_original_indices): + reordered[:, orig_idx] = extracted[:, idx] + + torch.testing.assert_close( + reordered, + ref.to(reordered.dtype), + atol=0, + rtol=0, + msg=lambda m, k=k: f"feature {k}:\n {m}", + ) + + elif full_tensor.ndim == 3: + # Shape: [B, n_atoms_padded, D] (e.g., ref_pos with D=3) + extracted = full_tensor[:, valid_padded_indices, :] + reordered = torch.empty_like(ref, dtype=full_tensor.dtype) + for idx, orig_idx in enumerate(valid_original_indices): + reordered[:, orig_idx, :] = extracted[:, idx, :] + + torch.testing.assert_close( + reordered, + ref.to(reordered.dtype), + atol=0, + rtol=0, + msg=lambda m, k=k: f"feature {k}:\n {m}", + ) + + elif full_tensor.ndim == 4: + # Shape: [B, n_atoms_padded, D1, D2] (e.g., ref_atom_name_chars) + extracted = full_tensor[:, valid_padded_indices, :, :] + reordered = torch.empty_like(ref, dtype=full_tensor.dtype) + for idx, orig_idx in enumerate(valid_original_indices): + reordered[:, orig_idx, :, :] = extracted[:, idx, :, :] + + torch.testing.assert_close( + reordered, + ref.to(reordered.dtype), + atol=0, + rtol=0, + msg=lambda m, k=k: f"feature {k}:\n {m}", + ) + + elif full_tensor.ndim == 5: + # Shape: [B, n_atoms_padded, D1, D2, D3] (e.g., ref_element one-hot) + extracted = full_tensor[:, valid_padded_indices, :, :, :] + reordered = torch.empty_like(ref, dtype=full_tensor.dtype) + for idx, orig_idx in enumerate(valid_original_indices): + reordered[:, orig_idx, :, :, :] = extracted[:, idx, :, :, :] + + torch.testing.assert_close( + reordered, + ref.to(reordered.dtype), + atol=0, + rtol=0, + msg=lambda m, k=k: f"feature {k}:\n {m}", + ) + + DistributedManager.cleanup() + monkeypatch.undo() + + +@pytest.mark.parametrize( + "setup_env", + [ + ((1, (2, 2)), True, "cuda", "ENV"), + ((2, (2, 2)), True, "cuda", "ENV"), + ], + indirect=("setup_env",), + ids=lambda x: f"dp={x[0][0]}, cp={x[0][1]}, device_type={x[2]}", +) +def test_distribute_atom_features(setup_env): + """Test distribute_atom_features with various atom feature types. + + Tests: + 1. Single representation features (atom_pad_mask, ref_pos, ref_charge, etc.) + 2. Block-diagonal features (atom_to_token, token_to_rep_atom) + 3. Pair representation features (pair_mask) + + Verifies that after distribute_atom_features + full_tensor(): + - Non-padded values match the original input + - Interspersed padding structure is correct + """ + grid_group_sizes, world_size, device_type, backend, _, env_per_rank = setup_env + + if device_type == "cuda": + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Not enough GPUs: need {world_size}, have {torch.cuda.device_count()}") + + # Use dp_size as batch size (required by distribute_atom_features) + dp_size = grid_group_sizes["dp"] + cp_tuple = grid_group_sizes["cp"] + + # Test parameters - tokens must be divisible by CP row size + n_tokens = 8 * cp_tuple[0] # e.g., 16 for cp=(2,2) + # n_atoms must be large enough to accommodate the random atom counts per token + # With (min=1, max=4) atoms per token, we need at least n_tokens * max_atoms + some buffer for last token + n_msa = 4 + atom_counts_per_token_range = (1, 3) # 1-3 atoms per token + max_atoms_per_token = atom_counts_per_token_range[1] + # Ensure enough atoms: (n_tokens - 1) * max + at least 1 for last token + n_atoms = n_tokens * max_atoms_per_token + max_atoms_per_token # extra buffer + + # Generate random features + from boltz.testing.utils import random_features + + # Seed for reproducibility + torch.manual_seed(42) + + feats_global_host = random_features( + size_batch=dp_size, + n_tokens=n_tokens, + n_atoms=n_atoms, + n_msa=n_msa, + atom_counts_per_token_range=atom_counts_per_token_range, + device=torch.device("cpu"), + float_value_range=(-1.0, 1.0), + selected_keys=[ + "atom_counts_per_token", + "atom_pad_mask", + "atom_to_token", + "token_to_rep_atom", + "pair_mask", + "ref_pos", + "ref_charge", + "ref_element", + "ref_atom_name_chars", + "ref_space_uid", + "token_pad_mask", # Needed for reference + ], + ) + + # Get actual n_atoms after random_features adjustment + n_atoms = feats_global_host["atom_pad_mask"].shape[1] + + # Define placements for CP submesh (2-tuple) + placements_cp_single = (Shard(0), Replicate()) + placements_cp_pair = (Shard(0), Shard(1)) + + placements_cp = { + "atom_counts_per_token": placements_cp_single, + "atom_pad_mask": placements_cp_single, + "atom_to_token": placements_cp_single, # Block-diagonal, stored as single + "token_to_rep_atom": placements_cp_single, # Block-diagonal, stored as single + "pair_mask": placements_cp_pair, + "ref_pos": placements_cp_single, + "ref_charge": placements_cp_single, + "ref_element": placements_cp_single, + "ref_atom_name_chars": placements_cp_single, + "ref_space_uid": placements_cp_single, + } + + # Define placements for full mesh (3-tuple: dp, cp_0, cp_1) + placements_single = (Shard(0), Shard(1), Replicate()) + placements_pair = (Shard(0), Shard(1), Shard(2)) + + placements_dp_cp = { + "atom_pad_mask": placements_single, + "atom_to_token": placements_single, + "token_to_rep_atom": placements_single, + "pair_mask": placements_pair, + "ref_pos": placements_single, + "ref_charge": placements_single, + "ref_element": placements_single, + "ref_atom_name_chars": placements_single, + "ref_space_uid": placements_single, + } + + # No multiplicity for this basic test + multiplicities = None + + spawn_multiprocessing( + parallel_assert_distribute_atom_features, + world_size, + grid_group_sizes, + world_size, + device_type, + backend, + env_per_rank, + feats_global_host, + placements_cp, + placements_dp_cp, + multiplicities, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/model/layers/test_outer_product_mean.py b/tests/model/layers/test_outer_product_mean.py index a9623e453..6010b64d7 100644 --- a/tests/model/layers/test_outer_product_mean.py +++ b/tests/model/layers/test_outer_product_mean.py @@ -1,4 +1,27 @@ -import pytorch_lightning +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# ruff: noqa +# fmt: off + import torch import torch.nn as nn @@ -13,16 +36,22 @@ def setUp(self): self.c_hidden = 16 self.c_out = 64 - torch.set_grad_enabled(False) - pytorch_lightning.seed_everything(1100) - self.layer = OuterProductMean(self.c_in, self.c_hidden, self.c_out) + # Use torch.random.fork_rng + torch.no_grad context managers instead of + # pytorch_lightning.seed_everything + torch.set_grad_enabled(False) to + # avoid polluting global RNG state and leaking disabled gradients into + # subsequent tests in the suite (the latter caused "element 0 of tensors + # does not require grad" failures in test_triattn_kernel and + # test_distogramv2). + with torch.random.fork_rng(), torch.no_grad(): + torch.manual_seed(1100) + self.layer = OuterProductMean(self.c_in, self.c_hidden, self.c_out) - # Initialize layer - for name, param in self.layer.named_parameters(): - nn.init.normal_(param, mean=1.0, std=1.0) + # Initialize layer + for name, param in self.layer.named_parameters(): + nn.init.normal_(param, mean=1.0, std=1.0) - # Set to eval mode - self.layer.eval() + # Set to eval mode + self.layer.eval() def test_chunk(self): chunk_sizes = [16, 33, 64, 83, 100] diff --git a/tests/model/layers/test_triangle_attention.py b/tests/model/layers/test_triangle_attention.py index e9fa51183..21ac311a2 100644 --- a/tests/model/layers/test_triangle_attention.py +++ b/tests/model/layers/test_triangle_attention.py @@ -1,4 +1,27 @@ -import pytorch_lightning +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# ruff: noqa +# fmt: off + import torch import torch.nn as nn @@ -13,13 +36,19 @@ def setUp(self): self.c_hidden = 32 self.no_heads = 1 - torch.set_grad_enabled(False) - pytorch_lightning.seed_everything(1100) - self.layer = TriangleAttention(self.c_in, self.c_hidden, self.no_heads) + # Use torch.random.fork_rng + torch.no_grad context managers instead of + # pytorch_lightning.seed_everything + torch.set_grad_enabled(False) to + # avoid polluting global RNG state and leaking disabled gradients into + # subsequent tests in the suite (the latter caused "element 0 of tensors + # does not require grad" failures in test_triattn_kernel and + # test_distogramv2). + with torch.random.fork_rng(), torch.no_grad(): + torch.manual_seed(1100) + self.layer = TriangleAttention(self.c_in, self.c_hidden, self.no_heads) - # Initialize layer - for name, param in self.layer.named_parameters(): - nn.init.normal_(param, mean=1.0, std=1.0) + # Initialize layer + for name, param in self.layer.named_parameters(): + nn.init.normal_(param, mean=1.0, std=1.0) def test_chunk(self): chunk_sizes = [16, 33, 64, 100] diff --git a/tests/model/layers/test_triattn_kernel.py b/tests/model/layers/test_triattn_kernel.py new file mode 100644 index 000000000..6f5c15944 --- /dev/null +++ b/tests/model/layers/test_triattn_kernel.py @@ -0,0 +1,759 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import math + +import cuequivariance_torch.primitives.triangle as cueq_triangle +import pytest +import torch +from trifast.torch import _triangle_attention as trifast_triangle_attention +from trifast.torch import triangle_attention_bwd as trifast_triangle_attention_bwd + +from boltz.distributed.model.modules.utils import PRECISION_TO_DTYPE, Precision, TriAttnBackend, setup_tf32_env +from boltz.model.layers.triangular_attention.primitives import _attention, cueq_is_installed, trifast_is_installed +from boltz.testing.utils import ( + PRECISION_TO_INF, + assert_no_percentile_upshift, + init_tensors_normal, + init_tensors_uniform, +) + + +def view_tensor_strided(t: torch.Tensor, argsort_strides: torch.Tensor): + """ + Create a strided view of a tensor with custom stride ordering. + + This function reorders the strides of a tensor based on the provided argsort + indices, allowing for custom memory layout configurations without copying data. + + Args: + t: Input tensor to create a strided view of. + argsort_strides: A 1D tensor containing indices that specify the desired + stride ordering. Must have the same number of elements as the number + of dimensions in `t`. + + Returns: + A strided view of the input tensor with reordered strides according to + the specified ordering. + + Raises: + AssertionError: If the number of dimensions in `t` doesn't match the + number of elements in `argsort_strides`. + """ + assert t.ndim == argsort_strides.numel(), "shape and argsort_strides must have the same length" + shape_sorted_by_strides = torch.tensor(t.shape)[argsort_strides[:-1]] + strides_sorted = torch.tensor([1] + shape_sorted_by_strides.tolist()).cumprod(dim=0) + # inverse argsort_strides + argsort_strides_inv = torch.argsort(argsort_strides) + strides_new = strides_sorted[argsort_strides_inv] + ans = torch.as_strided(t.flatten(), t.shape, strides_new.tolist()) + return ans + + +def make_input( + B, + H, + Q_len, # Query/output sequence length + KV_len, # Key/Value sequence length + C_hidden, + device, + dtype, + seed: int, + min_max: tuple[float, float] | None = None, + inf: float = 1e9, + use_mask: bool = True, + argsort_strides: torch.Tensor | None = None, +): + """Create test input tensors for triangular attention. + + Args: + B: Batch size + H: Number of heads + Q_len: Query sequence length (I in the original formulation) + KV_len: Key/Value sequence length (J in the original formulation) + C_hidden: Hidden dimension + device: Device + dtype: Data type + seed: Random seed + min_max: If provided, randomly sample from [min, max] uniformly for all + the returned tensors; otherwise, sample from standard normal distribution + inf: Infinity value for mask + use_mask: Whether to use mask + argsort_strides: If provided, reorder the strides of the returned tensors + according to the provided indices. In effect, the returned + tensors' stride()[argsort_strides] will be a exclusive cumulative product of + tensors' shape[argsort_strides]. + """ + N = Q_len + J = KV_len + Q = Q_len + K = KV_len + V = KV_len + torch.manual_seed(seed) + + # Create mask + if use_mask: + mask = torch.randint(0, 2, (B, N, 1, 1, J), device=device, dtype=dtype, requires_grad=False) + # Set some regions to zero for testing masking + if mask.shape[1] > 1: + mask[0, mask.shape[1] // 2 :, :, :, :] = 0 + if mask.shape[-1] > 1: + mask[0, :, :, :, mask.shape[-1] // 2 :] = 0 + else: + mask = None + + q = torch.empty(B, N, H, Q, C_hidden, device=device, dtype=dtype, requires_grad=True) + k = torch.empty(B, N, H, K, C_hidden, device=device, dtype=dtype, requires_grad=True) + v = torch.empty(B, N, H, V, C_hidden, device=device, dtype=dtype, requires_grad=True) + triangle_bias = torch.empty(B, 1, H, N, J, device=device, dtype=dtype, requires_grad=True) + + do = torch.empty_like(q) + + if min_max is None: + init_tensors_normal([q, k, v, triangle_bias, do]) + else: + init_tensors_uniform([q, k, v, triangle_bias, do], min_max[0], min_max[1]) + + # zero-initialize do for the invalid elements + if mask is not None: + with torch.no_grad(): + do = do * mask.any(dim=-1, keepdim=True) + + if argsort_strides is not None: + assert ( + argsort_strides.ndim == 1 and argsort_strides.numel() == 5 + ), "argsort_strides must be a 1D tensor of length 5" + q = view_tensor_strided(q, argsort_strides) + k = view_tensor_strided(k, argsort_strides) + v = view_tensor_strided(v, argsort_strides) + mask = view_tensor_strided(mask, argsort_strides) + triangle_bias = view_tensor_strided(triangle_bias, argsort_strides) + do = view_tensor_strided(do, argsort_strides) + + return q, k, v, mask, triangle_bias, do + + +def run_triangle_attention( + q, + k, + v, + triangle_bias, + do, + mask: torch.Tensor | None = None, + backend: TriAttnBackend = TriAttnBackend.REFERENCE, + precision: Precision = Precision.FP32, + check_bwd: bool = True, + scale: float = 1.0, + dtype_triangle_bias: torch.dtype | None = None, +): + """Run triangle attention operation with specified backend and precision. + + This function executes the triangle attention mechanism, which applies attention + with an additional triangle bias term. It supports a PyTorch reference implementation + and optimized GPU kernels (CUEQ, TRIFAST), and can run in various precisions. + + Args: + q: Query tensor of shape [batch, num_heads, seq_len, head_dim]. + k: Key tensor of shape [batch, num_heads, seq_len, head_dim]. + v: Value tensor of shape [batch, num_heads, seq_len, head_dim]. + triangle_bias: Triangle bias tensor of shape [batch, num_heads, seq_len, seq_len]. + do: Gradient output tensor of shape [batch, num_heads, seq_len, head_dim], + used for backward pass. + mask: Optional boolean mask tensor of shape [batch, 1, seq_len, seq_len] or + broadcastable shape. If provided, attention is masked where mask is False. + Defaults to None. + backend: Backend implementation to use (TriAttnBackend.REFERENCE, TriAttnBackend.CUEQ, + or TriAttnBackend.TRIFAST). Defaults to TriAttnBackend.REFERENCE. + precision: Precision mode for computation (Precision.FP16, Precision.BF16, + Precision.TF32, Precision.FP32, or Precision.FP64). FP64 is only supported + with the reference backend. Defaults to Precision.FP32. + check_bwd: Whether to run the backward pass and compute gradients. + Defaults to True. + scale: Scaling factor applied to attention scores. Defaults to 1.0. + dtype_triangle_bias: Optional dtype override for triangle_bias. If None, + uses the target dtype from precision. Defaults to None. + + Returns: + tuple: A 4-tuple containing: + - output: Attention output tensor of shape [batch, num_heads, seq_len, head_dim]. + - lse_m: Log-sum-exp values (shifted by max) of shape + [batch, num_heads, seq_len, 1]. + - amax: Maximum attention scores of shape [batch, num_heads, seq_len, 1]. + - input_grads: Dictionary with gradients for 'q', 'k', 'v', and 'triangle_bias'. + If check_bwd is False, all gradients are None. + + Raises: + ValueError: If precision is FP64 and backend is not REFERENCE. + ValueError: If backend is unknown. + + Note: + - Input tensors are cloned and converted to the target precision internally. + - The function uses appropriate TF32 environment settings based on precision (except TRIFAST). + - For CUEQ and TRIFAST backends, LSE/amax values are reshaped/adjusted to match reference format. + """ + if precision == Precision.FP64 and backend != TriAttnBackend.REFERENCE: + raise ValueError("FP64 is only supported for reference backend") + + if precision == Precision.TF32 and backend == TriAttnBackend.TRIFAST: + # trifast hardcodes the "input_precision" to "ieee", i.e., FP32, + # in tl.dot call, which locks down to doing FP32 matmul when input dtype is FP32. + raise ValueError("TF32 is not supported for TRIFAST backend") + + device = q.device + + target_dtype = PRECISION_TO_DTYPE[precision] + + # Convert inputs to target precision + q_work = q.detach().clone().to(dtype=target_dtype, device=device).requires_grad_(True) + k_work = k.detach().clone().to(dtype=target_dtype, device=device).requires_grad_(True) + v_work = v.detach().clone().to(dtype=target_dtype, device=device).requires_grad_(True) + triangle_bias_work = ( + triangle_bias.detach() + .clone() + .to( + dtype=target_dtype if dtype_triangle_bias is None else dtype_triangle_bias, + device=device, + ) + .requires_grad_(True) + ) + + if mask is None: + mask_work = None + else: + mask_work = mask.to(dtype=bool, device=device) + + do_work = do.detach().clone().to(dtype=target_dtype, device=device) + + # Must not change the input tensors' memory layout here + assert q_work.stride() == q.stride(), "q_work.stride() must be the same as q.stride()" + assert k_work.stride() == k.stride(), "k_work.stride() must be the same as k.stride()" + assert v_work.stride() == v.stride(), "v_work.stride() must be the same as v.stride()" + assert ( + triangle_bias_work.stride() == triangle_bias.stride() + ), "triangle_bias_work.stride() must be the same as triangle_bias.stride()" + if mask_work is not None: + ( + mask_work.stride() == mask.stride(), + "mask_work.stride() must be the same as mask.stride()", + ) + assert do_work.stride() == do.stride(), "do_work.stride() must be the same as do.stride()" + + # Run forward pass based on backend + if backend == TriAttnBackend.REFERENCE: + # reference implementation uses mask bias instead of mask + inf = PRECISION_TO_INF[precision] + if mask_work is None: + mask_bias = torch.zeros( + (q_work.shape[0], q_work.shape[1], 1, 1, k_work.shape[3]), dtype=target_dtype, device=device + ) + else: + mask_bias = inf * (mask_work.to(dtype=target_dtype) - 1.0) + biases = [mask_bias, triangle_bias_work] + with setup_tf32_env(precision): + output, lse_m, amax = _attention(q_work * scale, k_work, v_work, biases, return_lse=True) + # Run backward pass if requested + if check_bwd: + output.backward(do_work) + # Collect gradients + input_grads = { + "q": q_work.grad, + "k": k_work.grad, + "v": v_work.grad, + "triangle_bias": triangle_bias_work.grad, + } + elif backend == TriAttnBackend.CUEQ_FWD_TRIFAST_BWD: + # Forward pass uses CUEQ, backward pass uses TRIFAST + with setup_tf32_env(precision): + with torch.no_grad(): + # Run CUEQ forward + output, lse, amax = cueq_triangle.triangle_attention( + q_work, + k_work, + v_work, + triangle_bias_work, + mask=mask_work, + scale=scale, + return_aux=True, + ) + # add back the singleton K axis resulting from the max reduction + lse_reshaped = lse.unsqueeze(-1) + # need to return amax unsqueezed for comparison with reference + amax = amax.unsqueeze(-1) + lse_m = lse_reshaped - amax + + # Run backward pass if requested + if check_bwd: + # Reshape tensors from CUEQ format to TRIFAST format for backward + # q: [B, I, H, Q, C_hidden] --> [B, H, I, Q, C_hidden] + q_trifast = q_work.detach().transpose(-3, -4).contiguous().requires_grad_(True) + # k: [B, I, H, K, C_hidden] --> [B, H, I, K, C_hidden] + k_trifast = k_work.detach().transpose(-3, -4).contiguous().requires_grad_(True) + # v: [B, I, H, V, C_hidden] --> [B, H, I, V, C_hidden] + v_trifast = v_work.detach().transpose(-3, -4).contiguous().requires_grad_(True) + # triangle_bias: [B, 1, H, I, J] --> [B, H, I, J] + triangle_bias_trifast = triangle_bias_work.detach().squeeze(-4).contiguous().requires_grad_(True) + # output: [B, I, H, V, C_hidden] --> [B, H, I, V, C_hidden] + o_trifast = output.detach().transpose(-3, -4).contiguous() + # lse: [B, I, H, Q, 1] --> [B, H, I, Q, 1], and use lse instead of lse_m + lse_trifast = (lse_m + amax).detach().transpose(-3, -4).contiguous() + # do: [B, I, H, Q, C_hidden] --> [B, H, I, Q, C_hidden] + do_trifast = do_work.detach().transpose(-3, -4).contiguous() + + # TRIFAST mask convention: True for invalid positions, False for valid positions + # mask: [B, I, 1, 1, J] --> [B, I, J] + mask_trifast = ~(mask_work.detach().squeeze((-2, -3)).contiguous()) + + # Call TRIFAST backward + dq_trifast, dk_trifast, dv_trifast, dtriangle_bias_trifast, _ = trifast_triangle_attention_bwd( + do_trifast, + q_trifast, + k_trifast, + v_trifast, + triangle_bias_trifast, + o_trifast, + lse_trifast.squeeze(-1).to(dtype=torch.float32), + mask_trifast, + ) + + # Reshape gradients back to CUEQ format + # dq: [B, H, I, Q, C_hidden] --> [B, I, H, Q, C_hidden] + dq_cueq = dq_trifast.transpose(-3, -4).contiguous() + # dk: [B, H, I, K, C_hidden] --> [B, I, H, K, C_hidden] + dk_cueq = dk_trifast.transpose(-3, -4).contiguous() + # dv: [B, H, I, V, C_hidden] --> [B, I, H, V, C_hidden] + dv_cueq = dv_trifast.transpose(-3, -4).contiguous() + # dtriangle_bias: [B, H, I, J] --> [B, 1, H, I, J] + dtriangle_bias_cueq = dtriangle_bias_trifast.unsqueeze(-4).contiguous() + + # Manually set gradients + q_work.grad = dq_cueq + k_work.grad = dk_cueq + v_work.grad = dv_cueq + triangle_bias_work.grad = dtriangle_bias_cueq + + input_grads = { + "q": q_work.grad, + "k": k_work.grad, + "v": v_work.grad, + "triangle_bias": triangle_bias_work.grad, + } + else: + input_grads = {"q": None, "k": None, "v": None, "triangle_bias": None} + elif backend == TriAttnBackend.CUEQ: + with setup_tf32_env(precision): + output, lse, amax = cueq_triangle.triangle_attention( + q_work, + k_work, + v_work, + triangle_bias_work, + mask=mask_work, + scale=scale, + return_aux=True, + ) + # add back the singleton K axis resulting from the max reduction + lse_reshaped = lse.unsqueeze(-1) + # need to return amax unsqueezed for comparison with reference + amax = amax.unsqueeze(-1) + lse_m = lse_reshaped - amax + # manually call backward pass to emulate CP usage if requested + if check_bwd: + output.backward(do_work) + # Collect gradients + input_grads = { + "q": q_work.grad, + "k": k_work.grad, + "v": v_work.grad, + "triangle_bias": triangle_bias_work.grad, + } + else: + input_grads = {"q": None, "k": None, "v": None, "triangle_bias": None} + elif backend == TriAttnBackend.TRIFAST: + # No need to setup TF32 environment for TRIFAST + # as it hardcodes the "input_precision" to "ieee" (FP32) in tl.dot calls. + q_trifast = q_work.transpose(1, 2) + k_trifast = k_work.transpose(1, 2) + v_trifast = v_work.transpose(1, 2) + triangle_bias_trifast = triangle_bias_work.squeeze(1) + # TRIFAST mask convention: True for invalid positions, False for valid positions + if mask_work is None: + mask_trifast = torch.zeros( + q_trifast.shape[0], q_trifast.shape[2], q_trifast.shape[3], device=device, dtype=torch.bool + ) + else: + # mask: [B, I, 1, 1, K] --> [B, I, K] + mask_trifast = ~(mask_work.squeeze(dim=(2, 3)).to(dtype=torch.bool)) + output_trifast, lse = trifast_triangle_attention( + q_trifast, k_trifast, v_trifast, triangle_bias_trifast, mask_trifast + ) + # output: [B, H, I, J, C_hidden] --> [B, I, H, J, C_hidden] + output = output_trifast.transpose(1, 2) + if check_bwd: + output.backward(do_work) + # Collect gradients + input_grads = { + "q": q_work.grad, + "k": k_work.grad, + "v": v_work.grad, + "triangle_bias": triangle_bias_work.grad, + } + # TRIFAST's _triangle_attention API returns lse instead of lse - amax + # (i.e., lse_m). We return lse_m as-is and set amax to None to indicate this difference. + # lse: [B, H, I, J] --> [B, I, H, J, 1] + lse_m = lse.transpose(1, 2).unsqueeze(-1) + amax = None + else: + raise ValueError(f"Unknown backend: {backend}") + + return output, lse_m, amax, input_grads + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +@pytest.mark.parametrize( + "backend", + [TriAttnBackend.CUEQ, TriAttnBackend.TRIFAST, TriAttnBackend.CUEQ_FWD_TRIFAST_BWD], + ids=lambda x: f"backend:{x.value}", +) +@pytest.mark.parametrize("precision", [Precision.FP16, Precision.BF16, Precision.TF32, Precision.FP32]) +@pytest.mark.parametrize( + "argsort_strides", + [ + None, # LayoutRight + torch.tensor([4, 2, 3, 1, 0]), # D->H->S->I->B, commonly seen layout in OF/AF impl + torch.tensor([4, 3, 2, 1, 0]), # LayoutLeft + ], + ids=lambda x: f"argsort_strides:{'-'.join(map(str, x.tolist())) if x is not None else 'None'}", +) +def test_triangle_attention_kernel(backend, precision, argsort_strides): + """Test triangle attention kernels against reference implementation with error histogram analysis. + + This test validates triangle attention backend implementations (CUEQ, TRIFAST) against + the reference PyTorch implementation using error histogram analysis. It compares: + 1. FP64 reference (high precision baseline) + 2. Target precision reference (FP16/BF16/TF32/FP32) + 3. Backend kernel result at target precision (CUEQ or TRIFAST) + + The test also verifies: + - Forward pass output accuracy + - Backward pass gradient accuracy (when supported) + - Memory layout independence (contiguous vs strided) + - Masking correctness (zero gradients for masked regions) + """ + + if backend == TriAttnBackend.CUEQ and not cueq_is_installed: + pytest.skip("cuequivariance_torch is not installed") + + if backend == TriAttnBackend.TRIFAST and not trifast_is_installed: + pytest.skip("trifast is not installed") + + if backend == TriAttnBackend.CUEQ_FWD_TRIFAST_BWD: + if not cueq_is_installed: + pytest.skip("cuequivariance_torch is not installed") + if not trifast_is_installed: + pytest.skip("trifast is not installed") + if precision != Precision.FP32: + pytest.skip("CUEQ_FWD_TRIFAST_BWD only supports FP32 precision") + + if precision == Precision.TF32 and backend in (TriAttnBackend.TRIFAST, TriAttnBackend.CUEQ_FWD_TRIFAST_BWD): + # TRIFAST hardcodes the "input_precision" to "ieee" (FP32) in tl.dot calls, + # which locks down to FP32 matmul when input dtype is FP32. + pytest.skip("TF32 is not supported for TRIFAST backend") + + # Test parameters + H = 4 + N = 64 + C_hidden = 32 + B = 3 + device = "cuda:0" + scale = 1 / math.sqrt(C_hidden) + + # Skip backward pass for FP32 (without TF32) with CUEQ backend as it doesn't support it + # TRIFAST and CUEQ_FWD_TRIFAST_BWD support backward pass for all precisions + check_bwd = backend in (TriAttnBackend.TRIFAST, TriAttnBackend.CUEQ_FWD_TRIFAST_BWD) or ( + backend == TriAttnBackend.CUEQ + and (precision == Precision.TF32 or precision == Precision.BF16 or precision == Precision.FP16) + ) + + if precision == Precision.TF32: + if backend == TriAttnBackend.CUEQ: + # NOTE: The numerical error from CUEQ triangle attention with TF32 is significant. + # Use smaller input values to keep gradients in a reasonable range (1e-6 to 1e-5) + # for accurate gradient checking. + min_val = -0.02 + max_val = 0.02 + elif backend in (TriAttnBackend.TRIFAST, TriAttnBackend.CUEQ_FWD_TRIFAST_BWD): + min_val = -0.05 + max_val = 0.05 + elif precision == Precision.BF16: + min_val = -0.5 + max_val = 0.5 + elif precision == Precision.FP16: + min_val = -0.5 + max_val = 0.5 + elif precision == Precision.FP32: + min_val = -0.5 + max_val = 0.5 + else: + raise ValueError(f"Unsupported precision: {precision}") + + # Create test inputs in FP64 for highest precision, then cast as needed in run_triangle_attention + seed = 42 + q, k, v, mask, triangle_bias, do = make_input( + B, + H, + N, + N, + C_hidden, + device, + torch.float64, + seed, + min_max=(min_val, max_val), + inf=1e18, + use_mask=True, + argsort_strides=argsort_strides, + ) + + # === RUN COMPUTATIONS WITH DIFFERENT BACKENDS AND PRECISIONS === + + # Run FP64 reference (high precision baseline) + o_expected_fp64, lse_m_expected_fp64, amax_expected_fp64, grads_fp64 = run_triangle_attention( + q, + k, + v, + triangle_bias, + do, + mask=mask, + backend=TriAttnBackend.REFERENCE, + precision=Precision.FP64, + check_bwd=check_bwd, + scale=scale, + ) + + # Run FP32/TF32 reference (alternative precision) + o_expected_alt, lse_m_expected_alt, amax_expected_alt, grads_alt = run_triangle_attention( + q, + k, + v, + triangle_bias, + do, + mask=mask, + backend=TriAttnBackend.REFERENCE, + precision=precision, + check_bwd=check_bwd, + scale=scale, + ) + + # Run test backend implementation + o_result, lse_m_result, amax_result, grads_result = run_triangle_attention( + q, + k, + v, + triangle_bias, + do, + mask=mask, + backend=backend, + precision=precision, + check_bwd=check_bwd, + scale=scale, + ) + + # Compute masks for proper comparison + if mask is not None: + # mask: [B, I, 1, 1, K] --> [B, I, 1, K, 1] for dk and dv masking + mask_kv = mask.to(dtype=mask.dtype)[:, :, :, 0, :, None] + # mask_i: [B, I, 1, 1, 1] for output, lse, amax masking + mask_i = mask.any(dim=-1, keepdim=True) + if grads_result["triangle_bias"] is not None: + # mask_j: [B, 1, 1, 1, K] for dtriangle_bias masking + mask_j = mask.any(dim=1, keepdim=True).to(dtype=grads_result["triangle_bias"].dtype) + else: + mask_j = torch.ones((B, 1, 1, 1, N), dtype=o_result.dtype, requires_grad=False, device=o_result.device) + else: + mask_kv = torch.ones( + (B, N, 1, N, 1), + dtype=k.dtype, + requires_grad=False, + device=o_result.device, + ) + mask_i = torch.ones( + (B, N, 1, 1, 1), + dtype=q.dtype, + requires_grad=False, + device=o_result.device, + ) + mask_j = torch.ones( + (B, 1, 1, 1, N), + dtype=triangle_bias.dtype, + requires_grad=False, + device=o_result.device, + ) + + if precision != Precision.FP32: + # === ERROR HISTOGRAM ANALYSIS === + # Compare kernel implementations (CUEQ/TRIFAST) against PyTorch reference + # Different backends use different algorithmic approaches, so precision characteristics will differ + + # use mask to select valid elements for comparison + # to avoid numerical noise from invalid elements + mask_i = mask_i.to(dtype=bool) + # convert to FP32 for difference calculation + assert_no_percentile_upshift( + o_result[mask_i.expand_as(o_result)], + o_expected_fp64[mask_i.expand_as(o_expected_fp64)].to(dtype=torch.float32), + o_expected_alt[mask_i.expand_as(o_expected_alt)], + names_input=(f"o_{backend.value}_{precision}", "o_ref_fp64", f"o_ref_{precision}"), + ) + + # Test for lse_m and amax (backend-specific due to different return conventions) + if backend == TriAttnBackend.CUEQ: + assert_no_percentile_upshift( + lse_m_result[mask_i.expand_as(lse_m_result)], + lse_m_expected_fp64[mask_i.expand_as(lse_m_expected_fp64)].to(dtype=torch.float32), + lse_m_expected_alt[mask_i.expand_as(lse_m_expected_alt)], + names_input=(f"lse_m_{backend.value}_{precision}", "lse_m_ref_fp64", f"lse_m_ref_{precision}"), + ) + assert_no_percentile_upshift( + amax_result[mask_i.expand_as(amax_result)], + amax_expected_fp64[mask_i.expand_as(amax_expected_fp64)].to(dtype=torch.float32), + amax_expected_alt[mask_i.expand_as(amax_expected_alt)], + names_input=(f"amax_{backend.value}_{precision}", "amax_ref_fp64", f"amax_ref_{precision}"), + ) + elif backend == TriAttnBackend.TRIFAST: + # TRIFAST returns lse directly (not lse - amax), so compare against lse_m + amax + assert amax_result is None, "amax should be None for TRIFAST backend" + assert_no_percentile_upshift( + lse_m_result[mask_i.expand_as(lse_m_result)], + (lse_m_expected_fp64 + amax_expected_fp64)[mask_i.expand_as(lse_m_expected_fp64)].to( + dtype=torch.float32 + ), + (lse_m_expected_alt + amax_expected_alt)[mask_i.expand_as(lse_m_expected_alt)], + names_input=(f"lse_{backend.value}_{precision}", "lse_ref_fp64", f"lse_ref_{precision}"), + ) + else: + raise ValueError(f"Unknown backend: {backend}") + + if check_bwd: + mask_kv = mask_kv.to(dtype=bool) + mask_j = mask_j.to(dtype=bool) + # Test gradient error histograms + for grad_name in ["q", "k", "v", "triangle_bias"]: + grad_result = grads_result[grad_name] + grad_expected_fp64 = grads_fp64[grad_name] + grad_expected_alt = grads_alt[grad_name] + if grad_name == "q": + m = mask_i.expand_as(grad_result) + elif grad_name == "k" or grad_name == "v": + m = mask_kv.expand_as(grad_result) + elif grad_name == "triangle_bias": + m = mask_j.expand_as(grad_result) + else: + raise ValueError(f"Unknown gradient name: {grad_name}") + assert_no_percentile_upshift( + grad_result[m], + grad_expected_fp64[m].to(dtype=torch.float32), + grad_expected_alt[m], + names_input=( + f"d_{grad_name}_{backend.value}_{precision}", + f"d_{grad_name}_ref_fp64", + f"d_{grad_name}_ref_{precision}", + ), + ) + else: + # === SIMPLE TOLERANCE TESTING (FP32 only) === + # For FP32 without TF32, use simple assertion with default tolerances + torch.testing.assert_close(o_result * mask_i, o_expected_fp64.to(dtype=o_result.dtype) * mask_i) + + if backend in (TriAttnBackend.CUEQ, TriAttnBackend.CUEQ_FWD_TRIFAST_BWD): + torch.testing.assert_close(lse_m_result * mask_i, lse_m_expected_fp64.to(dtype=lse_m_result.dtype) * mask_i) + torch.testing.assert_close(amax_result * mask_i, amax_expected_fp64.to(dtype=amax_result.dtype) * mask_i) + elif backend == TriAttnBackend.TRIFAST: + assert amax_result is None, "amax should be None for TRIFAST backend" + # TRIFAST returns lse directly (not lse - amax), so compare against lse_m + amax + torch.testing.assert_close( + lse_m_result * mask_i, (lse_m_expected_fp64 + amax_expected_fp64).to(dtype=lse_m_result.dtype) * mask_i + ) + else: + raise ValueError(f"Unknown backend: {backend}") + + if check_bwd: + for grad_name in ["q", "k", "v", "triangle_bias"]: + grad_result = grads_result[grad_name] + grad_expected_fp64 = grads_fp64[grad_name] + torch.testing.assert_close( + grad_result, grad_expected_fp64.to(dtype=grad_result.dtype), msg=lambda m: f"d_{grad_name}:\n{m}" + ) + + # check gradients are zero in masked regions to prevent backprop invalid elements' gradient upstream + if check_bwd: + for grad_name in ["q", "k", "v", "triangle_bias"]: + grad_result = grads_result[grad_name] + grad_expected_fp64 = grads_fp64[grad_name] + grad_expected_alt = grads_alt[grad_name] + if grad_name == "q": + m = mask_i + elif grad_name == "k" or grad_name == "v": + m = mask_kv + elif grad_name == "triangle_bias": + m = mask_j + else: + raise ValueError(f"Unknown gradient name: {grad_name}") + torch.testing.assert_close(grad_result * ~(m.bool()), torch.zeros_like(grad_result), atol=0, rtol=0) + + # === CONTIGUITY CONSISTENCY TEST === + # Test self-consistency between contiguous vs non-contiguous layout + if argsort_strides is not None: + # Run the same backend with contiguous layout for comparison + q_contiguous = q.contiguous().detach().clone() + k_contiguous = k.contiguous().detach().clone() + v_contiguous = v.contiguous().detach().clone() + triangle_bias_contiguous = triangle_bias.contiguous().detach().clone() + do_contiguous = do.contiguous().detach().clone() + mask_contiguous = mask.contiguous().detach().clone() + + o_result_contiguous, lse_m_result_contiguous, amax_result_contiguous, grads_result_contiguous = ( + run_triangle_attention( + q_contiguous, + k_contiguous, + v_contiguous, + triangle_bias_contiguous, + do_contiguous, + mask=mask_contiguous, + backend=backend, + precision=precision, + check_bwd=check_bwd, + scale=scale, + ) + ) + + torch.testing.assert_close(o_result_contiguous, o_result, atol=0, rtol=0) + torch.testing.assert_close(lse_m_result_contiguous, lse_m_result, atol=0, rtol=0) + torch.testing.assert_close(amax_result_contiguous, amax_result, atol=0, rtol=0) + + if check_bwd: + for grad_name in ["q", "k", "v", "triangle_bias"]: + # triangle_bias's gradient is not binary identical across layouts for CUEQ backend + # due to atomic operations in the kernel implementation + atol, rtol = ( + (None, None) if (grad_name == "triangle_bias" and backend == TriAttnBackend.CUEQ) else (0, 0) + ) + torch.testing.assert_close( + grads_result_contiguous[grad_name], grads_result[grad_name], atol=atol, rtol=rtol + ) diff --git a/tests/model/loss/__init__.py b/tests/model/loss/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/tests/model/loss/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/tests/model/loss/test_cdist_lddt_validation.py b/tests/model/loss/test_cdist_lddt_validation.py new file mode 100644 index 000000000..741697f41 --- /dev/null +++ b/tests/model/loss/test_cdist_lddt_validation.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import pytest +import torch + +from boltz.distributed.model.loss.validation import clash_score +from boltz.distributed.model.loss.validation import factored_lddt_loss as triton_factored_lddt_loss +from boltz.model.loss.validation import factored_lddt_loss +from boltz.testing.utils import random_features + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("use_cardinality_weighted", [True, False]) +@pytest.mark.parametrize("repeat_atom_mask", [True, False]) +def test_factored_lddt_loss_cdist_consistency(use_cardinality_weighted, repeat_atom_mask): + device = torch.device("cuda") + rng = torch.Generator(device=device) + rng.manual_seed(0) + rng_features = torch.Generator(device=device) + rng_features.manual_seed(0) + + batch = 2 + num_tokens = 256 + num_atoms = num_tokens * 9 + multiplicity = 5 + + true_atom_coords = torch.rand(batch, num_atoms, 3, generator=rng, device=device) + pred_atom_coords = torch.rand(batch, num_atoms, 3, generator=rng, device=device) + + atom_mask_base = torch.randint(0, 2, (batch, num_atoms), generator=rng, device=device, dtype=torch.float32) + + true_atom_coords = true_atom_coords.repeat_interleave(multiplicity, 0) + pred_atom_coords = pred_atom_coords.repeat_interleave(multiplicity, 0) + atom_mask = atom_mask_base.repeat_interleave(multiplicity, 0) if repeat_atom_mask else atom_mask_base + + feats = random_features( + size_batch=batch, + n_tokens=num_tokens, + n_atoms=num_atoms, + n_msa=1, + atom_counts_per_token_range=(1, 9), + device=device, + float_value_range=(0.0, 1.0), + selected_keys=["atom_to_token", "asym_id", "mol_type"], + rng=rng_features, + ) + + feats = { + "atom_to_token": feats["atom_to_token"], + "mol_type": feats["mol_type"], + "asym_id": feats["asym_id"], + } + + atom_mask_ref = atom_mask_base.repeat_interleave(multiplicity, 0) + ref_lddt, ref_total = factored_lddt_loss( + true_atom_coords=true_atom_coords, + pred_atom_coords=pred_atom_coords, + feats=feats, + atom_mask=atom_mask_ref, + multiplicity=multiplicity, + cardinality_weighted=use_cardinality_weighted, + ) + triton_lddt, triton_total = triton_factored_lddt_loss( + true_atom_coords=true_atom_coords, + pred_atom_coords=pred_atom_coords, + feats=feats, + atom_mask=atom_mask, + multiplicity=multiplicity, + cardinality_weighted=use_cardinality_weighted, + ) + + saw_zero_total = False + for key in ref_lddt: + torch.testing.assert_close(triton_lddt[key], ref_lddt[key]) + torch.testing.assert_close(triton_total[key], ref_total[key]) + zero_total_mask = ref_total[key] == 0 + if torch.any(zero_total_mask): + saw_zero_total = True + torch.testing.assert_close(ref_lddt[key][zero_total_mask], torch.ones_like(ref_lddt[key][zero_total_mask])) + torch.testing.assert_close( + triton_lddt[key][zero_total_mask], torch.ones_like(triton_lddt[key][zero_total_mask]) + ) + assert saw_zero_total, "Expected at least one modality to have zero total." + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test runs on triton kernel but CUDA is not available") +def test_clash_score_counts_and_fraction(): + device = torch.device("cuda") + clash_cutoff = 2.0 + multiplicity = 2 + + coords_repr = torch.tensor( + [ + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [5.0, 0.0, 0.0], [9.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [3.0, 0.0, 0.0], [6.0, 0.0, 0.0], [9.0, 0.0, 0.0]], + ], + device=device, + dtype=torch.float32, + ) + token_pad_mask = torch.tensor([[True, True, True, False]], device=device) + + clash_atoms_count, clash_atoms_fraction = clash_score( + coords_repr=coords_repr, + token_pad_mask=token_pad_mask, + multiplicity=multiplicity, + clash_cutoff=clash_cutoff, + ) + + expected_count = torch.tensor([2, 0], device=device, dtype=clash_atoms_count.dtype) + expected_fraction = torch.tensor([2.0 / 3.0, 0.0], device=device, dtype=clash_atoms_fraction.dtype) + + torch.testing.assert_close(clash_atoms_count, expected_count) + torch.testing.assert_close(clash_atoms_fraction, expected_fraction) diff --git a/tests/model/loss/test_distogramv2.py b/tests/model/loss/test_distogramv2.py new file mode 100644 index 000000000..0d41ccb19 --- /dev/null +++ b/tests/model/loss/test_distogramv2.py @@ -0,0 +1,280 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for the Boltz-2 distogramv2 serial loss function. + +Tests verify: +1. v2 matches v1 for K=1, D=1 (backward-compatible legacy behavior) +2. Aggregate invariant: K identical conformers == K=1 +3. Aggregate error: D>1 raises +4. Non-aggregate semantic: min-over-D selects best prediction +5. Masking: diagonal + padded tokens get zero gradient; fully masked batch → ~0 loss +6. Gradient correctness: finite across all (K,D,agg) configs + numerical gradcheck +7. Autocast: loss stays float32 +""" + +import pytest +import torch + +from boltz.model.loss.distogram import distogram_loss as distogram_loss_v1 +from boltz.model.loss.distogramv2 import distogram_loss as distogram_loss_v2 + + +def test_aggregate_k1_matches_v1(): + """v2 aggregate with K=1, D=1 must produce identical results to v1 (forward + backward).""" + B, N, num_bins = 2, 16, 64 + D = 1 + + with torch.random.fork_rng(): + torch.manual_seed(42) + + pred_v2 = torch.randn(B, N, N, D, num_bins, requires_grad=True) + pred_v1 = pred_v2.squeeze(3).detach().clone().requires_grad_(True) + + target_idx = torch.randint(0, num_bins, (B, N, N)) + target_onehot = torch.nn.functional.one_hot(target_idx, num_classes=num_bins).float() + target_v2 = target_onehot.unsqueeze(3) + target_v1 = target_onehot + + mask = torch.ones(B, N) + mask[0, 12:] = 0 + mask[1, 8:] = 0 + + global_loss_v2, batch_loss_v2 = distogram_loss_v2( + {"pdistogram": pred_v2}, {"disto_target": target_v2, "token_disto_mask": mask}, aggregate_distogram=True + ) + global_loss_v1, batch_loss_v1 = distogram_loss_v1( + {"pdistogram": pred_v1}, {"disto_target": target_v1, "token_disto_mask": mask} + ) + + torch.testing.assert_close(global_loss_v2, global_loss_v1) + torch.testing.assert_close(batch_loss_v2, batch_loss_v1) + + global_loss_v2.backward() + global_loss_v1.backward() + torch.testing.assert_close(pred_v2.grad.squeeze(3), pred_v1.grad) + + +def test_aggregate_identical_conformers(): + """K identical conformers must produce the same loss and gradient as K=1.""" + B, N, num_bins = 2, 8, 16 + K, D = 3, 1 + + with torch.random.fork_rng(): + torch.manual_seed(42) + + pred = torch.randn(B, N, N, D, num_bins, requires_grad=True) + pred_single = pred.detach().clone().requires_grad_(True) + + target_idx = torch.randint(0, num_bins, (B, N, N)) + target_onehot = torch.nn.functional.one_hot(target_idx, num_classes=num_bins).float() + target_multi = target_onehot.unsqueeze(3).repeat(1, 1, 1, K, 1) + target_single = target_onehot.unsqueeze(3) + + mask = torch.ones(B, N) + + loss_multi, _ = distogram_loss_v2( + {"pdistogram": pred}, {"disto_target": target_multi, "token_disto_mask": mask}, aggregate_distogram=True + ) + loss_single, _ = distogram_loss_v2( + {"pdistogram": pred_single}, + {"disto_target": target_single, "token_disto_mask": mask}, + aggregate_distogram=True, + ) + + torch.testing.assert_close(loss_multi, loss_single) + + loss_multi.backward() + loss_single.backward() + torch.testing.assert_close(pred.grad, pred_single.grad) + + +def test_aggregate_rejects_multi_distogram(): + """Aggregate mode must reject D>1.""" + pred = torch.randn(2, 8, 8, 2, 16) + target = torch.rand(2, 8, 8, 1, 16) + mask = torch.ones(2, 8) + + with pytest.raises(AssertionError, match="Cannot aggregate GT distogram when num_distograms > 1"): + distogram_loss_v2( + {"pdistogram": pred}, {"disto_target": target, "token_disto_mask": mask}, aggregate_distogram=True + ) + + +def test_non_aggregate_min_selects_best_prediction(): + """Non-aggregate min-over-D should yield low loss when one prediction matches the target.""" + B, N, num_bins = 1, 4, 8 + D = 2 + + with torch.random.fork_rng(): + torch.manual_seed(42) + + target_idx = torch.randint(0, num_bins, (B, N, N)) + target = torch.nn.functional.one_hot(target_idx, num_classes=num_bins).float().unsqueeze(3) + + pred = torch.randn(B, N, N, D, num_bins) + # First prediction gets high logit for correct bin; second is random + pred_good = torch.zeros(B, N, N, num_bins) + pred_good.scatter_(-1, target_idx.unsqueeze(-1), 10.0) + with torch.no_grad(): + pred[:, :, :, 0, :] = pred_good + pred[:, :, :, 1, :] = torch.randn(B, N, N, num_bins) + pred.requires_grad_(True) + + mask = torch.ones(B, N) + global_loss, _ = distogram_loss_v2( + {"pdistogram": pred}, {"disto_target": target, "token_disto_mask": mask}, aggregate_distogram=False + ) + + assert global_loss < 1.0, f"Loss should be low when one prediction matches, got {global_loss}" + + +def test_masking_zeroes_gradients(): + """Diagonal, padded-token, and fully-masked-batch positions must get zero gradient / ~0 loss.""" + B, N, num_bins = 2, 8, 16 + K, D = 1, 1 + + with torch.random.fork_rng(): + torch.manual_seed(42) + + # --- Part 1: diagonal + padded tokens --- + pred = torch.randn(B, N, N, D, num_bins, requires_grad=True) + target = torch.rand(B, N, N, K, num_bins) + + mask = torch.ones(B, N) + mask[:, N // 2 :] = 0 # mask out second half of tokens + + global_loss, _ = distogram_loss_v2( + {"pdistogram": pred}, {"disto_target": target, "token_disto_mask": mask}, aggregate_distogram=True + ) + global_loss.backward() + + # Diagonal elements should have zero gradient + for b in range(B): + for i in range(N): + assert torch.allclose( + pred.grad[b, i, i, :, :], torch.zeros_like(pred.grad[b, i, i, :, :]), atol=1e-6 + ), f"Diagonal gradient at [{b},{i},{i}] should be zero" + + # Masked rows and columns should have zero gradient + for b in range(B): + for i in range(N // 2, N): + assert torch.allclose( + pred.grad[b, i, :, :, :], torch.zeros_like(pred.grad[b, i, :, :, :]), atol=1e-6 + ), f"Gradient for masked row {i} should be zero" + assert torch.allclose( + pred.grad[b, :, i, :, :], torch.zeros_like(pred.grad[b, :, i, :, :]), atol=1e-6 + ), f"Gradient for masked column {i} should be zero" + + # --- Part 2: fully masked batch element --- + pred2 = torch.randn(B, N, N, D, num_bins, requires_grad=True) + target2 = torch.rand(B, N, N, K, num_bins) + mask2 = torch.ones(B, N) + mask2[1, :] = 0 + + _, batch_loss = distogram_loss_v2( + {"pdistogram": pred2}, {"disto_target": target2, "token_disto_mask": mask2}, aggregate_distogram=True + ) + + assert batch_loss[1] < 1e-4, f"Fully masked batch should have ~0 loss, got {batch_loss[1]}" + + +@pytest.mark.parametrize( + "loss_config", + [ + (1, 1, True), + (3, 1, True), + (1, 2, False), + (3, 2, False), + ], + ids=lambda c: f"K={c[0]}|D={c[1]}|agg={c[2]}", +) +def test_gradient_finite(loss_config): + """Gradients must be finite for all (K, D, aggregate) configurations.""" + K, D, aggregate = loss_config + B, N, num_bins = 2, 16, 64 + + with torch.random.fork_rng(): + torch.manual_seed(42) + + pred = torch.randn(B, N, N, D, num_bins, requires_grad=True) + target = torch.rand(B, N, N, K, num_bins) + target = target / target.sum(dim=-1, keepdim=True).clamp(min=1e-8) + mask = torch.ones(B, N) + mask[0, N // 2 :] = 0 + + global_loss, batch_loss = distogram_loss_v2( + {"pdistogram": pred}, {"disto_target": target, "token_disto_mask": mask}, aggregate_distogram=aggregate + ) + + # Shape sanity (replaces removed test_output_shapes) + assert global_loss.shape == (), f"Global loss should be scalar, got {global_loss.shape}" + assert batch_loss.shape == (B,), f"Batch loss should be [B], got {batch_loss.shape}" + + global_loss.backward() + assert not torch.isnan(pred.grad).any(), f"NaN in gradients for {loss_config}" + assert not torch.isinf(pred.grad).any(), f"Inf in gradients for {loss_config}" + + +def test_gradient_numerical_check(): + """Numerical gradient check (torch.autograd.gradcheck) for aggregate mode.""" + B, N, num_bins = 1, 4, 8 + + with torch.random.fork_rng(): + torch.manual_seed(42) + + pred = torch.randn(B, N, N, 1, num_bins, requires_grad=True, dtype=torch.float64) + target = torch.rand(B, N, N, 1, num_bins, dtype=torch.float64) + target = target / target.sum(dim=-1, keepdim=True).clamp(min=1e-8) + mask = torch.ones(B, N, dtype=torch.float64) + + def loss_fn(p): + return distogram_loss_v2( + {"pdistogram": p}, {"disto_target": target, "token_disto_mask": mask}, aggregate_distogram=True + )[0] + + result = torch.autograd.gradcheck(loss_fn, pred, eps=1e-5, atol=1e-3, rtol=1e-2, nondet_tol=1e-3) + assert result, "Gradient check failed for aggregate mode" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for autocast test") +def test_autocast_disabled_in_loss(): + """Loss computation must stay float32 even under autocast.""" + B, N, num_bins = 2, 8, 16 + + with torch.random.fork_rng(devices=["cuda"]): + torch.manual_seed(42) + + pred = torch.randn(B, N, N, 1, num_bins, device="cuda", requires_grad=True) + target = torch.rand(B, N, N, 1, num_bins, device="cuda") + mask = torch.ones(B, N, device="cuda") + + with torch.autocast("cuda", dtype=torch.float16): + global_loss, batch_loss = distogram_loss_v2( + {"pdistogram": pred}, {"disto_target": target, "token_disto_mask": mask}, aggregate_distogram=True + ) + + assert global_loss.dtype == torch.float32, f"Global loss should be float32, got {global_loss.dtype}" + assert batch_loss.dtype == torch.float32, f"Batch loss should be float32, got {batch_loss.dtype}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/model/loss/test_factored_token_lddt_dist_loss.py b/tests/model/loss/test_factored_token_lddt_dist_loss.py new file mode 100644 index 000000000..83d2e514b --- /dev/null +++ b/tests/model/loss/test_factored_token_lddt_dist_loss.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import pytest +import torch + +from boltz.data import const +from boltz.distributed.model.loss.validation import factored_token_lddt_dist_loss_triton +from boltz.model.loss.validation import factored_token_lddt_dist_loss +from boltz.testing.utils import random_features + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("use_cardinality_weighted", [True, False]) +@pytest.mark.parametrize("pass_pred_d", [True, False]) +@pytest.mark.parametrize("pass_true_d", [True, False]) +def test_factored_token_lddt_dist_loss_triton_parity(use_cardinality_weighted, pass_pred_d, pass_true_d): + """Verify factored_token_lddt_dist_loss_triton matches the serial factored_token_lddt_dist_loss.""" + device = torch.device("cuda") + rng = torch.Generator(device=device) + rng.manual_seed(42) + rng_features = torch.Generator(device=device) + rng_features.manual_seed(42) + + batch = 2 + num_tokens = 256 + multiplicity = 1 + + pred_token_coords = torch.rand(batch, num_tokens, 3, generator=rng, device=device) + true_token_coords = torch.rand(batch, num_tokens, 3, generator=rng, device=device) + + feats = random_features( + size_batch=batch, + n_tokens=num_tokens, + n_atoms=num_tokens * 3, + n_msa=1, + atom_counts_per_token_range=(1, 3), + device=device, + float_value_range=(0.0, 1.0), + selected_keys=["mol_type", "asym_id", "token_disto_mask"], + rng=rng_features, + ) + + mol_type = feats["mol_type"].long() + token_disto_mask = feats["token_disto_mask"].float() + asym_id = feats["asym_id"].long() + + pred_d = torch.cdist(pred_token_coords, pred_token_coords) + true_d = torch.cdist(true_token_coords, true_token_coords) + + serial_feats = { + "mol_type": mol_type, + "token_disto_mask": token_disto_mask, + "asym_id": asym_id, + } + ref_lddt, ref_total = factored_token_lddt_dist_loss( + true_d=true_d, + pred_d=pred_d, + feats=serial_feats, + cardinality_weighted=use_cardinality_weighted, + ) + + triton_lddt, triton_total = factored_token_lddt_dist_loss_triton( + pred_token_coords=pred_token_coords, + true_token_coords=true_token_coords, + mol_type=mol_type, + token_disto_mask=token_disto_mask, + asym_id=asym_id, + multiplicity=multiplicity, + cardinality_weighted=use_cardinality_weighted, + pred_d=pred_d if pass_pred_d else None, + true_d=true_d if pass_true_d else None, + ) + + for key in ref_lddt: + torch.testing.assert_close(triton_lddt[key], ref_lddt[key], atol=1e-5, rtol=1e-5) + torch.testing.assert_close(triton_total[key], ref_total[key], atol=1e-5, rtol=1e-5) + zero_total_mask = ref_total[key] == 0 + if torch.any(zero_total_mask): + torch.testing.assert_close(ref_lddt[key][zero_total_mask], torch.ones_like(ref_lddt[key][zero_total_mask])) + torch.testing.assert_close( + triton_lddt[key][zero_total_mask], torch.ones_like(triton_lddt[key][zero_total_mask]) + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_factored_token_lddt_dist_loss_triton_multiplicity(): + """Verify triton token LDDT works with multiplicity > 1.""" + device = torch.device("cuda") + rng = torch.Generator(device=device) + rng.manual_seed(99) + rng_features = torch.Generator(device=device) + rng_features.manual_seed(99) + + batch = 2 + num_tokens = 128 + multiplicity = 3 + + pred_base = torch.rand(batch, num_tokens, 3, generator=rng, device=device) + true_base = torch.rand(batch, num_tokens, 3, generator=rng, device=device) + + pred_token_coords = pred_base.repeat_interleave(multiplicity, 0) + true_token_coords = true_base.repeat_interleave(multiplicity, 0) + + feats = random_features( + size_batch=batch, + n_tokens=num_tokens, + n_atoms=num_tokens * 3, + n_msa=1, + atom_counts_per_token_range=(1, 3), + device=device, + float_value_range=(0.0, 1.0), + selected_keys=["mol_type", "asym_id", "token_disto_mask"], + rng=rng_features, + ) + + mol_type = feats["mol_type"].long() + token_disto_mask = feats["token_disto_mask"].float() + asym_id = feats["asym_id"].long() + + triton_lddt, triton_total = factored_token_lddt_dist_loss_triton( + pred_token_coords=pred_token_coords, + true_token_coords=true_token_coords, + mol_type=mol_type, + token_disto_mask=token_disto_mask, + asym_id=asym_id, + multiplicity=multiplicity, + ) + + pred_d = torch.cdist(pred_base, pred_base) + true_d = torch.cdist(true_base, true_base) + + serial_feats = { + "mol_type": mol_type, + "token_disto_mask": token_disto_mask, + "asym_id": asym_id, + } + ref_lddt, ref_total = factored_token_lddt_dist_loss( + true_d=true_d, + pred_d=pred_d, + feats=serial_feats, + ) + + for key in ref_lddt: + triton_per_sample = triton_lddt[key].reshape(batch, multiplicity) + for m in range(multiplicity): + torch.testing.assert_close(triton_per_sample[:, m], ref_lddt[key], atol=1e-5, rtol=1e-5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_factored_token_lddt_dist_loss_triton_zero_total_defaults_to_one(): + """When a modality has no valid pairs, score must be 1.0 (not NaN or 0).""" + device = torch.device("cuda") + B = 1 + N = 8 + + pred = torch.rand(B, N, 3, device=device) + true = torch.rand(B, N, 3, device=device) + + mol_type = torch.full((B, N), const.chain_type_ids["PROTEIN"], dtype=torch.long, device=device) + token_disto_mask = torch.ones(B, N, device=device) + asym_id = torch.zeros(B, N, dtype=torch.long, device=device) + + lddt_dict, total_dict = factored_token_lddt_dist_loss_triton( + pred_token_coords=pred, + true_token_coords=true, + mol_type=mol_type, + token_disto_mask=token_disto_mask, + asym_id=asym_id, + ) + + empty_modalities = [ + "dna_protein", + "rna_protein", + "dna_ligand", + "rna_ligand", + "ligand_protein", + "intra_ligand", + "intra_dna", + "intra_rna", + "protein_protein", + ] + for key in empty_modalities: + assert total_dict[key].item() == 0, f"{key} should have zero total" + assert lddt_dict[key].item() == 1.0, f"{key} should default to 1.0 when no pairs exist" diff --git a/tests/model/validation/test_validator.py b/tests/model/validation/test_validator.py new file mode 100644 index 000000000..8610c831b --- /dev/null +++ b/tests/model/validation/test_validator.py @@ -0,0 +1,467 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for Validator metric logging in common_on_epoch_end. + +Validates that all metrics computed and accumulated during validation_step +are properly read, reset, and logged when the epoch ends. The key concern +is that confidence-ranked lDDT, PDE MAE, and PAE MAE metrics were being +accumulated but never logged — a silent data loss bug. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +import torch + +from boltz.data import const +from boltz.model.validation.validator import Validator + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +# Pair-type keys returned by factored_lddt_loss / compute_pde_mae / compute_pae_mae. +# These are const.out_types minus "modified". +PAIR_METRIC_KEYS = [m for m in const.out_types if m != "modified"] + +# Folding metric iteration keys (includes pocket/contact variants). +FOLDING_METRIC_KEYS = [*const.out_types, "pocket_ligand_protein", "contact_protein_protein"] + +CONFIDENCE_PREFIXES = [ + "top1", + "iplddt_top1", + "ipde_top1", + "pde_top1", + "ptm_top1", + "iptm_top1", + "ligand_iptm_top1", + "protein_iptm_top1", + "avg", +] + + +def _make_mock_model(confidence_prediction: bool = True) -> tuple[MagicMock, dict[str, float]]: + """Create a mock LightningModule that records log() calls. + + Returns + ------- + model : MagicMock + Mock with `.log()` wired to record calls. + logged : dict + Mapping from metric name to logged value, populated by `model.log()`. + """ + model = MagicMock() + model.confidence_prediction = confidence_prediction + logged: dict[str, float] = {} + + def _log(name: str, value, **kwargs): + logged[name] = value + + model.log = _log + return model, logged + + +def _make_validator(confidence_prediction: bool = True) -> Validator: + """Create a Validator with a single 'RCSB' validation dataset.""" + return Validator( + val_names=["RCSB"], + confidence_prediction=confidence_prediction, + physicalism_metrics=False, + ) + + +def _populate_folding_metrics(validator: Validator) -> None: + """Feed known values into all folding MeanMetrics so compute() is non-NaN.""" + idx = 0 + for m_ in FOLDING_METRIC_KEYS: + validator.folding_metrics["lddt"][idx][m_].update(torch.tensor(0.5), torch.tensor(1.0)) + validator.folding_metrics["disto_lddt"][idx][m_].update(torch.tensor(0.4), torch.tensor(1.0)) + validator.folding_metrics["complex_lddt"][idx][m_].update(torch.tensor(0.6), torch.tensor(1.0)) + validator.folding_metrics["disto_loss"][idx]["disto_loss"].update(torch.tensor(0.1)) + + +def _populate_confidence_metrics(validator: Validator) -> None: + """Feed known values into all confidence MeanMetrics that exist. + + Populates plddt_mae (always initialized) and — if they have been + initialized — the confidence-ranked lDDT, pde_mae, and pae_mae metrics. + """ + idx = 0 + # plddt_mae — always initialized + for m in const.out_single_types: + validator.confidence_metrics["plddt_mae"][idx][m].update(torch.tensor(0.05), torch.tensor(1.0)) + + # Confidence-ranked lDDT metrics + for prefix in CONFIDENCE_PREFIXES: + label = f"{prefix}_lddt" + for key in PAIR_METRIC_KEYS: + if key in validator.confidence_metrics[label][idx]: + validator.confidence_metrics[label][idx][key].update(torch.tensor(0.45), torch.tensor(1.0)) + + # PDE MAE and PAE MAE + for mae_label in ["pde_mae", "pae_mae"]: + for key in PAIR_METRIC_KEYS: + if key in validator.confidence_metrics[mae_label][idx]: + validator.confidence_metrics[mae_label][idx][key].update(torch.tensor(0.15), torch.tensor(1.0)) + + +# --------------------------------------------------------------------------- +# Expected metric names +# --------------------------------------------------------------------------- + + +def _expected_folding_metric_names() -> set[str]: + """Return the set of metric names that folding metrics should produce.""" + names: set[str] = set() + for m_ in FOLDING_METRIC_KEYS: + names.add(f"val/lddt_{m_}") + names.add(f"val/disto_lddt_{m_}") + names.add(f"val/complex_lddt_{m_}") + names.add("val/disto_loss") + names.add("val/disto_lddt") + names.add("val/lddt") + names.add("val/complex_lddt") + return names + + +def _expected_plddt_mae_metric_names() -> set[str]: + """Return the set of metric names for plddt_mae.""" + return {f"val/MAE_plddt_{m}" for m in const.out_single_types} + + +def _expected_confidence_lddt_metric_names() -> set[str]: + """Return the set of metric names for confidence-ranked lDDT.""" + names: set[str] = set() + for prefix in CONFIDENCE_PREFIXES: + for key in PAIR_METRIC_KEYS: + names.add(f"val/{prefix}_lddt_{key}") + return names + + +def _expected_pde_mae_metric_names() -> set[str]: + """Return the set of metric names for PDE MAE.""" + return {f"val/MAE_pde_{key}" for key in PAIR_METRIC_KEYS} + + +def _expected_pae_mae_metric_names() -> set[str]: + """Return the set of metric names for PAE MAE.""" + return {f"val/MAE_pae_{key}" for key in PAIR_METRIC_KEYS} + + +def _all_expected_confidence_metric_names() -> set[str]: + """All confidence metrics that should be logged (union of plddt/pde/pae/lddt).""" + return ( + _expected_plddt_mae_metric_names() + | _expected_confidence_lddt_metric_names() + | _expected_pde_mae_metric_names() + | _expected_pae_mae_metric_names() + ) + + +# --------------------------------------------------------------------------- +# Tests — Folding metrics (should pass before fix) +# --------------------------------------------------------------------------- + + +def test_folding_metrics_logged_without_confidence(): + validator = _make_validator(confidence_prediction=False) + model, logged = _make_mock_model(confidence_prediction=False) + _populate_folding_metrics(validator) + + validator.common_on_epoch_end(model) + + expected = _expected_folding_metric_names() + missing = expected - set(logged.keys()) + assert not missing, f"Missing folding metrics: {sorted(missing)}" + + +def test_folding_metrics_have_correct_values(): + validator = _make_validator(confidence_prediction=False) + model, logged = _make_mock_model(confidence_prediction=False) + _populate_folding_metrics(validator) + + validator.common_on_epoch_end(model) + + assert logged["val/lddt_dna_protein"] == pytest.approx(0.5) + assert logged["val/disto_lddt_dna_protein"] == pytest.approx(0.4) + assert logged["val/complex_lddt_dna_protein"] == pytest.approx(0.6) + assert logged["val/disto_loss"] == pytest.approx(0.1, abs=1e-5) + + +def test_no_confidence_metrics_when_disabled(): + """When confidence_prediction=False, no confidence metrics should be logged.""" + validator = _make_validator(confidence_prediction=False) + model, logged = _make_mock_model(confidence_prediction=False) + _populate_folding_metrics(validator) + + validator.common_on_epoch_end(model) + + confidence_names = {"MAE_plddt", "MAE_pde", "MAE_pae", "top1_lddt", "avg_lddt"} + for name in logged: + for cn in confidence_names: + assert cn not in name, f"Unexpected confidence metric logged: {name}" + + +# --------------------------------------------------------------------------- +# Tests — Confidence metrics (main TDD tests — should FAIL before fix) +# --------------------------------------------------------------------------- + + +def test_plddt_mae_logged(): + """plddt_mae should already be logged (pre-existing behavior).""" + validator = _make_validator(confidence_prediction=True) + model, logged = _make_mock_model(confidence_prediction=True) + _populate_folding_metrics(validator) + _populate_confidence_metrics(validator) + + validator.common_on_epoch_end(model) + + expected = _expected_plddt_mae_metric_names() + missing = expected - set(logged.keys()) + assert not missing, f"Missing plddt_mae metrics: {sorted(missing)}" + + +def test_confidence_lddt_metrics_initialized(): + """Confidence-ranked lDDT MeanMetrics must exist in the ModuleDict.""" + validator = _make_validator(confidence_prediction=True) + + for prefix in CONFIDENCE_PREFIXES: + label = f"{prefix}_lddt" + for key in PAIR_METRIC_KEYS: + assert ( + key in validator.confidence_metrics[label][0] + ), f"MeanMetric not initialized for confidence_metrics['{label}'][0]['{key}']" + + +def test_pde_mae_metrics_initialized(): + """PDE MAE MeanMetrics must exist in the ModuleDict.""" + validator = _make_validator(confidence_prediction=True) + for key in PAIR_METRIC_KEYS: + assert ( + key in validator.confidence_metrics["pde_mae"][0] + ), f"MeanMetric not initialized for confidence_metrics['pde_mae'][0]['{key}']" + + +def test_pae_mae_metrics_initialized(): + """PAE MAE MeanMetrics must exist in the ModuleDict.""" + validator = _make_validator(confidence_prediction=True) + for key in PAIR_METRIC_KEYS: + assert ( + key in validator.confidence_metrics["pae_mae"][0] + ), f"MeanMetric not initialized for confidence_metrics['pae_mae'][0]['{key}']" + + +def test_confidence_lddt_metrics_logged(): + """All confidence-ranked lDDT metrics must appear in logged output.""" + validator = _make_validator(confidence_prediction=True) + model, logged = _make_mock_model(confidence_prediction=True) + _populate_folding_metrics(validator) + _populate_confidence_metrics(validator) + + validator.common_on_epoch_end(model) + + expected = _expected_confidence_lddt_metric_names() + missing = expected - set(logged.keys()) + assert not missing, f"Missing confidence lDDT metrics: {sorted(missing)}" + + +def test_pde_mae_logged(): + """PDE MAE metrics must appear in logged output.""" + validator = _make_validator(confidence_prediction=True) + model, logged = _make_mock_model(confidence_prediction=True) + _populate_folding_metrics(validator) + _populate_confidence_metrics(validator) + + validator.common_on_epoch_end(model) + + expected = _expected_pde_mae_metric_names() + missing = expected - set(logged.keys()) + assert not missing, f"Missing PDE MAE metrics: {sorted(missing)}" + + +def test_pae_mae_logged(): + """PAE MAE metrics must appear in logged output.""" + validator = _make_validator(confidence_prediction=True) + model, logged = _make_mock_model(confidence_prediction=True) + _populate_folding_metrics(validator) + _populate_confidence_metrics(validator) + + validator.common_on_epoch_end(model) + + expected = _expected_pae_mae_metric_names() + missing = expected - set(logged.keys()) + assert not missing, f"Missing PAE MAE metrics: {sorted(missing)}" + + +def test_all_confidence_metrics_logged(): + """Comprehensive check: every confidence metric must be logged.""" + validator = _make_validator(confidence_prediction=True) + model, logged = _make_mock_model(confidence_prediction=True) + _populate_folding_metrics(validator) + _populate_confidence_metrics(validator) + + validator.common_on_epoch_end(model) + + expected = _all_expected_confidence_metric_names() + missing = expected - set(logged.keys()) + assert not missing, f"Missing confidence metrics ({len(missing)}): {sorted(missing)}" + + +def test_no_metric_name_collisions(): + """All logged metric names must be unique (no overwrites).""" + validator = _make_validator(confidence_prediction=True) + logged_names: list[str] = [] + + def _log(name: str, value, **kwargs): + logged_names.append(name) + + model = MagicMock() + model.confidence_prediction = True + model.log = _log + + _populate_folding_metrics(validator) + _populate_confidence_metrics(validator) + validator.common_on_epoch_end(model) + + assert len(logged_names) == len( + set(logged_names) + ), f"Duplicate metric names: {[n for n in logged_names if logged_names.count(n) > 1]}" + + +# --------------------------------------------------------------------------- +# Tests — Metric reset lifecycle +# --------------------------------------------------------------------------- + + +def test_folding_metrics_reset_after_epoch_end(): + """After common_on_epoch_end, folding MeanMetrics should be reset.""" + validator = _make_validator(confidence_prediction=False) + model, _ = _make_mock_model(confidence_prediction=False) + _populate_folding_metrics(validator) + + validator.common_on_epoch_end(model) + + for m_ in FOLDING_METRIC_KEYS: + val = validator.folding_metrics["lddt"][0][m_].compute() + assert torch.isnan(val), f"lddt[{m_}] not reset: {val}" + + +def test_confidence_metrics_reset_after_epoch_end(): + """After common_on_epoch_end, confidence MeanMetrics should be reset.""" + validator = _make_validator(confidence_prediction=True) + model, _ = _make_mock_model(confidence_prediction=True) + _populate_folding_metrics(validator) + _populate_confidence_metrics(validator) + + validator.common_on_epoch_end(model) + + for m in const.out_single_types: + val = validator.confidence_metrics["plddt_mae"][0][m].compute() + assert torch.isnan(val), f"plddt_mae[{m}] not reset: {val}" + + for mae_label in ["pde_mae", "pae_mae"]: + for key in PAIR_METRIC_KEYS: + if key in validator.confidence_metrics[mae_label][0]: + val = validator.confidence_metrics[mae_label][0][key].compute() + assert torch.isnan(val), f"{mae_label}[{key}] not reset: {val}" + + for prefix in CONFIDENCE_PREFIXES: + label = f"{prefix}_lddt" + for key in PAIR_METRIC_KEYS: + if key in validator.confidence_metrics[label][0]: + val = validator.confidence_metrics[label][0][key].compute() + assert torch.isnan(val), f"{label}[{key}] not reset: {val}" + + +def test_two_epoch_cycle(): + """Simulate two validation epochs and verify independent accumulation. + + This tests the core lifecycle: + 1. Epoch 1: populate -> epoch_end -> log values -> reset + 2. Epoch 2: populate with different values -> epoch_end -> log new values + """ + validator = _make_validator(confidence_prediction=False) + + # --- Epoch 1 --- + model1, logged1 = _make_mock_model(confidence_prediction=False) + _populate_folding_metrics(validator) + validator.common_on_epoch_end(model1) + epoch1_lddt = logged1["val/lddt_dna_protein"] + + # --- Epoch 2: different values --- + model2, logged2 = _make_mock_model(confidence_prediction=False) + for m_ in FOLDING_METRIC_KEYS: + validator.folding_metrics["lddt"][0][m_].update(torch.tensor(0.9), torch.tensor(1.0)) + validator.folding_metrics["disto_lddt"][0][m_].update(torch.tensor(0.8), torch.tensor(1.0)) + validator.folding_metrics["complex_lddt"][0][m_].update(torch.tensor(0.7), torch.tensor(1.0)) + validator.folding_metrics["disto_loss"][0]["disto_loss"].update(torch.tensor(0.01)) + validator.common_on_epoch_end(model2) + epoch2_lddt = logged2["val/lddt_dna_protein"] + + assert epoch1_lddt == pytest.approx(0.5) + assert epoch2_lddt == pytest.approx(0.9) + assert epoch1_lddt != epoch2_lddt, "Epoch 2 should reflect new values, not stale ones" + + +def test_unpopulated_metrics_log_zero(): + """If no validation batches run (e.g. empty val set), metrics should log 0.0. + + This is relevant for resume scenarios where the validator is freshly + created but no val batches have run yet. + """ + validator = _make_validator(confidence_prediction=False) + model, logged = _make_mock_model(confidence_prediction=False) + + validator.common_on_epoch_end(model) + + for m_ in FOLDING_METRIC_KEYS: + assert logged[f"val/lddt_{m_}"] == 0.0, f"Unpopulated lddt_{m_} should be 0.0 (NaN->0.0 fallback)" + + +# --------------------------------------------------------------------------- +# Tests — Dataset name suffix +# --------------------------------------------------------------------------- + + +def test_rcsb_has_no_suffix(): + validator = _make_validator(confidence_prediction=False) + model, logged = _make_mock_model(confidence_prediction=False) + _populate_folding_metrics(validator) + validator.common_on_epoch_end(model) + + assert "val/lddt_dna_protein" in logged + + +def test_non_rcsb_has_suffix(): + validator = Validator( + val_names=["CUSTOM"], + confidence_prediction=False, + physicalism_metrics=False, + ) + model, logged = _make_mock_model(confidence_prediction=False) + _populate_folding_metrics(validator) + validator.common_on_epoch_end(model) + + assert "val/lddt_dna_protein__CUSTOM" in logged + assert "val/lddt_dna_protein" not in logged diff --git a/tests/scripts/__init__.py b/tests/scripts/__init__.py new file mode 100644 index 000000000..b1ddbb2da --- /dev/null +++ b/tests/scripts/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. diff --git a/tests/scripts/test_cluster.py b/tests/scripts/test_cluster.py new file mode 100644 index 000000000..b4be5710f --- /dev/null +++ b/tests/scripts/test_cluster.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Tests for scripts/process/cluster.py OS_CMD_INJECTION fix. + +Verifies that main() uses subprocess argument lists (no shell=True) and that +the full clustering pipeline (FASTA parsing, mmseqs call, short/nucleotide/ +ligand handling, JSON output) is preserved. +""" + +import argparse +import hashlib +import json +import pickle +import sys +from pathlib import Path +from unittest.mock import patch + +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts" / "process")) +import cluster as cluster_module # noqa: E402 + + +def _hash(seq: str) -> str: + return hashlib.sha256(seq.encode()).hexdigest() + + +def _write_fasta(path: Path, sequences: list[tuple[str, str]]) -> None: + """Write a minimal FASTA file. Each item is (header, sequence).""" + lines = [] + for header, seq in sequences: + lines.append(f">{header}") + lines.append(seq) + path.write_text("\n".join(lines)) + + +def _write_cluster_tsv(path: Path, mapping: list[tuple[str, str]]) -> None: + """Write a fake mmseqs clust_prot_cluster.tsv (tab-separated, no header).""" + lines = [f"{rep}\t{member}" for rep, member in mapping] + path.write_text("\n".join(lines)) + + +def _write_ccd_pickle(path: Path, ligand_codes: list[str]) -> None: + """Write a minimal CCD pickle (a dict keyed by ligand code).""" + data = dict.fromkeys(ligand_codes) + with path.open("wb") as f: + pickle.dump(data, f) + + +class TestMainCallsMmseqs: + """Verify subprocess.run invocation and full pipeline output.""" + + def test_calls_mmseqs_with_correct_args(self, tmp_path): + seq_a = "MKTAYIAKQRQISFVKSHFSRQ" + seq_b = "MLLSALVLLLSESGLSGAGGL" + fasta_path = tmp_path / "input.fasta" + _write_fasta(fasta_path, [("seqA", seq_a), ("seqB", seq_b)]) + + ccd_path = tmp_path / "ccd.pkl" + _write_ccd_pickle(ccd_path, ["ATP", "NAG"]) + + outdir = tmp_path / "out" + + hash_a = _hash(seq_a) + hash_b = _hash(seq_b) + + args = argparse.Namespace( + sequences=str(fasta_path), + ccd=str(ccd_path), + outdir=str(outdir), + mmseqs="/usr/bin/mmseqs", + ) + + def _fake_run(cmd, **kwargs): + outdir.mkdir(parents=True, exist_ok=True) + _write_cluster_tsv( + outdir / "clust_prot_cluster.tsv", + [(hash_a, hash_a), (hash_a, hash_b)], + ) + + with patch("cluster.subprocess.run", side_effect=_fake_run) as mock_run: + cluster_module.main(args) + + mock_run.assert_called_once() + call_args, call_kwargs = mock_run.call_args + + cmd = call_args[0] + assert isinstance(cmd, list), f"Expected list, got {type(cmd)}" + assert cmd[0] == "/usr/bin/mmseqs" + assert cmd[1] == "easy-cluster" + assert cmd[-2:] == ["--min-seq-id", "0.4"] + assert "shell" not in call_kwargs, "shell=True must not be passed" + assert call_kwargs.get("check") is True + + clustering_file = outdir / "clustering.json" + assert clustering_file.exists() + with clustering_file.open() as f: + clustering = json.load(f) + + assert hash_a in clustering + assert hash_b in clustering + assert clustering[hash_a] == hash_a + assert clustering[hash_b] == hash_a + assert "ATP" in clustering + assert "NAG" in clustering + assert clustering["ATP"] == "ATP" + + def test_handles_short_and_nucleotide_sequences(self, tmp_path): + protein = "MKTAYIAKQRQISFVKSHFSRQ" + short = "MKTAY" + nucleotide = "ACGUACGU" + + fasta_path = tmp_path / "input.fasta" + _write_fasta( + fasta_path, + [ + ("prot1", protein), + ("short1", short), + ("nucl1", nucleotide), + ], + ) + + ccd_path = tmp_path / "ccd.pkl" + _write_ccd_pickle(ccd_path, []) + + outdir = tmp_path / "out" + hash_prot = _hash(protein) + + args = argparse.Namespace( + sequences=str(fasta_path), + ccd=str(ccd_path), + outdir=str(outdir), + mmseqs="mmseqs", + ) + + def _fake_run(cmd, **kwargs): + outdir.mkdir(parents=True, exist_ok=True) + _write_cluster_tsv( + outdir / "clust_prot_cluster.tsv", + [(hash_prot, hash_prot)], + ) + + with patch("cluster.subprocess.run", side_effect=_fake_run) as mock_run: + cluster_module.main(args) + + mock_run.assert_called_once() + + with (outdir / "clustering.json").open() as f: + clustering = json.load(f) + + hash_short = _hash(short) + hash_nucl = _hash(nucleotide) + assert clustering[hash_short] == hash_short, "Short sequence should get self-referential ID" + assert clustering[hash_nucl] == hash_nucl, "Nucleotide sequence should get self-referential ID" + assert hash_prot in clustering, "Protein should be in clustering" diff --git a/tests/scripts/test_run_evals.py b/tests/scripts/test_run_evals.py new file mode 100644 index 000000000..d154da4fe --- /dev/null +++ b/tests/scripts/test_run_evals.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +"""Tests for scripts/eval/run_evals.py OS_CMD_INJECTION fix. + +Verifies that evaluate_structure uses subprocess argument lists (no shell=True) +and preserves the original docker command semantics. +""" + +import sys +from pathlib import Path +from unittest.mock import patch + +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts" / "eval")) +import run_evals # noqa: E402 + + +class TestEvaluateStructureCalls: + """Verify subprocess invocations produced by evaluate_structure.""" + + def test_calls_docker_for_structure_and_ligand(self, tmp_path): + outdir = tmp_path / "results" + outdir.mkdir() + + with ( + patch("run_evals.subprocess.run") as mock_run, + patch("run_evals.os.getuid", return_value=1000), + patch("run_evals.os.getgid", return_value=2000), + ): + run_evals.evaluate_structure( + name="test_sample", + pred="/data/pred.cif", + reference="/data/ref.cif", + outdir=str(outdir), + mount="/data", + ) + + assert mock_run.call_count == 2, f"Expected 2 subprocess.run calls, got {mock_run.call_count}" + + # --- Structure comparison call --- + struct_args, struct_kwargs = mock_run.call_args_list[0] + struct_cmd = struct_args[0] + + assert "shell" not in struct_kwargs, "shell=True must not be passed" + assert struct_kwargs.get("check") is False + assert struct_kwargs.get("capture_output") is True + + assert struct_cmd[:3] == ["sudo", "docker", "run"] + assert "-u" in struct_cmd + u_idx = struct_cmd.index("-u") + assert struct_cmd[u_idx + 1] == "1000:2000" + assert "--rm" in struct_cmd + assert "--volume" in struct_cmd + vol_idx = struct_cmd.index("--volume") + assert struct_cmd[vol_idx + 1] == "/data:/data" + assert run_evals.IMAGE_NAME in struct_cmd + assert "compare-structures" in struct_cmd + + for flag in [ + "--lddt", + "--bb-lddt", + "--qs-score", + "--dockq", + "--ics", + "--ips", + "--rigid-scores", + "--patch-scores", + "--tm-score", + "--fault-tolerant", + ]: + assert flag in struct_cmd, f"Missing flag {flag} in structure command" + assert "-m" in struct_cmd + m_idx = struct_cmd.index("-m") + assert struct_cmd[m_idx + 1] == "/data/pred.cif" + assert "-r" in struct_cmd + r_idx = struct_cmd.index("-r") + assert struct_cmd[r_idx + 1] == "/data/ref.cif" + pep_idx = struct_cmd.index("--min-pep-length") + assert struct_cmd[pep_idx + 1] == "4" + nuc_idx = struct_cmd.index("--min-nuc-length") + assert struct_cmd[nuc_idx + 1] == "4" + + # --- Ligand comparison call --- + ligand_args, ligand_kwargs = mock_run.call_args_list[1] + ligand_cmd = ligand_args[0] + + assert "shell" not in ligand_kwargs, "shell=True must not be passed" + assert "compare-ligand-structures" in ligand_cmd + + for flag in ["--lddt-pli", "--rmsd", "--substructure-match", "--fault-tolerant"]: + assert flag in ligand_cmd, f"Missing flag {flag} in ligand command" + + expected_ligand_out = str(outdir / "test_sample_ligand.json") + o_idx = ligand_cmd.index("-o") + assert ligand_cmd[o_idx + 1] == expected_ligand_out + + def test_skips_existing_outputs(self, tmp_path): + outdir = tmp_path / "results" + outdir.mkdir() + (outdir / "test_sample.json").touch() + (outdir / "test_sample_ligand.json").touch() + + with patch("run_evals.subprocess.run") as mock_run: + run_evals.evaluate_structure( + name="test_sample", + pred="/data/pred.cif", + reference="/data/ref.cif", + outdir=str(outdir), + mount="/data", + ) + + mock_run.assert_not_called() + + def test_skips_only_structure_when_structure_exists(self, tmp_path): + outdir = tmp_path / "results" + outdir.mkdir() + (outdir / "test_sample.json").touch() + + with ( + patch("run_evals.subprocess.run") as mock_run, + patch("run_evals.os.getuid", return_value=0), + patch("run_evals.os.getgid", return_value=0), + ): + run_evals.evaluate_structure( + name="test_sample", + pred="/data/pred.cif", + reference="/data/ref.cif", + outdir=str(outdir), + mount="/data", + ) + + assert mock_run.call_count == 1 + cmd = mock_run.call_args_list[0][0][0] + assert "compare-ligand-structures" in cmd + + def test_uid_gid_values(self, tmp_path): + outdir = tmp_path / "results" + outdir.mkdir() + + with ( + patch("run_evals.subprocess.run") as mock_run, + patch("run_evals.os.getuid", return_value=12345), + patch("run_evals.os.getgid", return_value=67890), + ): + run_evals.evaluate_structure( + name="sample", + pred="/mnt/pred.cif", + reference="/mnt/ref.cif", + outdir=str(outdir), + mount="/mnt", + ) + + for c in mock_run.call_args_list: + cmd = c[0][0] + u_idx = cmd.index("-u") + assert cmd[u_idx + 1] == "12345:67890" diff --git a/tests/workflow/test_workflow_utils.py b/tests/workflow/test_workflow_utils.py new file mode 100644 index 000000000..679ca42cd --- /dev/null +++ b/tests/workflow/test_workflow_utils.py @@ -0,0 +1,708 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Unit tests for workflow utilities.""" + +import pickle + +import pytest +import torch +import torch.nn as nn +from omegaconf import DictConfig, ListConfig, OmegaConf +from pytorch_lightning import LightningModule, Trainer +from torch.utils.data import DataLoader, TensorDataset + +from boltz.workflow.utils import CUDAMemoryProfile, convert_datasets_dict_to_list_config + +# ========== Fixtures ========== + + +@pytest.fixture +def simple_base_config(): + """Create a simple base ListConfig with one item.""" + return OmegaConf.create([{"target_dir": "/path1", "prob": 0.5}]) + + +@pytest.fixture +def multi_item_base_config(): + """Create a base ListConfig with multiple items.""" + return OmegaConf.create( + [ + {"target_dir": "/path1", "prob": 0.5, "msa_dir": "/msa1"}, + {"target_dir": "/path2", "prob": 0.3, "msa_dir": "/msa2"}, + {"target_dir": "/path3", "prob": 0.2, "msa_dir": "/msa3"}, + ] + ) + + +@pytest.fixture +def keys_set_simple(): + """Simple set of valid keys.""" + return {"target_dir", "prob"} + + +@pytest.fixture +def keys_set_extended(): + """Extended set of valid keys.""" + return {"target_dir", "prob", "msa_dir"} + + +# ========== Happy Path Tests ========== + + +def test_single_item_single_key_override(simple_base_config, keys_set_simple): + """Test overriding a single key in a single-item ListConfig.""" + override = OmegaConf.create({"0": {"target_dir": "/new/path"}}) + + result = convert_datasets_dict_to_list_config(simple_base_config, override, keys_set_simple) + + assert isinstance(result, ListConfig) + assert result[0].target_dir == override["0"]["target_dir"] + assert result[0].prob == simple_base_config[0].prob # Unchanged + + # Verify original is not modified + assert simple_base_config[0].target_dir == "/path1" + + +def test_single_item_multiple_keys_override(simple_base_config, keys_set_simple): + """Test overriding multiple keys in a single-item ListConfig.""" + override = OmegaConf.create({"0": {"target_dir": "/new/path", "prob": 0.9}}) + + result = convert_datasets_dict_to_list_config(simple_base_config, override, keys_set_simple) + + assert result[0].target_dir == override["0"]["target_dir"] + assert result[0].prob == override["0"]["prob"] + + +def test_multiple_items_single_override(multi_item_base_config, keys_set_extended): + """Test overriding a single item in a multi-item ListConfig.""" + override = OmegaConf.create({"1": {"target_dir": "/new/path"}}) + + result = convert_datasets_dict_to_list_config(multi_item_base_config, override, keys_set_extended) + + # First item unchanged + assert result[0].target_dir == multi_item_base_config[0].target_dir + assert result[0].prob == multi_item_base_config[0].prob + assert result[0].msa_dir == multi_item_base_config[0].msa_dir + + # Second item modified + assert result[1].target_dir == override["1"]["target_dir"] + assert result[1].prob == multi_item_base_config[1].prob # Unchanged + assert result[1].msa_dir == multi_item_base_config[1].msa_dir # Unchanged + + # Third item unchanged + assert result[2].target_dir == multi_item_base_config[2].target_dir + + +def test_multiple_items_multiple_overrides(multi_item_base_config, keys_set_extended): + """Test overriding multiple items in a multi-item ListConfig.""" + override = OmegaConf.create( + { + "0": {"target_dir": "/new/path0", "prob": 0.8}, + "2": {"msa_dir": "/new/msa2"}, + } + ) + + result = convert_datasets_dict_to_list_config(multi_item_base_config, override, keys_set_extended) + + # First item modified + assert result[0].target_dir == override["0"]["target_dir"] + assert result[0].prob == override["0"]["prob"] + assert result[0].msa_dir == multi_item_base_config[0].msa_dir # Unchanged + + # Second item unchanged + assert result[1].target_dir == multi_item_base_config[1].target_dir + assert result[1].prob == multi_item_base_config[1].prob + + # Third item modified + assert result[2].target_dir == multi_item_base_config[2].target_dir # Unchanged + assert result[2].msa_dir == override["2"]["msa_dir"] + + +def test_all_items_overridden(multi_item_base_config, keys_set_extended): + """Test overriding all items in a multi-item ListConfig.""" + override = OmegaConf.create( + { + "0": {"target_dir": "/new0"}, + "1": {"target_dir": "/new1"}, + "2": {"target_dir": "/new2"}, + } + ) + + result = convert_datasets_dict_to_list_config(multi_item_base_config, override, keys_set_extended) + + assert result[0].target_dir == override["0"]["target_dir"] + assert result[1].target_dir == override["1"]["target_dir"] + assert result[2].target_dir == override["2"]["target_dir"] + + +def test_deep_copy_behavior(simple_base_config, keys_set_simple): + """Test that the function returns a deep copy and doesn't mutate the original.""" + original_target_dir = simple_base_config[0].target_dir + override = OmegaConf.create({"0": {"target_dir": "/new/path"}}) + + result = convert_datasets_dict_to_list_config(simple_base_config, override, keys_set_simple) + + # Modify the result + modified_value = "/another/path" + result[0].target_dir = modified_value + + # Original should be unchanged + assert simple_base_config[0].target_dir == original_target_dir + assert result[0].target_dir == modified_value + + +@pytest.mark.parametrize( + "field_name,field_value,expected", + [ + ("str_field", "new_value", "new_value"), + ("int_field", 100, 100), + ("float_field", 2.71, 2.71), + ("bool_field", False, False), + ], +) +def test_override_with_different_value_types(field_name, field_value, expected): + """Test overriding with different value types (string, int, float, bool).""" + base = OmegaConf.create( + [ + { + "str_field": "value", + "int_field": 42, + "float_field": 3.14, + "bool_field": True, + } + ] + ) + keys = {"str_field", "int_field", "float_field", "bool_field"} + override = OmegaConf.create({"0": {field_name: field_value}}) + + result = convert_datasets_dict_to_list_config(base, override, keys) + + assert result[0][field_name] == override["0"][field_name] + assert result[0][field_name] == expected + + +def test_partial_key_override(multi_item_base_config, keys_set_extended): + """Test overriding only some keys, leaving others unchanged.""" + override = OmegaConf.create({"1": {"prob": 0.99}}) + + result = convert_datasets_dict_to_list_config(multi_item_base_config, override, keys_set_extended) + + assert result[1].target_dir == multi_item_base_config[1].target_dir # Unchanged + assert result[1].prob == override["1"]["prob"] # Changed + assert result[1].msa_dir == multi_item_base_config[1].msa_dir # Unchanged + + +# ========== Error Cases: Invalid base ========== + + +@pytest.mark.parametrize( + "base_input,error_match", + [ + (OmegaConf.create({"0": {"target_dir": "/path"}}), "base must be a ListConfig"), # DictConfig + ([{"target_dir": "/path"}], "base must be a ListConfig"), # Plain list + ], +) +def test_base_not_listconfig_raises(base_input, error_match): + """Test that passing non-ListConfig as base raises ValueError.""" + override = OmegaConf.create({"0": {"target_dir": "/new"}}) + keys = {"target_dir"} + + with pytest.raises(ValueError, match=error_match): + convert_datasets_dict_to_list_config(base_input, override, keys) + + +# ========== Error Cases: Invalid override ========== + + +@pytest.mark.parametrize( + "override_input,error_match", + [ + (OmegaConf.create([{"target_dir": "/new"}]), "override must be a DictConfig"), # ListConfig + ({"0": {"target_dir": "/new"}}, "override must be a DictConfig"), # Plain dict + ], +) +def test_override_not_dictconfig_raises(simple_base_config, keys_set_simple, override_input, error_match): + """Test that passing non-DictConfig as override raises ValueError.""" + with pytest.raises(ValueError, match=error_match): + convert_datasets_dict_to_list_config(simple_base_config, override_input, keys_set_simple) + + +def test_empty_override_raises(simple_base_config, keys_set_simple): + """Test that empty override DictConfig raises ValueError.""" + override = OmegaConf.create({}) + + with pytest.raises(ValueError, match="Input DictConfig override is empty"): + convert_datasets_dict_to_list_config(simple_base_config, override, keys_set_simple) + + +@pytest.mark.parametrize( + "invalid_key,base_len,error_match", + [ + ("3", 3, "Invalid keys in override"), # Out of range + ("-1", 1, "Invalid keys in override"), # Negative + ("abc", 1, "Invalid keys in override"), # Non-integer + ], +) +def test_override_with_invalid_index_raises(invalid_key, base_len, error_match): + """Test that override with invalid index raises ValueError.""" + base = OmegaConf.create([{"field": f"value{i}"} for i in range(base_len)]) + override = OmegaConf.create({invalid_key: {"field": "/new"}}) + keys = {"field"} + + with pytest.raises(ValueError, match=error_match): + convert_datasets_dict_to_list_config(base, override, keys) + + +def test_override_with_invalid_nested_keys_raises(simple_base_config, keys_set_simple): + """Test that override with keys not in keys_to_override raises ValueError.""" + override = OmegaConf.create({"0": {"target_dir": "/new", "invalid_key": "value"}}) + + with pytest.raises(ValueError, match="Invalid keys in override of item 0"): + convert_datasets_dict_to_list_config(simple_base_config, override, keys_set_simple) + + +def test_override_with_only_invalid_keys_raises(simple_base_config, keys_set_simple): + """Test that override with only invalid keys raises ValueError.""" + override = OmegaConf.create({"0": {"invalid_key": "value"}}) + + with pytest.raises(ValueError, match="Invalid keys in override of item 0"): + convert_datasets_dict_to_list_config(simple_base_config, override, keys_set_simple) + + +# ========== Null dataset removal ========== + + +def test_remove_null_dataset_single(multi_item_base_config, keys_set_extended): + """Base has 2 items; override sets item 1 to null with remove_null_datasets=True; result has 1 item.""" + base = OmegaConf.create( + [ + {"target_dir": "/path1", "prob": 0.5, "msa_dir": "/msa1"}, + {"target_dir": "/path2", "prob": 0.5, "msa_dir": "/msa2"}, + ] + ) + override = OmegaConf.create({"1": None}) + + result = convert_datasets_dict_to_list_config(base, override, keys_set_extended, remove_null_datasets=True) + + assert len(result) == 1 + assert result[0].target_dir == "/path1" + assert result[0].prob == 0.5 + assert result[0].msa_dir == "/msa1" + + +def test_remove_null_dataset_multiple(multi_item_base_config, keys_set_extended): + """Base has 3 items; override nullifies items 0 and 2; result has 1 item (former index 1).""" + override = OmegaConf.create({"0": None, "2": None}) + + result = convert_datasets_dict_to_list_config( + multi_item_base_config, override, keys_set_extended, remove_null_datasets=True + ) + + assert len(result) == 1 + assert result[0].target_dir == "/path2" + assert result[0].prob == 0.3 + assert result[0].msa_dir == "/msa2" + + +def test_null_override_without_flag_raises(multi_item_base_config, keys_set_extended): + """Override sets an item to null with remove_null_datasets=False (default); raises ValueError.""" + override = OmegaConf.create({"1": None}) + + with pytest.raises(ValueError, match="Override for item 1 is null but remove_null_datasets is False"): + convert_datasets_dict_to_list_config(multi_item_base_config, override, keys_set_extended) + + +def test_remove_null_with_partial_override(multi_item_base_config, keys_set_extended): + """Override nullifies item 1 and partially overrides item 0; result has one item with overrides applied.""" + override = OmegaConf.create({"0": {"target_dir": "/new/path0", "prob": 0.8}, "1": None}) + + result = convert_datasets_dict_to_list_config( + multi_item_base_config, override, keys_set_extended, remove_null_datasets=True + ) + + assert len(result) == 2 # item 0 (merged) and item 2 (unchanged) + assert result[0].target_dir == "/new/path0" + assert result[0].prob == 0.8 + assert result[0].msa_dir == "/msa1" + assert result[1].target_dir == "/path3" + assert result[1].prob == 0.2 + assert result[1].msa_dir == "/msa3" + + +# ========== Edge Cases ========== + + +def test_single_item_listconfig(): + """Test with a single-item ListConfig.""" + base = OmegaConf.create([{"field": "value"}]) + override = OmegaConf.create({"0": {"field": "new_value"}}) + keys = {"field"} + + result = convert_datasets_dict_to_list_config(base, override, keys) + + assert len(result) == len(base) + assert result[0].field == override["0"]["field"] + + +def test_large_listconfig(): + """Test with a large ListConfig (10 items).""" + base_size = 10 + base = OmegaConf.create([{"field": f"value{i}"} for i in range(base_size)]) + override = OmegaConf.create( + { + "0": {"field": "new0"}, + "5": {"field": "new5"}, + "9": {"field": "new9"}, + } + ) + keys = {"field"} + + result = convert_datasets_dict_to_list_config(base, override, keys) + + assert len(result) == base_size + assert result[0].field == override["0"]["field"] + assert result[1].field == base[1].field # Unchanged + assert result[5].field == override["5"]["field"] + assert result[9].field == override["9"]["field"] + + +@pytest.mark.parametrize( + "override_value,expected", + [ + (None, None), + ("", ""), + (0, 0), + ], +) +def test_override_with_special_values(override_value, expected): + """Test overriding with special values (None, empty string, zero).""" + base = OmegaConf.create([{"field": "value"}]) + override = OmegaConf.create({"0": {"field": override_value}}) + keys = {"field"} + + result = convert_datasets_dict_to_list_config(base, override, keys) + + assert result[0].field == override["0"]["field"] + assert result[0].field == expected + + +def test_empty_keys_set(): + """Test with empty keys_to_override set.""" + base = OmegaConf.create([{}]) + override = OmegaConf.create({"0": {}}) + keys = set() + + result = convert_datasets_dict_to_list_config(base, override, keys) + + # Should work since both base and override are empty + assert len(result) == len(base) + + +def test_nested_dictconfig_values(): + """Test with nested DictConfig as values.""" + base = OmegaConf.create([{"config": {"nested": "value"}}]) + override = OmegaConf.create({"0": {"config": {"nested": "new_value"}}}) + keys = {"config"} + + result = convert_datasets_dict_to_list_config(base, override, keys) + + assert result[0].config.nested == override["0"]["config"]["nested"] + + +def test_list_values_in_config(): + """Test with list values in the config.""" + base = OmegaConf.create([{"data_list": [1, 2, 3]}]) + override = OmegaConf.create({"0": {"data_list": [4, 5, 6]}}) + keys = {"data_list"} + + result = convert_datasets_dict_to_list_config(base, override, keys) + + assert result[0].data_list == override["0"]["data_list"] + + +# ========== Relaxed validation: base may have extra keys ========== + + +def test_base_with_extra_keys_allowed(): + """Base items may have keys beyond keys_to_override; override works and extra keys preserved.""" + base = OmegaConf.create([{"target_dir": "/path1", "prob": 0.5, "extra_key": "extra_value", "another_extra": 42}]) + override = OmegaConf.create({"0": {"target_dir": "/new/path"}}) + keys = {"target_dir", "prob"} + + result = convert_datasets_dict_to_list_config(base, override, keys) + + assert result[0].target_dir == "/new/path" + assert result[0].prob == 0.5 + assert result[0].extra_key == "extra_value" + assert result[0].another_extra == 42 + + +def test_base_items_with_different_keys(): + """Base items may have heterogeneous key sets; override of shared keys works.""" + # Mimics real YAML: dataset 0 has split/val_group, dataset 1 has override_method/override_bfactor + base = OmegaConf.create( + [ + {"target_dir": "/path1", "prob": 0.5, "split": "train", "val_group": "RCSB"}, + {"target_dir": "/path2", "prob": 0.5, "override_method": "AFDB", "override_bfactor": True}, + ] + ) + keys = {"target_dir", "prob", "split", "val_group", "override_method", "override_bfactor"} + override = OmegaConf.create({"0": {"target_dir": "/new0"}, "1": {"target_dir": "/new1"}}) + + result = convert_datasets_dict_to_list_config(base, override, keys) + + assert result[0].target_dir == "/new0" + assert result[0].split == "train" + assert result[1].target_dir == "/new1" + assert result[1].override_method == "AFDB" + + +def test_override_key_not_in_base_item_raises(): + """Override key in keys_to_override but absent from that base item raises ValueError.""" + base = OmegaConf.create( + [ + {"target_dir": "/path1", "prob": 0.5, "split": "train"}, + {"target_dir": "/path2", "prob": 0.5}, # no "split" key + ] + ) + keys = {"target_dir", "prob", "split"} + override = OmegaConf.create({"1": {"split": "val"}}) # override split for item 1 which has no split + + with pytest.raises(ValueError, match="contain keys not present in base item"): + convert_datasets_dict_to_list_config(base, override, keys) + + +# ========== Documentation Example Test ========== + + +def test_docstring_example(): + """Test the example from the function's docstring.""" + base = OmegaConf.create([{"target_dir": "/path1", "prob": 0.5}]) + override = OmegaConf.create({"0": {"target_dir": "/new/path"}}) + keys = {"target_dir", "prob"} + + result = convert_datasets_dict_to_list_config(base, override, keys) + + assert result[0].target_dir == override["0"]["target_dir"] + assert result[0].prob == base[0].prob + + +# ========== Real-world Scenario Tests ========== + + +def test_realistic_dataset_config(): + """Test with a realistic dataset configuration scenario.""" + base = OmegaConf.create( + [ + { + "_target_": "DatasetA", + "target_dir": "/data/train", + "msa_dir": "/msa/train", + "prob": 0.7, + "sampler": "uniform", + "cropper": "random", + "split": "train", + }, + { + "_target_": "DatasetB", + "target_dir": "/data/val", + "msa_dir": "/msa/val", + "prob": 0.3, + "sampler": "weighted", + "cropper": "center", + "split": "val", + }, + ] + ) + + # Command-line override: Change first dataset's directory and probability + override = OmegaConf.create( + { + "0": { + "target_dir": "/new/train/path", + "prob": 0.9, + } + } + ) + + keys = {"_target_", "target_dir", "msa_dir", "prob", "sampler", "cropper", "split"} + + result = convert_datasets_dict_to_list_config(base, override, keys) + + # First dataset modified + assert result[0]._target_ == base[0]._target_ # Unchanged + assert result[0].target_dir == override["0"]["target_dir"] # Changed + assert result[0].prob == override["0"]["prob"] # Changed + assert result[0].sampler == base[0].sampler # Unchanged + + # Second dataset unchanged + assert result[1].target_dir == base[1].target_dir + assert result[1].prob == base[1].prob + + +def test_string_index_consistency(multi_item_base_config, keys_set_extended): + """Test that string indices work correctly and consistently.""" + override = OmegaConf.create( + { + "0": {"target_dir": "/new0"}, + "1": {"target_dir": "/new1"}, + } + ) + + result = convert_datasets_dict_to_list_config(multi_item_base_config, override, keys_set_extended) + + assert result[0].target_dir == override["0"]["target_dir"] + assert result[1].target_dir == override["1"]["target_dir"] + assert result[2].target_dir == multi_item_base_config[2].target_dir # Unchanged + + +def test_preserves_omegaconf_metadata(simple_base_config, keys_set_simple): + """Test that OmegaConf metadata is preserved.""" + override = OmegaConf.create({"0": {"target_dir": "/new"}}) + + result = convert_datasets_dict_to_list_config(simple_base_config, override, keys_set_simple) + + assert isinstance(result, ListConfig) + assert isinstance(result[0], DictConfig) + assert OmegaConf.is_config(result) + assert OmegaConf.is_list(result) + + +# ========== CUDAMemoryProfile Tests ========== + + +class SimpleLightningModule(LightningModule): + """Simple Lightning module for testing CUDAMemoryProfile.""" + + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(100, 50) + self.layer2 = nn.Linear(50, 10) + + def forward(self, x): + x = torch.relu(self.layer1(x)) + return self.layer2(x) + + def training_step(self, batch, batch_idx): + x, y = batch + y_pred = self(x) + return nn.functional.mse_loss(y_pred, y) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_memory_profile_callback(tmp_path): + """Test CUDAMemoryProfile callback creates memory snapshot file.""" + output_file = tmp_path / "memory_snapshot.pickle" + memory_profiler = CUDAMemoryProfile(output_path=output_file, max_entries=100000) + + model = SimpleLightningModule() + x_data = torch.randn(32, 100) + y_data = torch.randn(32, 10) + dataset = TensorDataset(x_data, y_data) + dataloader = DataLoader(dataset, batch_size=8) + + trainer = Trainer( + max_epochs=1, + accelerator="gpu", + devices=1, + callbacks=[memory_profiler], + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + ) + trainer.fit(model, dataloader) + + assert output_file.exists(), f"Memory snapshot file should be created at {output_file}" + + with open(output_file, "rb") as f: + snapshot_data = pickle.load(f) + + assert isinstance(snapshot_data, dict), "Snapshot should be a dictionary" + assert "segments" in snapshot_data or "device_traces" in snapshot_data, "Snapshot should contain memory data" + + if "device_traces" in snapshot_data and len(snapshot_data["device_traces"]) > 0: + device_0_traces = snapshot_data["device_traces"][0] + assert isinstance(device_0_traces, list), "Device traces should be a list" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_memory_profile_with_kwargs(tmp_path): + """Test CUDAMemoryProfile callback forwards kwargs correctly.""" + output_file = tmp_path / "memory_snapshot_kwargs.pickle" + max_entries = 50000 + memory_profiler = CUDAMemoryProfile(output_path=output_file, max_entries=max_entries) + + assert memory_profiler._kwargs == {"max_entries": max_entries} + assert memory_profiler._output_path == output_file + + model = SimpleLightningModule() + x_data = torch.randn(16, 100) + y_data = torch.randn(16, 10) + dataset = TensorDataset(x_data, y_data) + dataloader = DataLoader(dataset, batch_size=8) + + trainer = Trainer( + max_epochs=1, + accelerator="gpu", + devices=1, + callbacks=[memory_profiler], + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + ) + trainer.fit(model, dataloader) + + assert output_file.exists(), "Memory snapshot file should be created" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_memory_profile_creates_parent_dirs(tmp_path): + """Test that CUDAMemoryProfile creates parent directories if they don't exist.""" + output_file = tmp_path / "nested" / "dirs" / "memory_snapshot.pickle" + assert not output_file.parent.exists() + + memory_profiler = CUDAMemoryProfile(output_path=output_file) + + assert output_file.parent.exists(), "Parent directories should be created" + assert output_file.parent.is_dir(), "Parent path should be a directory" + + model = SimpleLightningModule() + x_data = torch.randn(8, 100) + y_data = torch.randn(8, 10) + dataset = TensorDataset(x_data, y_data) + dataloader = DataLoader(dataset, batch_size=4) + + trainer = Trainer( + max_epochs=1, + accelerator="gpu", + devices=1, + callbacks=[memory_profiler], + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + ) + trainer.fit(model, dataloader) + + assert output_file.exists(), "Memory snapshot should be created in nested directory"