Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5,149 changes: 2,646 additions & 2,503 deletions pixi.lock

Large diffs are not rendered by default.

20 changes: 14 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ requires = ["hatchling"]
[dependency-groups]
analysis = [
"ipython",
"joblib",
"marimo",
"matplotlib",
"mdtraj",
Expand All @@ -29,7 +30,8 @@ dependencies = [
"einx<0.4",
"hydra-core",
"loguru",
"omegaconf"
"omegaconf",
"sfcalculator-torch>=0.3.2"
]
name = "sampleworks"
requires-python = ">= 3.11, <3.14"
Expand Down Expand Up @@ -86,6 +88,12 @@ rc-foundry = {editable = true, extras = ["rf3"], git = "https://github.com/k-chr
[tool.pixi.pypi-dependencies]
sampleworks = {editable = true, path = "."}

# Workspace-level override: reciprocalspaceship (sfcalculator-torch transitive dep)
# caps pandas<=2.2.3 in its metadata, but runs fine on 2.3.1; protenix hard-pins
# pandas==2.3.1. Standardize all envs on 2.3.1 to avoid silent downgrades.
[tool.pixi.pypi-options.dependency-overrides]
pandas = "==2.3.1"

[tool.pixi.system-requirements]
cuda = "12"

Expand Down Expand Up @@ -188,11 +196,6 @@ include = ["src/sampleworks/eval/bond_angle_and_length_outlier_eval_script.py"]
possibly-missing-attribute = "ignore"

[tool.ty.rules]
# Pre-existing type issues across the codebase; warn instead of error
# so ty runs in CI without blocking PRs while the team fixes them.
unresolved-import = "ignore"
unknown-argument = "warn"
unresolved-attribute = "warn"
invalid-argument-type = "warn"
invalid-assignment = "warn"
invalid-method-override = "warn"
Expand All @@ -201,6 +204,11 @@ no-matching-overload = "warn"
not-iterable = "warn"
not-subscriptable = "warn"
too-many-positional-arguments = "warn"
unknown-argument = "warn"
unresolved-attribute = "warn"
# Pre-existing type issues across the codebase; warn instead of error
# so ty runs in CI without blocking PRs while the team fixes them.
unresolved-import = "ignore"
unsupported-operator = "warn"
unused-ignore-comment = "warn"
unused-type-ignore-comment = "warn"
2 changes: 1 addition & 1 deletion scripts/eval/classify_altloc_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def main(args: argparse.Namespace) -> None:
)
)

out_df = pd.DataFrame(all_rows, columns=OUTPUT_COLUMNS)
out_df = pd.DataFrame(all_rows, columns=pd.Index(OUTPUT_COLUMNS))
args.output_file.parent.mkdir(parents=True, exist_ok=True)
out_df.to_csv(args.output_file, index=False)
logger.info(f"Wrote {len(out_df)} classified spans to {args.output_file}")
Expand Down
148 changes: 16 additions & 132 deletions src/sampleworks/eval/generate_synthetic_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,14 @@
import traceback
from dataclasses import dataclass, field
from pathlib import Path
from typing import cast, ClassVar
from typing import ClassVar

import torch
from atomworks.io.transforms.atom_array import remove_waters
from biotite.structure import AtomArray, AtomArrayStack
from joblib import delayed, Parallel
from loguru import logger
from sampleworks.core.forward_models.xray.real_space_density import XMap_torch
from sampleworks.eval.structure_utils import apply_selection
from sampleworks.utils.atom_array_utils import (
AltlocInfo,
detect_altlocs,
keep_amino_acids,
keep_polymer,
load_structure_with_altlocs,
remove_hydrogens,
save_structure_to_cif,
)
from sampleworks.eval.synthetic_utils import load_structure_for_synthetic_reward
from sampleworks.utils.atom_array_utils import save_structure_to_cif
from sampleworks.utils.density_utils import compute_density_from_atomarray
from sampleworks.utils.torch_utils import try_gpu

Expand Down Expand Up @@ -97,79 +87,6 @@ def from_dict(cls, row: dict[str, str]) -> "BatchRow":
)


def assign_occupancies(
atom_array: AtomArray | AtomArrayStack,
altloc_info: AltlocInfo,
mode: str,
occ_values: list[float] | None = None,
) -> AtomArray | AtomArrayStack:
"""Assign occupancy values to atoms based on their altloc membership.

Parameters
----------
atom_array
Structure to modify
altloc_info
Detected altloc information from detect_altlocs()
mode
Assignment mode: 'default' (no change), 'uniform' (1/n_altlocs each),
or 'custom' (user-specified values)
occ_values
For 'custom' mode: list of occupancy values [0.0-1.0] assigned to altlocs
in sorted order (e.g., [0.3, 0.7] assigns 0.3 to altloc 'A', 0.7 to 'B').
If fewer values than altlocs, remaining altlocs get occupancy 0.

Returns
-------
AtomArray | AtomArrayStack
Modified structure with updated occupancies

Raises
------
ValueError
If 'custom' mode is requested but no altlocs exist, or if occ_values
is None in custom mode, or if any occupancy value is outside [0.0, 1.0]
"""
if mode == "default":
return atom_array

