Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 38 additions & 51 deletions src/sampleworks/core/samplers/edm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,31 +133,13 @@ def __post_init__(self) -> None:
if self.sigma_data <= 0:
raise ValueError(f"sigma_data ({self.sigma_data}) must be positive")

def create_sampler(self) -> AF3EDMSampler:
"""Create EDM sampler instance from this config."""
return AF3EDMSampler(
sigma_data=self.sigma_data,
s_max=self.s_max,
s_min=self.s_min,
p=self.p,
gamma_min=self.gamma_min,
gamma_0=self.gamma_0,
noise_scale=self.noise_scale,
step_scale=self.step_scale,
augmentation=self.augmentation,
align_to_input=self.align_to_input,
alignment_reverse_diffusion=self.alignment_reverse_diffusion,
scale_guidance_to_diffusion=self.scale_guidance_to_diffusion,
device=self.device,
)


@dataclass
class AF3EDMSampler:
"""EDM-style sampler from AF3-like models.

All constants are configurable via constructor for model-specific values.
Default values match AF3 parameterization.
Initialized with a single :class:`EDMSamplerConfig` object that holds all
schedule hyperparameters and runtime options. Default values in the config
match the AF3 parameterization.

This sampler implements the EDM (Karras et al.) style sampling
approach as used in AlphaFold3 and related models, which is the Euler
Expand All @@ -172,19 +154,20 @@ class AF3EDMSampler:
https://www.nature.com/articles/s41586-024-07487-w
"""

sigma_data: float = 16.0 # assumed std dev of data distribution
s_max: float = 160.0 # upper noise schedule bound (in sigma_data units)
s_min: float = 4e-4 # lower noise schedule bound (in sigma_data units)
p: float = 7.0 # schedule exponent (rho in Karras et al.)
gamma_min: float = 0.2 # sigma threshold below which noise inflation is disabled
gamma_0: float = 0.8 # noise inflation factor (S_churn / num_steps)
noise_scale: float = 1.003 # stochastic noise multiplier (S_noise)
step_scale: float = 1.5 # Euler step size multiplier
augmentation: bool = True # random SO(3) rotation + small translation before denoising
align_to_input: bool = True # align to input reference frame
alignment_reverse_diffusion: bool = False # also align noisy state to denoised
scale_guidance_to_diffusion: bool = True # rescale guidance to match diffusion update magnitude
device: str | torch.device = "cpu"
def __init__(self, config: EDMSamplerConfig) -> None:
"""Initialize the sampler with a configuration object.

Parameters
----------
config : EDMSamplerConfig
Configuration object containing all schedule hyperparameters
(``sigma_data``, ``s_max``, ``s_min``, ``p``, ``gamma_min``,
``gamma_0``, ``noise_scale``, ``step_scale``) and runtime flags
(``augmentation``, ``align_to_input``,
``alignment_reverse_diffusion``, ``scale_guidance_to_diffusion``,
``device``).
"""
self.config = config

def check_context(self, context: StepParams) -> None:
"""Validate that the provided StepParams is ready for step.
Expand Down Expand Up @@ -245,21 +228,25 @@ def compute_schedule(self, num_steps: int) -> EDMSchedule:
EDMSchedule
Schedule object with `sigma_tm`, `sigma_t`, `gamma`, `t_hat`, and `dt` arrays.
"""
t_values = torch.linspace(0, 1, num_steps + 1, device=self.device)
t_values = torch.linspace(0, 1, num_steps + 1, device=self.config.device)

sigmas = (
self.sigma_data
self.config.sigma_data
* (
self.s_max ** (1 / self.p)
+ t_values * (self.s_min ** (1 / self.p) - self.s_max ** (1 / self.p))
self.config.s_max ** (1 / self.config.p)
+ t_values
* (
self.config.s_min ** (1 / self.config.p)
- self.config.s_max ** (1 / self.config.p)
)
)
** self.p
** self.config.p
)

gammas = torch.where(
sigmas > self.gamma_min,
torch.tensor(self.gamma_0, device=self.device),
torch.tensor(0.0, device=self.device),
sigmas > self.config.gamma_min,
torch.tensor(self.config.gamma_0, device=self.config.device),
torch.tensor(0.0, device=self.config.device),
)

sigma_tm = sigmas[:-1]
Expand Down Expand Up @@ -298,7 +285,7 @@ def get_context_for_step(self, step_index: int, schedule: SamplerSchedule) -> St
t_hat = schedule.t_hat[step_index] # ty: ignore[unresolved-attribute] (accessible after check_schedule)
dt = schedule.dt[step_index] # ty: ignore[unresolved-attribute]
sigma_tm = schedule.sigma_tm[step_index] # ty: ignore[unresolved-attribute]
eps_scale = self.noise_scale * torch.sqrt(t_hat**2 - sigma_tm**2)
eps_scale = self.config.noise_scale * torch.sqrt(t_hat**2 - sigma_tm**2)

total_steps = len(schedule.sigma_t) # ty: ignore[unresolved-attribute] (this will be accessible due to the check above)

