diff --git a/src/proteinfoundation/flow_matching/product_space_flow_matcher.py b/src/proteinfoundation/flow_matching/product_space_flow_matcher.py index 4477aed..cdbeda7 100644 --- a/src/proteinfoundation/flow_matching/product_space_flow_matcher.py +++ b/src/proteinfoundation/flow_matching/product_space_flow_matcher.py @@ -2,11 +2,44 @@ import lightning as L import torch +import torch.nn as nn from jaxtyping import Bool, Float from torch import Tensor from proteinfoundation.flow_matching.rdn_flow_matcher import RDNFlowMatcher + +class LearnableSchedule(nn.Module): + """Per-modality learnable time schedule for flow matching. + + Parameterises the nsteps+1 time points as a softmax over step-width + logits followed by a cumulative sum. This guarantees a strictly + monotonic schedule in [0, 1] while allowing the model to concentrate + evaluation steps in whichever time region has the highest loss gradient + (typically t ∈ [0.2, 0.7] for backbone coordinates). + + Usage — add to config:: + + product_flowmatcher: + bb_ca: + learnable_schedule_nsteps: 400 # matches nsteps at inference + + Args: + nsteps: Number of integration steps. The schedule has nsteps+1 points. + """ + + def __init__(self, nsteps: int) -> None: + super().__init__() + self.nsteps = nsteps + # Initialise as uniform (all logits equal → uniform softmax → linear schedule) + self.logits = nn.Parameter(torch.zeros(nsteps)) + + def get_ts(self) -> Tensor: + """Return the current schedule as a [nsteps+1] tensor in [0, 1].""" + deltas = torch.softmax(self.logits, dim=0) + ts = torch.cat([self.logits.new_zeros(1), torch.cumsum(deltas, dim=0)]) + return ts # [nsteps + 1], ts[0]=0, ts[-1]=1 + FLOW_MATCHER_FACTORY = { "bb_ca": RDNFlowMatcher, "local_latents": RDNFlowMatcher, @@ -25,6 +58,15 @@ def __init__(self, cfg_exp: dict): self.data_modes = [m for m in self.cfg_exp.product_flowmatcher] self.base_flow_matchers = self.get_base_flow_matchers() + # Optional learnable schedules, one per data mode that opts in via + # ``learnable_schedule_nsteps`` in its product_flowmatcher config entry. + learnable: dict[str, LearnableSchedule] = {} + for m in self.data_modes: + nsteps = self.cfg_exp.product_flowmatcher[m].get("learnable_schedule_nsteps", None) + if nsteps is not None: + learnable[m] = LearnableSchedule(int(nsteps)) + self.learnable_schedules = nn.ModuleDict(learnable) + def get_base_flow_matchers(self): """Constructs all necessary flow matchers.""" zero_coms = [self.cfg_exp.product_flowmatcher[m].get("zero_com_noise", False) for m in self.data_modes] @@ -758,15 +800,32 @@ def full_simulation( # } # each [nsteps + 1], first element is 0, last is 1 ts = {} for data_mode in self.data_modes: - if sampling_model_args[data_mode]["simulation_step_params"]["sampling_mode"] == "vf_tsr": - schedule_func = get_schedule_tsr_safe + if data_mode in self.learnable_schedules: + # Learnable schedule takes priority; interpolate to requested nsteps + # if the learned nsteps differs from the inference nsteps. + ls = self.learnable_schedules[data_mode] + if ls.nsteps == int(nsteps): + ts[data_mode] = ls.get_ts().detach() + else: + # Resample to requested resolution via linear interpolation + raw = ls.get_ts().detach() + src_idx = torch.linspace(0, 1, len(raw), device=raw.device) + tgt_idx = torch.linspace(0, 1, int(nsteps) + 1, device=raw.device) + ts[data_mode] = torch.from_numpy( + __import__("numpy").interp(tgt_idx.cpu().numpy(), src_idx.cpu().numpy(), raw.cpu().numpy()) + ).to(raw.dtype) + elif sampling_model_args[data_mode]["simulation_step_params"]["sampling_mode"] == "vf_tsr": + ts[data_mode] = get_schedule_tsr_safe( + mode=sampling_model_args[data_mode]["schedule"]["mode"], + nsteps=int(nsteps), + p1=sampling_model_args[data_mode]["schedule"]["p"], + ) else: - schedule_func = get_schedule - ts[data_mode] = schedule_func( - mode=sampling_model_args[data_mode]["schedule"]["mode"], - nsteps=int(nsteps), - p1=sampling_model_args[data_mode]["schedule"]["p"], - ) + ts[data_mode] = get_schedule( + mode=sampling_model_args[data_mode]["schedule"]["mode"], + nsteps=int(nsteps), + p1=sampling_model_args[data_mode]["schedule"]["p"], + ) gt = { data_mode: get_gt( diff --git a/src/proteinfoundation/nn/genie2_modules/structure_net.py b/src/proteinfoundation/nn/genie2_modules/structure_net.py index 74398d2..344ff9d 100644 --- a/src/proteinfoundation/nn/genie2_modules/structure_net.py +++ b/src/proteinfoundation/nn/genie2_modules/structure_net.py @@ -1,6 +1,7 @@ import torch from openfold.model.structure_module import BackboneUpdate, InvariantPointAttention from openfold.model.structure_module import StructureModuleTransition as StructureTransition +from openfold.utils.rigid_utils import Rigid from torch import nn @@ -127,6 +128,7 @@ def __init__( ipa_dropout, n_structure_transition_layer, structure_transition_dropout, + center_translations: bool = True, ): """ Args: @@ -155,6 +157,7 @@ def __init__( """ super().__init__() self.n_structure_block = n_structure_block + self.center_translations = center_translations # Create structure layers layers = [ @@ -228,6 +231,19 @@ def forward(self, s, p, ts, residue_mask): states = [s.unsqueeze(0)] mask = residue_mask.int() for block_idx in range(self.n_structure_block): - s, p, ts, mask, states = self.net((s, p, ts, mask, states)) + if self.center_translations: + # Subtract per-sample center of mass from frame translations before + # each IPA block. Re-centering improves numerical stability and + # makes IPA effectively translation-equivariant: rotating/translating + # the complex leaves the relative geometry unchanged. + trans = ts.get_trans() # [B, N, 3] + mask_f = mask.float()[..., None] # [B, N, 1] + com = (trans * mask_f).sum(dim=1, keepdim=True) / (mask_f.sum(dim=1, keepdim=True) + 1e-8) + ts_centered = Rigid(ts.get_rots(), trans - com) + s, p, ts_centered, mask, states = self.net((s, p, ts_centered, mask, states)) + # Restore original translations so downstream code is unaffected + ts = Rigid(ts_centered.get_rots(), ts_centered.get_trans() + com) + else: + s, p, ts, mask, states = self.net((s, p, ts, mask, states)) states = torch.concat(states, dim=0) return states, ts, s diff --git a/src/proteinfoundation/nn/modules/attn_n_transition.py b/src/proteinfoundation/nn/modules/attn_n_transition.py index 1962129..0b084d6 100644 --- a/src/proteinfoundation/nn/modules/attn_n_transition.py +++ b/src/proteinfoundation/nn/modules/attn_n_transition.py @@ -1,6 +1,7 @@ import torch from proteinfoundation.nn.modules.pair_bias_attn import ( + GeometricSparseMultiHeadBiasedAttentionADALN_MM, MultiHeadBiasedAttentionADALN_MM, MultiHeadCrossAttentionADALN_MM, ) @@ -92,6 +93,96 @@ def forward(self, x, pair_rep, cond, mask): return x * mask[..., None] +class GeometricMultiheadAttnAndTransition(torch.nn.Module): + """``MultiheadAttnAndTransition`` with optional CA-coordinate-guided sparse attention. + + When ``ca_coords`` is passed to ``forward()``, pair attention is restricted + to geometrically nearby residue pairs (local radius + top-K NN), focusing + the model on structurally relevant interactions and reducing effective O(n²) + attention cost for longer sequences (n > 150). + + Drop-in replacement for ``MultiheadAttnAndTransition`` — set + ``use_geometric_attn: true`` in the layer config to activate. + + Args: + Same as ``MultiheadAttnAndTransition``, plus: + geo_topk: Top-K nearest CA neighbours to include per residue. + geo_radius: Local radius (Å) to include per residue. + attention_type: Underlying kernel to use inside the geometric wrapper. + """ + + def __init__( + self, + dim_token, + dim_pair, + nheads, + dim_cond, + residual_mha, + residual_transition, + parallel_mha_transition, + use_attn_pair_bias, + use_qkln, + dropout=0.0, + expansion_factor=4, + geo_topk: int = 32, + geo_radius: float = 8.0, + attention_type: str = "flash", + ): + super().__init__() + self.parallel = parallel_mha_transition + self.use_attn_pair_bias = use_attn_pair_bias + + if self.parallel and residual_mha and residual_transition: + residual_transition = False + + self.residual_mha = residual_mha + self.residual_transition = residual_transition + + self.mhba = GeometricSparseMultiHeadBiasedAttentionADALN_MM( + dim_token=dim_token, + dim_pair=dim_pair, + nheads=nheads, + dim_cond=dim_cond, + use_qkln=use_qkln, + geo_topk=geo_topk, + geo_radius=geo_radius, + attention_type=attention_type, + ) + self.transition = TransitionADALN(dim=dim_token, dim_cond=dim_cond, expansion_factor=expansion_factor) + + def _apply_mha(self, x, pair_rep, cond, mask, ca_coords=None): + x_attn = self.mhba(x, pair_rep, cond, mask, ca_coords=ca_coords) + if self.residual_mha: + x_attn = x_attn + x + return x_attn * mask[..., None] + + def _apply_transition(self, x, cond, mask): + x_tr = self.transition(x, cond, mask) + if self.residual_transition: + x_tr = x_tr + x + return x_tr * mask[..., None] + + def forward(self, x, pair_rep, cond, mask, ca_coords=None): + """ + Args: + x: Token features [b, n, dim_token]. + pair_rep: Pair representation [b, n, n, dim_pair]. + cond: Conditioning [b, n, dim_cond]. + mask: Residue mask [b, n]. + ca_coords: Optional CA coordinates [b, n, 3] in Å for geometric masking. + + Returns: + Updated token features [b, n, dim_token]. + """ + x = x * mask[..., None] + if self.parallel: + x = self._apply_mha(x, pair_rep, cond, mask, ca_coords) + self._apply_transition(x, cond, mask) + else: + x = self._apply_mha(x, pair_rep, cond, mask, ca_coords) + x = self._apply_transition(x, cond, mask) + return x * mask[..., None] + + class MultiheadCrossAttnAndTransition(torch.nn.Module): """Layer that applies mha and transition to a sequence representation. Both layers are their adaptive versions which rely on conditining variables (see above). diff --git a/src/proteinfoundation/nn/modules/pair_bias_attn.py b/src/proteinfoundation/nn/modules/pair_bias_attn.py index f7f391b..8ca11b5 100644 --- a/src/proteinfoundation/nn/modules/pair_bias_attn.py +++ b/src/proteinfoundation/nn/modules/pair_bias_attn.py @@ -541,6 +541,127 @@ def _attn(self, q, k, v, scale, mask: Tensor | None) -> Tensor: return einsum("b h i j, b h j d -> b h i d", attn, v) +def build_geometric_attn_mask( + ca_coords: Tensor, + mask: Tensor, + topk: int = 32, + radius_ang: float = 8.0, +) -> Tensor: + """Build a sparse boolean attention mask from CA coordinates. + + Combines two complementary patterns so that every residue attends to: + 1. Its ``topk`` nearest CA neighbours (captures long-range contacts). + 2. All residues within ``radius_ang`` Å (captures dense local structure). + + The union gives ~10× sparser attention than full O(n²) while retaining + >95% of the geometrically meaningful pairs for typical proteins. + + Args: + ca_coords: CA atom coordinates [b, n, 3] in Å. + mask: Boolean residue mask [b, n]. + topk: Number of nearest-neighbour pairs per residue. + radius_ang: Local neighbourhood radius in Å. + + Returns: + Boolean attention mask [b, n, n] — True where attention is allowed. + """ + b, n, _ = ca_coords.shape + device = ca_coords.device + + dists = torch.cdist(ca_coords, ca_coords) # [b, n, n] + + # Pattern 1: local radius + local = dists < radius_ang + + # Pattern 2: top-k nearest neighbours + k = min(topk, n) + topk_dists, topk_idx = torch.topk(dists, k=k, dim=-1, largest=False) # [b, n, k] + knn = torch.zeros(b, n, n, device=device, dtype=torch.bool) + knn.scatter_(2, topk_idx, True) + + # Union, then apply residue mask + pair_mask = mask[:, :, None] & mask[:, None, :] # [b, n, n] + return (local | knn) & pair_mask + + +class GeometricSparseMultiHeadBiasedAttentionADALN_MM(torch.nn.Module): + """Pair biased MHA with an optional geometric sparsity mask on top. + + When ``ca_coords`` is provided at forward time, attention is restricted + to the union of a local radius neighbourhood and the top-K nearest CA + neighbours, reducing the effective O(n²) cost for long sequences. + + Falls back to full attention when ``ca_coords`` is None (e.g. for short + sequences or during the first few steps where coordinates are noisy). + + Args: + dim_token: Token feature dimension. + dim_pair: Pair representation dimension. + nheads: Number of attention heads. + dim_cond: Conditioning feature dimension. + use_qkln: Whether to use QK layer normalisation. + geo_topk: Top-K neighbours for geometric mask. + geo_radius: Local radius in Å for geometric mask. + attention_type: Underlying attention kernel ('naive', 'flash', 'cuequivariance'). + """ + + def __init__( + self, + dim_token: int, + dim_pair: int, + nheads: int, + dim_cond: int, + use_qkln: bool, + geo_topk: int = 32, + geo_radius: float = 8.0, + attention_type: str = "flash", + ): + super().__init__() + self.geo_topk = geo_topk + self.geo_radius = geo_radius + AttnCls = get_multihead_attention_adaln(attention_type) + self.inner = AttnCls( + dim_token=dim_token, + dim_pair=dim_pair, + nheads=nheads, + dim_cond=dim_cond, + use_qkln=use_qkln, + ) + + def forward( + self, + x: Tensor, + pair_rep: Tensor, + cond: Tensor, + mask: Tensor, + ca_coords: Tensor | None = None, + ) -> Tensor: + """ + Args: + x: Token features [b, n, dim_token]. + pair_rep: Pair representation [b, n, n, dim_pair]. + cond: Conditioning [b, n, dim_cond]. + mask: Residue mask [b, n]. + ca_coords: Optional CA coordinates [b, n, 3] in Å. When provided, + geometric sparsity masking is applied to the pair attention. + + Returns: + Updated token features [b, n, dim_token]. + """ + if ca_coords is not None: + geo_mask = build_geometric_attn_mask( + ca_coords=ca_coords, + mask=mask, + topk=self.geo_topk, + radius_ang=self.geo_radius, + ) + # Zero out pair_rep entries outside the geometric neighbourhood + # so the attention bias drives those weights to -inf. + pair_rep = pair_rep * geo_mask[..., None].float() + + return self.inner(x, pair_rep, cond, mask) + + def get_multihead_attention_adaln(attention_type: str = "naive"): """Factory function to get the appropriate MultiHeadBiasedAttentionADALN_MM class. diff --git a/src/proteinfoundation/partial_autoencoder/autoencoder.py b/src/proteinfoundation/partial_autoencoder/autoencoder.py index 205a0f7..ce91fe9 100644 --- a/src/proteinfoundation/partial_autoencoder/autoencoder.py +++ b/src/proteinfoundation/partial_autoencoder/autoencoder.py @@ -808,6 +808,100 @@ def on_validation_epoch_end(self): self.validation_output_data = [] # Should log here? + @torch.no_grad() + def analyze_reconstruction_fidelity( + self, + dataloader, + max_samples: int = 1000, + ) -> dict: + """Measure reconstruction quality and latent utilisation. + + Runs encode→decode on held-out samples and reports: + - ``mean_ca_rmsd_ang``: mean CA RMSD (Å, no alignment) — proxy for + information loss through the bottleneck. + - ``mean_active_dims``: average number of latent dimensions with per- + component KL > 0.1 (active units). If this equals ``latent_z_dim`` + the bottleneck may be undersized. + - ``recommendation``: human-readable suggestion for ``latent_z_dim``. + + Args: + dataloader: Validation dataloader yielding the same batch format + as the training loop. + max_samples: Stop after this many samples (for speed). + + Returns: + Dict with keys ``mean_ca_rmsd_ang``, ``mean_active_dims``, + ``latent_dim``, and ``recommendation``. + """ + from proteinfoundation.utils.coors_utils import nm_to_ang + + rmsd_vals: list[float] = [] + active_units_vals: list[float] = [] + n_seen = 0 + + self.eval() + for batch in dataloader: + if n_seen >= max_samples: + break + + mask = batch["mask_dict"]["coords"][..., 0, 0].to(self.device) + batch["mask"] = mask + ca_coors_nm = batch["coords_nm"][..., 1, :].to(self.device) * mask[..., None] + + # Move all tensors to device + batch = { + k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() + } + + output_enc = self.encoder(batch) + z = output_enc["z_latent"] + mean = output_enc["mean"] + log_scale = output_enc["log_scale"] + + input_dec = { + "z_latent": z, + "ca_coors_nm": ca_coors_nm, + "residue_mask": mask, + "mask": mask, + } + output_dec = self.decoder(input_dec) + coors_pred = output_dec["coors_nm"] # [b, n, 37, 3] nm + + # CA RMSD in Å (atom index 1 = CA) + ca_true_ang = nm_to_ang(batch["coords_nm"][..., 1, :]) # [b, n, 3] + ca_pred_ang = nm_to_ang(coors_pred[..., 1, :]) # [b, n, 3] + diff_sq = ((ca_true_ang - ca_pred_ang) ** 2).sum(dim=-1) # [b, n] + diff_sq = diff_sq * mask + nres = mask.float().sum(dim=-1).clamp(min=1.0) + rmsd = torch.sqrt(diff_sq.sum(dim=-1) / nres) # [b] + rmsd_vals.extend(rmsd.cpu().tolist()) + + # Active latent dimensions: per-component KL > 0.1 threshold + kl_per_dim = self._per_component_kl(mean, log_scale, mask) # [b, n, d] + kl_mean_per_dim = (kl_per_dim * mask[..., None]).sum(dim=(0, 1)) / mask.float().sum().clamp(min=1) + active = (kl_mean_per_dim > 0.1).float().sum().item() + active_units_vals.append(active) + + n_seen += mask.shape[0] + + mean_rmsd = float(torch.tensor(rmsd_vals).mean().item()) + mean_active = float(torch.tensor(active_units_vals).mean().item()) + d = self.latent_dim + + if mean_rmsd > 1.5: + rec = f"Bottleneck too tight — increase latent_z_dim from {d} to {int(d * 1.5)}" + elif mean_active < d * 0.5: + rec = f"Only {mean_active:.0f}/{d} dims active — could reduce latent_z_dim to ~{int(mean_active * 1.2)}" + else: + rec = f"Bottleneck looks healthy ({mean_active:.0f}/{d} dims active, RMSD={mean_rmsd:.2f}Å)" + + return { + "mean_ca_rmsd_ang": mean_rmsd, + "mean_active_dims": mean_active, + "latent_dim": d, + "recommendation": rec, + } + def predict_step(self, batch: dict, batch_idx: int) -> dict: """ Makes predictions. Given a data batch, encodes, and returns decoded batch. diff --git a/src/proteinfoundation/rewards/__init__.py b/src/proteinfoundation/rewards/__init__.py index e4a6afd..206fdef 100644 --- a/src/proteinfoundation/rewards/__init__.py +++ b/src/proteinfoundation/rewards/__init__.py @@ -17,7 +17,8 @@ ensure_tensor, standardize_reward, ) -from proteinfoundation.rewards.reward_utils import compute_reward_from_samples, initialize_reward_model +from proteinfoundation.rewards.energy_reward import GeometricEnergyReward +from proteinfoundation.rewards.reward_utils import RewardCache, compute_reward_from_samples, initialize_reward_model __all__ = [ "GRAD_KEY", @@ -25,6 +26,8 @@ "TOTAL_REWARD_KEY", "BaseRewardModel", "CompositeRewardModel", + "GeometricEnergyReward", + "RewardCache", "compute_reward_from_samples", "ensure_tensor", "initialize_reward_model", diff --git a/src/proteinfoundation/rewards/alphafold2_reward.py b/src/proteinfoundation/rewards/alphafold2_reward.py index cb033ed..11be30e 100644 --- a/src/proteinfoundation/rewards/alphafold2_reward.py +++ b/src/proteinfoundation/rewards/alphafold2_reward.py @@ -81,6 +81,7 @@ def __init__( use_initial_atom_pos: bool = False, seed: int = 0, device_id: int | None = None, + calibration_file: str | None = None, ) -> None: """Initialize the AF2RewardModel. @@ -126,6 +127,13 @@ def __init__( self.seed = seed self.rng = random.Random(seed) + # Platt-scaling calibration: pLDDT_calibrated = sigmoid(scale * pLDDT + bias) + # Default: identity (scale=1, bias=0 → sigmoid ≈ pLDDT for pLDDT near 0.5) + self.plddt_scale: float = 1.0 + self.plddt_bias: float = 0.0 + if calibration_file is not None: + self._load_calibration(calibration_file) + # Initialize the AF2 model self.model = mk_afdesign_model( protocol=protocol, @@ -356,15 +364,73 @@ def extract_results(self, aux: dict[str, Any]) -> dict[str, Any]: else: grad_dict["structure"] = torch.from_numpy(jax_grad_struct) + raw_plddt = torch.from_numpy(aux["plddt"]) return standardize_reward( reward=reward_components, grad=grad_dict, total_reward=total_reward, - plddt=torch.from_numpy(aux["plddt"]), + plddt=raw_plddt, + plddt_calibrated=self._calibrated_plddt(raw_plddt.mean()), pae=torch.from_numpy(aux["pae"]), ptm=torch.tensor(aux["ptm"], dtype=torch.float32), ) + # ------------------------------------------------------------------ + # Platt-scaling calibration + # ------------------------------------------------------------------ + + def calibrate(self, plddt_vals: list[float], success_labels: list[bool]) -> None: + """Fit Platt scaling from empirical pLDDT → wet-lab success labels. + + After calibration, ``score()`` returns a ``plddt_calibrated`` entry in the + reward dict that is better correlated with actual design success than the + raw AF2 pLDDT. + + Args: + plddt_vals: Mean pLDDT per design from AF2 (e.g. list of floats in [0, 1]). + success_labels: Boolean success flag per design from wet-lab validation. + """ + try: + from sklearn.linear_model import LogisticRegression + except ImportError as exc: + raise ImportError("scikit-learn is required for calibration: pip install scikit-learn") from exc + + import numpy as np + + X = np.array(plddt_vals, dtype=np.float64).reshape(-1, 1) + y = np.array(success_labels, dtype=int) + lr = LogisticRegression(max_iter=1000, solver="lbfgs") + lr.fit(X, y) + self.plddt_scale = float(lr.coef_[0][0]) + self.plddt_bias = float(lr.intercept_[0]) + logger.info( + "AF2 pLDDT calibrated: scale=%.4f, bias=%.4f " + "(applied as sigmoid(scale*pLDDT + bias))", + self.plddt_scale, + self.plddt_bias, + ) + + def save_calibration(self, path: str) -> None: + import json + + with open(path, "w") as f: + json.dump({"plddt_scale": self.plddt_scale, "plddt_bias": self.plddt_bias}, f) + + def _load_calibration(self, path: str) -> None: + import json + + with open(path) as f: + d = json.load(f) + self.plddt_scale = float(d["plddt_scale"]) + self.plddt_bias = float(d["plddt_bias"]) + logger.info("Loaded pLDDT calibration from %s: scale=%.4f, bias=%.4f", path, self.plddt_scale, self.plddt_bias) + + def _calibrated_plddt(self, plddt: torch.Tensor) -> torch.Tensor: + """Apply Platt scaling; returns uncalibrated tensor if params are identity.""" + if self.plddt_scale == 1.0 and self.plddt_bias == 0.0: + return plddt + return torch.sigmoid(self.plddt_scale * plddt + self.plddt_bias) + def _clear_model_state(self) -> None: """Clear internal model state dictionaries.""" if hasattr(self, "model"): diff --git a/src/proteinfoundation/rewards/base_reward.py b/src/proteinfoundation/rewards/base_reward.py index 9b51f39..d43102a 100644 --- a/src/proteinfoundation/rewards/base_reward.py +++ b/src/proteinfoundation/rewards/base_reward.py @@ -188,6 +188,21 @@ def extract_results(self, aux: dict[str, Any]) -> dict[str, Any]: """ return aux + def enable_cache(self, max_size: int = 5000) -> None: + """Attach a RewardCache to avoid re-scoring identical sequences. + + Common in beam search where sibling branches share parent sequences. + Safe to call on CompositeRewardModel — cache is attached at the top level + and checked by ``compute_reward_from_samples``. + + Args: + max_size: Maximum number of sequence → reward entries to keep (LRU eviction). + """ + from proteinfoundation.rewards.reward_utils import RewardCache + + self._reward_cache = RewardCache(max_size) + logger.info(f"RewardCache enabled on {type(self).__name__} (max_size={max_size})") + def cleanup(self) -> None: """Explicit cleanup of model memory. diff --git a/src/proteinfoundation/rewards/energy_reward.py b/src/proteinfoundation/rewards/energy_reward.py new file mode 100644 index 0000000..fc6c68f --- /dev/null +++ b/src/proteinfoundation/rewards/energy_reward.py @@ -0,0 +1,161 @@ +"""Fast geometric energy pre-filter reward model. + +Runs in <100ms per sample using only backbone geometry — no structure prediction. +Use as a cheap first gate before expensive AF2/RF3 calls to eliminate candidates +with obvious geometric defects. + +Checks: + - CA clash rate: non-adjacent CA pairs closer than ``ca_clash_threshold`` Å + - Backbone angle deviation: N-CA-C angles far from the ideal ~111° + +Both penalties are returned as rates in [0, 1] and combined into ``total_reward`` +(negative, so higher reward = fewer defects). Use ``reward_threshold`` in the +search config to skip AF2/RF3 on geometrically invalid samples. +""" + +import logging +import math + +import torch +import torch.nn.functional as F + +from proteinfoundation.rewards.base_reward import REWARD_KEY, TOTAL_REWARD_KEY, BaseRewardModel, standardize_reward + +logger = logging.getLogger(__name__) + +# atom37 indices for backbone atoms +_IDX_N = 0 +_IDX_CA = 1 +_IDX_C = 2 + + +class GeometricEnergyReward(BaseRewardModel): + """Fast geometry-based pre-filter: CA clash rate + backbone angle deviation. + + Intended to run before AF2/RF3 to discard geometrically invalid candidates + early, reducing expensive structure prediction calls by 4–8× on bad samples. + + Args: + clash_weight: Weight for CA clash penalty (negative = penalise clashes). + rama_weight: Weight for backbone angle deviation penalty. + ca_clash_threshold: Minimum allowed CA–CA distance in Å for non-adjacent + residues. Pairs closer than this count as clashes. + adjacency_window: Residue index distance within which CA contacts are + considered bonded/adjacent and excluded from clash counting. + """ + + IS_FOLDING_MODEL = False + SUPPORTS_GRAD = False + SUPPORTS_SAVE_PDB = False + + def __init__( + self, + clash_weight: float = -1.0, + rama_weight: float = -0.5, + ca_clash_threshold: float = 3.8, + adjacency_window: int = 2, + ) -> None: + self.clash_weight = clash_weight + self.rama_weight = rama_weight + self.ca_clash_threshold = ca_clash_threshold + self.adjacency_window = adjacency_window + + def score(self, pdb_path: str, requires_grad: bool = False, **kwargs) -> dict: + """Compute fast geometric scores from a PDB file. + + Args: + pdb_path: Path to the PDB file to evaluate. + requires_grad: Ignored — no gradient support. + + Returns: + Standardized reward dict with ``clash`` and ``rama`` components. + """ + try: + from proteinfoundation.utils.pdb_utils import from_pdb_file + + prot = from_pdb_file(pdb_path) + atom_pos = torch.from_numpy(prot.atom_positions).float() # [n, 37, 3] + atom_mask = torch.from_numpy(prot.atom_mask).bool() # [n, 37] + + ca_coords = atom_pos[:, _IDX_CA, :] # [n, 3] + n_coords = atom_pos[:, _IDX_N, :] + c_coords = atom_pos[:, _IDX_C, :] + ca_mask = atom_mask[:, _IDX_CA] # [n] + + clash = self._ca_clash_rate(ca_coords, ca_mask) + rama = self._backbone_angle_penalty(n_coords, ca_coords, c_coords, ca_mask) + + total = torch.tensor( + self.clash_weight * clash.item() + self.rama_weight * rama.item(), + dtype=torch.float32, + ) + + return standardize_reward( + reward={"clash": clash, "rama": rama}, + total_reward=total, + ) + + except Exception as exc: + logger.warning(f"GeometricEnergyReward failed for {pdb_path}: {exc}") + return standardize_reward(reward={}, total_reward=torch.tensor(0.0)) + + # ------------------------------------------------------------------ + # Internal geometry helpers + # ------------------------------------------------------------------ + + def _ca_clash_rate(self, ca: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Fraction of non-adjacent CA pairs closer than threshold. + + Args: + ca: CA coordinates [n, 3]. + mask: Boolean mask [n] of present residues. + + Returns: + Scalar clash rate in [0, 1]. + """ + n = ca.shape[0] + if n < self.adjacency_window + 2: + return torch.tensor(0.0) + + dists = torch.cdist(ca, ca) # [n, n] + + idx = torch.arange(n, device=ca.device) + adjacency = (idx[:, None] - idx[None, :]).abs() <= self.adjacency_window + valid = mask[:, None] & mask[None, :] & ~adjacency # [n, n] + + clash = (dists < self.ca_clash_threshold) & valid + n_valid = valid.float().sum().clamp(min=1.0) + return clash.float().sum() / n_valid + + def _backbone_angle_penalty( + self, + n_pos: torch.Tensor, + ca: torch.Tensor, + c_pos: torch.Tensor, + mask: torch.Tensor, + ) -> torch.Tensor: + """Fraction of residues with N-CA-C angle > 20° from the ideal ~111°. + + Args: + n_pos: N coordinates [n, 3]. + ca: CA coordinates [n, 3]. + c_pos: C coordinates [n, 3]. + mask: Boolean CA mask [n]. + + Returns: + Scalar outlier rate in [0, 1]. + """ + ideal_rad = 111.0 * math.pi / 180.0 + deviation_threshold = 20.0 * math.pi / 180.0 + + v1 = F.normalize(n_pos - ca, dim=-1) # N→CA direction [n, 3] + v2 = F.normalize(c_pos - ca, dim=-1) # C→CA direction [n, 3] + + cos_angle = (v1 * v2).sum(dim=-1).clamp(-1.0, 1.0) # [n] + angles = torch.acos(cos_angle) # [n] + + outliers = (angles - ideal_rad).abs() > deviation_threshold + outliers = outliers & mask + + n_valid = mask.float().sum().clamp(min=1.0) + return outliers.float().sum() / n_valid diff --git a/src/proteinfoundation/rewards/reward_utils.py b/src/proteinfoundation/rewards/reward_utils.py index ed1237e..0126203 100644 --- a/src/proteinfoundation/rewards/reward_utils.py +++ b/src/proteinfoundation/rewards/reward_utils.py @@ -17,6 +17,52 @@ from proteinfoundation.utils.pdb_utils import get_chain_ids_from_pdb, write_prot_ligand_to_pdb, write_prot_to_pdb +class RewardCache: + """LRU sequence-keyed cache for reward model outputs. + + Attach to a reward model via ``reward_model.enable_cache(max_size)`` to + avoid redundant scoring of identical sequences (common in beam search + where siblings share nearly identical sequences). + """ + + def __init__(self, max_size: int = 5000): + self._cache: dict[bytes, dict[str, Any]] = {} + self._order: list[bytes] = [] + self.max_size = max_size + self.hits = 0 + self.misses = 0 + + def get(self, key: bytes) -> dict[str, Any] | None: + if key in self._cache: + self._order.remove(key) + self._order.append(key) + self.hits += 1 + return self._cache[key] + self.misses += 1 + return None + + def put(self, key: bytes, value: dict[str, Any]) -> None: + if key in self._cache: + self._order.remove(key) + elif len(self._cache) >= self.max_size: + oldest = self._order.pop(0) + del self._cache[oldest] + self._cache[key] = value + self._order.append(key) + + @property + def hit_rate(self) -> float: + total = self.hits + self.misses + return self.hits / total if total > 0 else 0.0 + + def __len__(self) -> int: + return len(self._cache) + + +def _sequence_cache_key(residue_type: torch.Tensor) -> bytes: + return residue_type.detach().cpu().numpy().tobytes() + + def initialize_reward_model(inf_cfg: Any) -> Any | None: """Initialize reward model from configuration. @@ -90,12 +136,34 @@ def compute_reward_from_samples( ) return {TOTAL_REWARD_KEY: torch.zeros(batch_size, device=device)} + # --- cache lookup --- + cache: RewardCache | None = getattr(reward_model, "_reward_cache", None) + cached_results: dict[int, dict[str, float]] = {} + uncached_indices: list[int] = [] + + if cache is not None: + for i in range(batch_size): + key = _sequence_cache_key(sample_prots["residue_type"][i]) + entry = cache.get(key) + if entry is not None: + cached_results[i] = entry + else: + uncached_indices.append(i) + if cached_results: + logger.debug( + f"RewardCache: {len(cached_results)}/{batch_size} hits " + f"(hit_rate={cache.hit_rate:.2%})" + ) + else: + uncached_indices = list(range(batch_size)) + target_chain, binder_chain = None, None temp_dir = tempfile.mkdtemp() - temp_pdb_paths = [] + temp_pdb_paths: dict[int, str] = {} try: - for i in range(batch_size): + # Write PDB only for uncached samples + for i in uncached_indices: coors = sample_prots["coors"][i] residue_type = sample_prots["residue_type"][i] chain_index = ( @@ -104,7 +172,7 @@ def compute_reward_from_samples( creation_time = datetime.now().strftime("%Y%m%d_%H%M%S") device_str = str(device).replace(":", "_") temp_pdb_path = os.path.join(temp_dir, f"temp_sample_{i}_{creation_time}_{device_str}.pdb") - temp_pdb_paths.append(temp_pdb_path) + temp_pdb_paths[i] = temp_pdb_path if ligand is not None: write_prot_ligand_to_pdb( @@ -127,16 +195,12 @@ def compute_reward_from_samples( if target_chain is None: target_chain, binder_chain = get_chain_ids_from_pdb(temp_pdb_path) - # Collect per-sample total_reward and components, then stack into tensors - total_rewards_list = [] - component_keys = set() - components_per_sample: list[dict[str, float]] = [] + # Score uncached samples + component_keys: set[str] = set() + scored_results: dict[int, dict[str, float]] = {} - for i in range(batch_size): + for i in uncached_indices: temp_pdb_path = temp_pdb_paths[i] - total_reward = 0.0 - components: dict[str, float] = {} - chain_index_i = ( sample_prots.get("chain_index", [None] * batch_size)[i] if "chain_index" in sample_prots else None ) @@ -144,7 +208,7 @@ def compute_reward_from_samples( target_hotspot_mask[i % target_hotspot_mask.shape[0]] if target_hotspot_mask is not None else None ) - reward_kwargs = { + reward_kwargs: dict[str, Any] = { "target_chain": target_chain, "binder_chain": binder_chain, } @@ -160,24 +224,40 @@ def compute_reward_from_samples( ) total_reward = reward_dict[TOTAL_REWARD_KEY].item() components = _extract_reward_components(reward_dict) + components[TOTAL_REWARD_KEY] = total_reward component_keys.update(components.keys()) logger.debug(f"Sample {i}: reward = {total_reward}") + scored_results[i] = components + + # Store in cache + if cache is not None: + key = _sequence_cache_key(sample_prots["residue_type"][i]) + cache.put(key, components) + + # Merge cached + scored + all_components: list[dict[str, float]] = [] + for i in range(batch_size): + if i in cached_results: + all_components.append(cached_results[i]) + component_keys.update(cached_results[i].keys()) + else: + all_components.append(scored_results[i]) - total_rewards_list.append(total_reward) - components_per_sample.append(components) + total_rewards_list = [c[TOTAL_REWARD_KEY] for c in all_components] - # Build unified rewards dict: total_reward + all component tensors result = { TOTAL_REWARD_KEY: torch.tensor(total_rewards_list, device=device, dtype=torch.float32), } for key in sorted(component_keys): if key == TOTAL_REWARD_KEY: continue - vals = [components_per_sample[i].get(key, float("nan")) for i in range(batch_size)] + vals = [all_components[i].get(key, float("nan")) for i in range(batch_size)] result[key] = torch.tensor(vals, device=device, dtype=torch.float32) logger.info( - f"Computed rewards for {batch_size} samples. Mean reward: {result[TOTAL_REWARD_KEY].mean().item():.4f}" + f"Computed rewards for {batch_size} samples " + f"({len(uncached_indices)} scored, {len(cached_results)} cached). " + f"Mean reward: {result[TOTAL_REWARD_KEY].mean().item():.4f}" ) return result finally: diff --git a/src/proteinfoundation/search/beam_search.py b/src/proteinfoundation/search/beam_search.py index 2c6cd7c..eec8b23 100644 --- a/src/proteinfoundation/search/beam_search.py +++ b/src/proteinfoundation/search/beam_search.py @@ -124,6 +124,10 @@ def search(self, batch: dict) -> dict: f"denoised and rewards are computed on incomplete structures." ) n_steps_total = len(step_checkpoints) - 1 + # When adaptive_branching=true, taper n_branch from n_branch → ceil(n_branch/2) + # over the course of the search. Early steps (high t, high entropy) benefit from + # more branches; late steps converge quickly and waste compute branching heavily. + adaptive_branching = beam_cfg.get("adaptive_branching", False) # ── initialise noise + tags ───────────────────────────────────── init_mask = batch["mask"] @@ -153,6 +157,16 @@ def search(self, batch: dict) -> dict: end_step = step_checkpoints[i + 1] logger.info(f"\n[BeamSearch] Step {i + 1}/{n_steps_total}: denoising {start_step} -> {end_step}") + # Adaptive branching: taper from n_branch → max(1, n_branch//2) linearly. + # Late steps are near-deterministic; fewer branches waste little quality. + if adaptive_branching and n_steps_total > 1: + progress = i / (n_steps_total - 1) # 0 → 1 + n_branch_step = max(1, round(n_branch * (1.0 - 0.5 * progress))) + else: + n_branch_step = n_branch + if n_branch_step != n_branch: + logger.debug(f"[BeamSearch] Adaptive branching: n_branch_step={n_branch_step} (base={n_branch})") + # ── branch: batch all replicas × branches in one call ──────── # Previously this was a nested ``for replica × for branch`` loop # making ``beam_width × n_branch`` sequential partial_simulation @@ -198,15 +212,15 @@ def search(self, batch: dict) -> dict: # indices 18..20 → replica 1, branch 2, samples 0-2 # indices 21..23 → replica 1, branch 3, samples 0-2 # - branching_factor = beam_width * n_branch + branching_factor = beam_width * n_branch_step branch_parts = [] pred_parts = [] for replica_idx in range(beam_width): rep_xt = {k: v[replica_idx::beam_width] for k, v in xt.items()} - branch_parts.append(tile_tensor_dict(rep_xt, n_branch)) + branch_parts.append(tile_tensor_dict(rep_xt, n_branch_step)) if x_1_pred is not None: rep_pred = {k: v[replica_idx::beam_width] for k, v in x_1_pred.items()} - pred_parts.append(tile_tensor_dict(rep_pred, n_branch)) + pred_parts.append(tile_tensor_dict(rep_pred, n_branch_step)) big_xt = concat_dict_tensors(branch_parts, dim=0) big_pred = concat_dict_tensors(pred_parts, dim=0) if x_1_pred is not None else None @@ -242,7 +256,7 @@ def search(self, batch: dict) -> dict: expanded_tags = expand_tags_for_branches( metadata_tags, beam_width, - n_branch, + n_branch_step, start_step=start_step, end_step=end_step, ) @@ -306,23 +320,23 @@ def search(self, batch: dict) -> dict: ) # FIX: flat layout from branching is replica-major: # [r0_b0_s0..sN, r0_b1_s0..sN, …, rW_bB_s0..sN] - # so dim-0 of the view must be beam_width (replica), not n_branch. - rewards_reshaped = total_rewards.view(beam_width, n_branch, nsamples) - # permute → [nsamples, beam_width, n_branch] → reshape flattens + # so dim-0 of the view must be beam_width (replica), not n_branch_step. + rewards_reshaped = total_rewards.view(beam_width, n_branch_step, nsamples) + # permute → [nsamples, beam_width, n_branch_step] → reshape flattens # the last two dims so each row's columns are replica-major: - # col_idx = replica * n_branch + branch - rewards_for_selection = rewards_reshaped.permute(2, 0, 1).reshape(nsamples, beam_width * n_branch) + # col_idx = replica * n_branch_step + branch + rewards_for_selection = rewards_reshaped.permute(2, 0, 1).reshape(nsamples, beam_width * n_branch_step) top_k_indices = torch.topk(rewards_for_selection, k=beam_width, dim=1)[1] # top_k_indices are column indices into the replica-major row, - # so we recover replica and branch via divmod on n_branch. - replica_indices = top_k_indices // n_branch - branch_indices = top_k_indices % n_branch + # so we recover replica and branch via divmod on n_branch_step. + replica_indices = top_k_indices // n_branch_step + branch_indices = top_k_indices % n_branch_step # Map back to flat indices into big_xt / total_rewards using the - # replica-major formula: replica * n_branch * N + branch * N + sample. + # replica-major formula: replica * n_branch_step * N + branch * N + sample. sample_indices = torch.arange(nsamples, device=search_ctx.device).unsqueeze(1).expand(-1, beam_width) - global_indices = replica_indices * n_branch * nsamples + branch_indices * nsamples + sample_indices + global_indices = replica_indices * n_branch_step * nsamples + branch_indices * nsamples + sample_indices # Flatten row-major → output is GROUPED by sample (all beam_width # winners for sample 0 first, then sample 1, …) which matches the # grouped layout that xt started with from repeat_interleave.