if not altloc_info.altloc_ids:
if mode == "custom":
raise ValueError(
"Custom occupancy mode was requested, but the structure has no altlocs."
)
logger.warning("No altlocs detected, using default occupancies")
return atom_array

result = atom_array.copy()
occupancy = result.occupancy

if mode == "uniform":
n_altlocs = len(altloc_info.altloc_ids)
uniform_occ = 1.0 / n_altlocs
for altloc in altloc_info.altloc_ids:
occupancy[altloc_info.atom_masks[altloc]] = uniform_occ

elif mode == "custom":
if occ_values is None:
raise ValueError("occ_values required for custom mode")
for i, v in enumerate(occ_values):
if not 0.0 <= v <= 1.0:
raise ValueError(f"Occupancy value {v} at index {i} is out of range [0.0, 1.0]")

if len(occ_values) != len(altloc_info.altloc_ids):
logger.warning(
f"Expected {len(altloc_info.altloc_ids)} occupancy values, got {len(occ_values)}. "
"The missing values are automatically set to 0."
)
occ_values = occ_values + [0.0] * (len(altloc_info.altloc_ids) - len(occ_values))

for altloc, occ in zip(sorted(altloc_info.altloc_ids), occ_values):
occupancy[altloc_info.atom_masks[altloc]] = occ

return cast(AtomArray, result)


def save_density(density: torch.Tensor, xmap_torch: XMap_torch, output_path: Path) -> None:
"""Save a density map to disk in CCP4 format.

Expand Down Expand Up @@ -261,54 +178,18 @@ def _process_single_row(
If True, save the processed structure to a CIF file in the input directory. Default is True.
"""
structure_path = base_dir / row.filename
if not structure_path.exists():
logger.error(f"Structure not found: {structure_path}")
return

try:
atom_array = load_structure_with_altlocs(structure_path)
except Exception as e:
logger.error(
f"Failed to load {row.filename} ({type(e).__name__}): {e}\n"
f"{''.join(traceback.format_tb(e.__traceback__))}"
)
return

try:
atom_array = apply_selection(atom_array, row.selection)
except ValueError as e:
logger.error(f"Selection error for {row.filename}: {e}")
atom_array = load_structure_for_synthetic_reward(
structure_path,
occupancy_mode=occ_mode,
occupancy_values=row.occ_values,
strip_hydrogens=strip_hydrogens,
strip_waters=strip_waters,
strip_ligands=strip_ligands,
selection=row.selection,
)
if atom_array is None:
return

atom_array = remove_hydrogens(atom_array) if strip_hydrogens else atom_array
atom_array = remove_waters(atom_array) if strip_waters else atom_array
# This is currently a sort of hacky way to remove ligands by keeping only polymer atoms
# TODO: there's probably a more robust way to do this
atom_array = keep_polymer(keep_amino_acids(atom_array)) if strip_ligands else atom_array

altloc_info = detect_altlocs(atom_array) # ty: ignore[invalid-argument-type]
if row.occ_values:
if occ_mode != "custom":
logger.warning(
f"Custom occupancy values provided for {row.filename}, "
f"but occ_mode is '{occ_mode}'. Using 'custom' mode."
)
occ_mode = "custom"
try:
atom_array = assign_occupancies(atom_array, altloc_info, "custom", row.occ_values)
except ValueError as e:
logger.error(f"Occupancy assignment error for {row.filename}: {e}")
raise
elif occ_mode in {"uniform", "default"}:
try:
atom_array = assign_occupancies(atom_array, altloc_info, occ_mode)
except ValueError as e:
logger.error(f"Occupancy assignment error for {row.filename}: {e}")
raise
else:
logger.error(f"Invalid occupancy mode '{occ_mode}' for {row.filename}")
raise ValueError(f"Invalid occupancy mode '{occ_mode}'")

try:
density, xmap_torch = compute_density_from_atomarray(
atom_array, resolution=resolution, em_mode=em_mode, device=device
Expand Down Expand Up @@ -395,6 +276,9 @@ def process_batch(
rows = load_batch_csv(csv_path)
logger.info(f"Processing {len(rows)} structures from {csv_path} using {n_jobs} jobs")

# TODO(`#242`): When device is CUDA and n_jobs > 1, each loky worker gets its own
# CUDA context, risking GPU memory contention or OOM errors. Consider explicit
# per-worker device assignment and/or n_jobs capping. Same issue in SF script.
Parallel(n_jobs=n_jobs, backend="loky")(
delayed(_process_single_row)(
row=row,
Expand Down
Loading
Loading