Expand Down Expand Up @@ -357,7 +344,7 @@ def _apply_scaler_guidance(
guidance_direction, align_transform, rotation_only=True
)

if self.scale_guidance_to_diffusion:
if self.config.scale_guidance_to_diffusion:
delta_norm = torch.linalg.norm(delta, dim=(-1, -2), keepdim=True)
# scaler handles any adjustment/clipping of guidance direction, but we have diffusion
# update magnitude here, so can optionally scale to match
Expand All @@ -367,7 +354,7 @@ def _apply_scaler_guidance(
einx.multiply("b, b n c -> b n c", guidance_weight, guidance_direction)
/ context.t_effective
)
proposal_shift = self.step_scale * context.dt * scaled_delta_contribution # ty: ignore[unsupported-operator] (dt will be Array if check_context didn't raise)
proposal_shift = self.config.step_scale * context.dt * scaled_delta_contribution # ty: ignore[unsupported-operator] (dt will be Array if check_context didn't raise)

result = delta + scaled_delta_contribution
return torch.as_tensor(result), loss, torch.as_tensor(proposal_shift)
Expand Down Expand Up @@ -415,7 +402,7 @@ def step(

transform = (
create_random_transform(state_centered, center_before_rotation=False)
if self.augmentation
if self.config.augmentation
else None
)

Expand Down Expand Up @@ -460,14 +447,14 @@ def step(
target_batch_size=x_hat_0.shape[0],
)

if self.align_to_input and alignment_reference is None:
if self.config.align_to_input and alignment_reference is None:
logger.warning(
"align_to_input is True but no alignment_reference provided; "
"skipping alignment. Set alignment_reference on StepParams via "
"with_reconciler() to enable alignment."
)

if self.align_to_input and alignment_reference is not None:
if self.config.align_to_input and alignment_reference is not None:
if reconciler is not None:
x_hat_0_working_frame, align_transform = reconciler.align(
torch.as_tensor(x_hat_0),
Expand All @@ -485,7 +472,7 @@ def step(
torch.as_tensor(maybe_augmented_state), torch.as_tensor(eps), align_transform
)

if self.alignment_reverse_diffusion:
if self.config.alignment_reverse_diffusion:
noisy_state_working_frame = weighted_rigid_align_differentiable(
torch.as_tensor(noisy_state_working_frame),
torch.as_tensor(x_hat_0_working_frame), # <-- this is what is being aligned to
Expand Down Expand Up @@ -527,11 +514,11 @@ def step(

# Euler step: x_{t-1} = x_t + step_scale * dt * delta
# ty sees dt as float | None, but it will be float if check_context didn't raise
next_state = noisy_state_working_frame_t + self.step_scale * dt * delta # ty: ignore[unsupported-operator]
next_state = noisy_state_working_frame_t + self.config.step_scale * dt * delta # ty: ignore[unsupported-operator]

return SamplerStepOutput(
state=next_state,
denoised=x_hat_0_working_frame_t,
loss=loss,
log_proposal_correction=log_proposal_correction,
)
)
7 changes: 5 additions & 2 deletions src/sampleworks/utils/guidance_script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
RealSpaceRewardFunction,
setup_scattering_params,
)
from sampleworks.core.samplers.edm import AF3EDMSampler
from sampleworks.core.samplers.edm import AF3EDMSampler, EDMSamplerConfig
from sampleworks.core.scalers.fk_steering import FKSteering
from sampleworks.core.scalers.pure_guidance import PureGuidance
from sampleworks.core.scalers.step_scalers import (
Expand Down Expand Up @@ -423,12 +423,15 @@ def _run_guidance(
use_alignment_for_reverse_diffusion = is_boltz

# Create sampler with model-appropriate settings
sampler = AF3EDMSampler(
sampler_config = EDMSamplerConfig(
device=str(device),
augmentation=args.augmentation,
align_to_input=args.align_to_input,
alignment_reverse_diffusion=use_alignment_for_reverse_diffusion,
)
sampler = AF3EDMSampler(
config=sampler_config,
)

# Create step scaler for gradient-based guidance
use_tweedie = getattr(args, "use_tweedie", False)
Expand Down
28 changes: 20 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from atomworks.io.parser import parse
from atomworks.io.utils.io_utils import load_any
from biotite.structure import AtomArray, AtomArrayStack, stack
from sampleworks.core.samplers.edm import AF3EDMSampler
from sampleworks.core.samplers.edm import AF3EDMSampler, EDMSamplerConfig
from sampleworks.core.samplers.protocol import StepParams
from sampleworks.eval.structure_utils import SampleworksProcessedStructure
from sampleworks.utils.atom_reconciler import AtomReconciler
Expand Down Expand Up @@ -94,6 +94,11 @@ class ComponentInfo:
For scalers, whether it requires a reward function.
default_kwargs
Default kwargs for instantiation in tests.
config_class_path
For components (e.g. AF3EDMSampler) whose constructor takes a config object,
the path to the config class. When set, default_kwargs are passed to the
config constructor, and the resulting config is then passed to the component
constructor.
annotate_fn_path
For wrappers, fully qualified path to the annotate function.
conditioning_type_path
Expand All @@ -108,6 +113,7 @@ class ComponentInfo:
is_trajectory_sampler: bool = False
requires_reward: bool = False
default_kwargs: tuple[tuple[str, Any], ...] = ()
config_class_path: str = ""
annotate_fn_path: str = ""
conditioning_type_path: str = ""
requires_out_dir: bool = True
Expand Down Expand Up @@ -153,8 +159,8 @@ class ComponentInfo:
TrajectorySamplers.AF3EDM: ComponentInfo(
name="af3edm",
module_path="sampleworks.core.samplers.edm.AF3EDMSampler",
config_class_path="sampleworks.core.samplers.edm.EDMSamplerConfig",
is_trajectory_sampler=True,
default_kwargs=(("augmentation", True), ("align_to_input", True)),
),
}

Expand Down Expand Up @@ -240,6 +246,15 @@ def create_sampler_from_type(
) -> Any:
"""Create sampler from TrajectorySamplers enum."""
info = SAMPLER_REGISTRY[sampler_type]
if info.config_class_path:
config_cls = _import_from_path(info.config_class_path)
config_kwargs = dict(info.default_kwargs)
config_kwargs.update(extra_kwargs)
if device is not None:
config_kwargs["device"] = device
config = config_cls(**config_kwargs)
cls = _import_from_path(info.module_path)
return cls(config)
return create_component_from_info(info, device=device, **extra_kwargs)


Expand Down Expand Up @@ -981,11 +996,8 @@ def converging_mock_wrapper(device: torch.device) -> MockFlowModelWrapper:
@pytest.fixture
def edm_sampler(device: torch.device) -> AF3EDMSampler:
"""AF3EDMSampler configured for testing."""
return AF3EDMSampler(
device=device,
augmentation=False,
align_to_input=False,
)
config = EDMSamplerConfig(device=device, augmentation=False, align_to_input=False)
return AF3EDMSampler(config)


@pytest.fixture
Expand Down Expand Up @@ -1056,4 +1068,4 @@ def perturbed_coords(
torch.manual_seed(42)
base = converging_mock_wrapper.target
perturbation = torch.randn_like(base) * 0.1 # ty: ignore[invalid-argument-type]
return base, base + perturbation # ty: ignore[invalid-return-type, unsupported-operator]
return base, base + perturbation # ty: ignore[invalid-return-type, unsupported-operator]
11 changes: 7 additions & 4 deletions tests/integration/test_mismatch_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from atomworks.io.transforms.atom_array import ensure_atom_array_stack
from biotite.structure import AtomArray
from sampleworks.core.rewards.protocol import RewardInputs
from sampleworks.core.samplers.edm import AF3EDMSampler
from sampleworks.core.samplers.edm import AF3EDMSampler, EDMSamplerConfig
from sampleworks.core.samplers.protocol import StepParams
from sampleworks.core.scalers.fk_steering import FKSteering
from sampleworks.core.scalers.pure_guidance import PureGuidance
Expand Down Expand Up @@ -720,13 +720,14 @@ class TestSamplerStep:
@pytest.fixture
def sampler(self) -> AF3EDMSampler:
"""Sampler configured for deterministic mismatch tests."""
return AF3EDMSampler(
config = EDMSamplerConfig(
augmentation=False,
align_to_input=True,
alignment_reverse_diffusion=False,
scale_guidance_to_diffusion=True,
device="cpu",
)
return AF3EDMSampler(config)

def _context_with_reference(
self,
Expand Down Expand Up @@ -782,13 +783,14 @@ def test_alignment_reduces_rmsd(self, mismatch_case: MismatchCase, sampler: AF3E
state = torch.randn(1, mismatch_case.n_model, 3)
context = self._context_with_reference(reconciler, reference)

sampler_no_align = AF3EDMSampler(
config_no_align = EDMSamplerConfig(
augmentation=False,
align_to_input=False,
alignment_reverse_diffusion=False,
scale_guidance_to_diffusion=True,
device="cpu",
)
sampler_no_align = AF3EDMSampler(config_no_align)

torch.manual_seed(42)
output_aligned = sampler.step(state.clone(), wrapper, context, features=features)
Expand Down Expand Up @@ -877,7 +879,8 @@ def _run_scaler(self, case: MismatchCase, scaler_type: str, reward) -> Any:
"asym_unit": case.struct_atom_array.copy(),
"metadata": {"id": case.id},
}
sampler = AF3EDMSampler(augmentation=False, align_to_input=True, device="cpu")
config = EDMSamplerConfig(augmentation=False, align_to_input=True, device="cpu")
sampler = AF3EDMSampler(config)
step_scaler = DataSpaceDPSScaler(step_size=0.01)

if scaler_type == "pure_guidance":
Expand Down
Loading
Loading