diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..917eb400 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,96 @@ +name: CI + +on: + push: + branches: [main] + paths: + - 'src/**' + - 'tests/**' + - 'pyproject.toml' + - 'pixi.lock' + - '.github/workflows/ci.yml' + - '.pre-commit-config.yaml' + pull_request: + branches: [main] + paths: + - 'src/**' + - 'tests/**' + - 'pyproject.toml' + - 'pixi.lock' + - '.github/workflows/ci.yml' + - '.pre-commit-config.yaml' + workflow_dispatch: + +concurrency: + group: ci-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + runs-on: ubuntu-latest + timeout-minutes: 10 + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install pixi + uses: prefix-dev/setup-pixi@v0.8.8 + with: + environments: boltz-dev + + - name: Ruff lint + run: pixi run -e boltz-dev ruff check . + + - name: Ruff format check + run: pixi run -e boltz-dev ruff format --check . + + typecheck: + runs-on: ubuntu-latest + timeout-minutes: 15 + permissions: + contents: read + strategy: + fail-fast: false + matrix: + environment: [boltz-dev, protenix-dev, rf3-dev] + + name: typecheck (${{ matrix.environment }}) + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install pixi + uses: prefix-dev/setup-pixi@v0.8.8 + with: + environments: ${{ matrix.environment }} + + - name: Run ty + run: pixi run -e ${{ matrix.environment }} ty check + + cpu-tests: + runs-on: ubuntu-latest + timeout-minutes: 20 + permissions: + contents: read + strategy: + fail-fast: false + matrix: + environment: [boltz-dev, protenix-dev, rf3-dev] + + name: tests (${{ matrix.environment }}) + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install pixi + uses: prefix-dev/setup-pixi@v0.8.8 + with: + environments: ${{ matrix.environment }} + + - name: Run CPU tests + run: pixi run -e ${{ matrix.environment }} cpu-tests diff --git a/.github/workflows/gpu-tests.yml b/.github/workflows/gpu-tests.yml index 9533796b..85c0d740 100644 --- a/.github/workflows/gpu-tests.yml +++ b/.github/workflows/gpu-tests.yml @@ -43,6 +43,8 @@ jobs: - name: Install pixi uses: prefix-dev/setup-pixi@19eac09b398e3d0c747adc7921926a6d802df4da # v0.8.8 + with: + cache: false # NFS-backed cache on self-hosted runner handles this - name: Build CUDA extensions run: pixi run -e ${{ matrix.environment }} python3 -c "from sampleworks.core.forward_models.xray.real_space_density_deps.ops.csrc import dilate_points_cuda" diff --git a/pixi.lock b/pixi.lock index 12ed24ce..40b60194 100644 --- a/pixi.lock +++ b/pixi.lock @@ -9719,8 +9719,8 @@ packages: timestamp: 1753407970803 - pypi: ./ name: sampleworks - version: 0.4.0 - sha256: 5db03aab50df2b70618c97837dfeb5cda94af7877f98bbc922f7438fafc86e77 + version: 0.4.1 + sha256: a9fd317c84677c0bc7f17597cca9d82a6d153d11c296a58fa5a4b1f7b31dc11b requires_dist: - atomworks[ml]==2.1.1 - python-dotenv diff --git a/pyproject.toml b/pyproject.toml index cbe25112..a7b47a7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,4 +177,19 @@ include = ["src/sampleworks/eval/bond_angle_and_length_outlier_eval_script.py"] possibly-missing-attribute = "ignore" [tool.ty.rules] +# Pre-existing type issues across the codebase; warn instead of error +# so ty runs in CI without blocking PRs while the team fixes them. unresolved-import = "ignore" +unknown-argument = "warn" +unresolved-attribute = "warn" +invalid-argument-type = "warn" +invalid-assignment = "warn" +invalid-method-override = "warn" +invalid-parameter-default = "warn" +no-matching-overload = "warn" +not-iterable = "warn" +not-subscriptable = "warn" +too-many-positional-arguments = "warn" +unsupported-operator = "warn" +unused-ignore-comment = "warn" +unused-type-ignore-comment = "warn" diff --git a/scripts/eval/bond_geometry_eval.py b/scripts/eval/bond_geometry_eval.py index d4984361..982d0457 100644 --- a/scripts/eval/bond_geometry_eval.py +++ b/scripts/eval/bond_geometry_eval.py @@ -40,7 +40,7 @@ def bond_length_violations(pose: AtomArray, tolerance: float = 0.1) -> tuple[flo """ try: bounds = check_pose_and_get_bounds(pose) - except (ValueError, BadStructureError) as e: + except (ValueError, BadStructureError): return np.nan, pd.DataFrame() bond_indices = np.sort(pose.bonds.as_array()[:, :2], axis=1) @@ -97,13 +97,12 @@ def check_pose_and_get_bounds(pose: AtomArray): "`biotite.structure.io.pdbx.get_structure(..., include_bonds=True)`" ) raise ValueError("The structure does not have bonds.") - + # this fetches values from RDKit, raises BadStructureError if the structure is bad bounds = get_distance_bounds(pose) return bounds - def bond_angle_violations(pose: AtomArray, tolerance: float = 0.1) -> tuple[float, pd.DataFrame]: """ Calculate the percentage of bonds that are outside acceptable ranges. diff --git a/scripts/eval/run_and_process_phenix_clashscore.py b/scripts/eval/run_and_process_phenix_clashscore.py index 72c90c1b..a294d0f9 100644 --- a/scripts/eval/run_and_process_phenix_clashscore.py +++ b/scripts/eval/run_and_process_phenix_clashscore.py @@ -37,9 +37,7 @@ def main(args) -> None: return clashscore_df = pd.concat(clashscore_metrics, ignore_index=True) - clashscore_df.to_csv( - args.grid_search_results_path / "clashscore_metrics.csv", index=False - ) + clashscore_df.to_csv(args.grid_search_results_path / "clashscore_metrics.csv", index=False) def process_one_trial(trial: Trial) -> pd.DataFrame: diff --git a/scripts/eval/run_and_process_tortoize.py b/scripts/eval/run_and_process_tortoize.py index 1fbbb849..0b8aa678 100644 --- a/scripts/eval/run_and_process_tortoize.py +++ b/scripts/eval/run_and_process_tortoize.py @@ -8,7 +8,6 @@ import pandas as pd from loguru import logger from pandas import DataFrame - from sampleworks.eval.grid_search_eval_utils import parse_eval_args, setup_evaluation_parameters @@ -27,9 +26,7 @@ def main(args: argparse.Namespace) -> None: try: subprocess.call("tortoize", stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) except FileNotFoundError: - raise RuntimeError( - "tortoize is not available, make sure you have installed it." - ) from None + raise RuntimeError("tortoize is not available, make sure you have installed it.") from None # The dropped variable is a list of ProteinConfigs, not used yet in this script all_trials, _ = setup_evaluation_parameters(args) @@ -122,13 +119,15 @@ def get_protein_level_z_scores(tortoize_json: dict[str, Any]) -> pd.DataFrame: out: list[dict[str, Any]] = [] model_block = tortoize_json.get("model", {}) for model_id, model_data in model_block.items(): - out.append({ - "model": str(model_id), - "ramachandran_z_score": model_data.get("ramachandran-z", None), - "ramachandran_jackknife_sd": model_data.get("ramachandran-jackknife-sd", None), - "torsion_z_score": model_data.get("torsion-z", None), - "torsion_jackknife_sd": model_data.get("torsion-jackknife-sd", None) - }) + out.append( + { + "model": str(model_id), + "ramachandran_z_score": model_data.get("ramachandran-z", None), + "ramachandran_jackknife_sd": model_data.get("ramachandran-jackknife-sd", None), + "torsion_z_score": model_data.get("torsion-z", None), + "torsion_jackknife_sd": model_data.get("torsion-jackknife-sd", None), + } + ) return pd.DataFrame(out) diff --git a/src/sampleworks/core/forward_models/xray/real_space_density_deps/ops/csrc/__init__.py b/src/sampleworks/core/forward_models/xray/real_space_density_deps/ops/csrc/__init__.py index c1eccb2f..43e26f69 100644 --- a/src/sampleworks/core/forward_models/xray/real_space_density_deps/ops/csrc/__init__.py +++ b/src/sampleworks/core/forward_models/xray/real_space_density_deps/ops/csrc/__init__.py @@ -45,4 +45,5 @@ def _ensure_toolchain_env() -> None: CUDA_AVAILABLE = True except Exception as e: print(f"CUDA extension loading failed: {e}") + dilate_points_cuda = None CUDA_AVAILABLE = False diff --git a/src/sampleworks/eval/grid_search_eval_utils.py b/src/sampleworks/eval/grid_search_eval_utils.py index 6bed7059..aa294e50 100644 --- a/src/sampleworks/eval/grid_search_eval_utils.py +++ b/src/sampleworks/eval/grid_search_eval_utils.py @@ -11,7 +11,7 @@ from loguru import logger from sampleworks.eval.constants import OCCUPANCY_LEVELS -from sampleworks.eval.eval_dataclasses import Trial, TrialList, ProteinConfig +from sampleworks.eval.eval_dataclasses import ProteinConfig, Trial, TrialList from sampleworks.eval.occupancy_utils import extract_protein_and_occupancy from sampleworks.utils.guidance_constants import StructurePredictor @@ -175,7 +175,7 @@ def parse_eval_args(description: str | None = None): type=Path, required=True, help="Path to the top-level grid search results directory, usu. called " - "``grid_search_results``", + "``grid_search_results``", ) # not technically used everywhere yet, but requiring it future-proofs. parser.add_argument( @@ -183,14 +183,14 @@ def parse_eval_args(description: str | None = None): type=Path, required=True, help="Path to the directory containing the grid search inputs, in particular " - "the protein configuration CSV file, maps, and reference structures.", + "the protein configuration CSV file, maps, and reference structures.", default=None, ) parser.add_argument( "--protein-configs-csv", type=Path, help="Path to the CSV file containing protein configurations, like " - "``${HOME}/configs.csv``. Defaults to sampleworks/data/protein_configs.csv", + "``${HOME}/configs.csv``. Defaults to sampleworks/data/protein_configs.csv", default=files("sampleworks.data") / "protein_configs.csv", ) parser.add_argument( @@ -215,7 +215,7 @@ def parse_eval_args(description: str | None = None): def setup_evaluation_parameters( - args: argparse.Namespace + args: argparse.Namespace, ) -> tuple[TrialList, dict[str, ProteinConfig]]: grid_search_dir = Path(args.grid_search_results_path) @@ -227,9 +227,7 @@ def setup_evaluation_parameters( logger.info(f"Proteins configured: {list(protein_configs.keys())}") # Scan for experiments (look for refined.cif files) - all_trials = scan_grid_search_results( - grid_search_dir, target_filename=args.target_filename - ) + all_trials = scan_grid_search_results(grid_search_dir, target_filename=args.target_filename) logger.info(f"Found {len(all_trials)} experiments with refined.cif files") if all_trials: diff --git a/src/sampleworks/utils/msa.py b/src/sampleworks/utils/msa.py index 42f7dfbe..e5b948ec 100644 --- a/src/sampleworks/utils/msa.py +++ b/src/sampleworks/utils/msa.py @@ -52,8 +52,8 @@ def _validate_msa_cache_contents(msa_hash: str, msa_dir: Path) -> None: raise FileNotFoundError(f"No A3M files found for hash {msa_hash} in {msa_dir}") # Validate that we have matching pairs - csv_indices = {int(f.stem.split('_')[-1]) for f in csv_files} - a3m_indices = {int(f.stem.split('_')[-1]) for f in a3m_files} + csv_indices = {int(f.stem.split("_")[-1]) for f in csv_files} + a3m_indices = {int(f.stem.split("_")[-1]) for f in a3m_files} if csv_indices != a3m_indices: raise ValueError( @@ -67,16 +67,16 @@ def _validate_msa_cache_contents(msa_hash: str, msa_dir: Path) -> None: a3m_path = msa_dir / f"{msa_hash}_{idx}.a3m" # Read CSV sequences (skip header, take second column) - with csv_path.open('r') as f: + with csv_path.open("r") as f: csv_lines = f.readlines() if not csv_lines or csv_lines[0].strip() != "key,sequence": raise ValueError(f"Invalid CSV header in {csv_path}") - csv_sequences = [line.strip().split(',', 1)[1] for line in csv_lines[1:] if line.strip()] + csv_sequences = [line.strip().split(",", 1)[1] for line in csv_lines[1:] if line.strip()] # Read A3M sequences (every other line, skipping headers) - with a3m_path.open('r') as f: + with a3m_path.open("r") as f: a3m_lines = f.readlines() # A3M format: header lines start with '>', sequences on alternating lines diff --git a/tests/eval/test_structure_utils.py b/tests/eval/test_structure_utils.py index 4f43bf7b..4f0de785 100644 --- a/tests/eval/test_structure_utils.py +++ b/tests/eval/test_structure_utils.py @@ -23,7 +23,9 @@ def mock_protein_config(tmp_path: Path) -> ProteinConfig: return ProteinConfig( protein="test", base_map_dir=tmp_path, - selection=["chain A and resi 1-10", ], + selection=[ + "chain A and resi 1-10", + ], resolution=2.0, map_pattern="{occ_str}.ccp4", structure_pattern="{occ_str}.cif", @@ -256,7 +258,9 @@ def test_converts_atomarray_to_stack(self, tmp_path, basic_atom_array_multichain config = ProteinConfig( protein="test", base_map_dir=tmp_path, - selection=["chain A", ], + selection=[ + "chain A", + ], resolution=2.0, map_pattern="{occ_str}.ccp4", structure_pattern="{occ_str}.cif", @@ -272,7 +276,9 @@ def test_with_real_structure(self, resources_dir): config = ProteinConfig( protein="6b8x", base_map_dir=resources_dir / "6b8x", - selection=["chain A", ], + selection=[ + "chain A", + ], resolution=1.74, map_pattern="{occ_str}.ccp4", structure_pattern="6b8x_final.pdb", @@ -299,7 +305,9 @@ def test_handles_exceptions_gracefully(self, tmp_path): config = ProteinConfig( protein="test", base_map_dir=tmp_path, - selection=["chain Z and resi 999", ], + selection=[ + "chain Z and resi 999", + ], resolution=2.0, map_pattern="{occ_str}.ccp4", structure_pattern="{occ_str}.cif", @@ -314,7 +322,9 @@ def test_with_real_structure(self, resources_dir): config = ProteinConfig( protein="6b8x", base_map_dir=resources_dir / "6b8x", - selection=[selection_string, ], + selection=[ + selection_string, + ], resolution=1.74, map_pattern="{occ_str}.ccp4", structure_pattern="6b8x_final.pdb", diff --git a/tests/models/protenix/test_ccd_expansion.py b/tests/models/protenix/test_ccd_expansion.py index c64e84c6..328ed065 100644 --- a/tests/models/protenix/test_ccd_expansion.py +++ b/tests/models/protenix/test_ccd_expansion.py @@ -20,7 +20,11 @@ class TestExpandTildeCCDCode: def test_unique_match_expands(self): """~QS should expand uniquely to A1AQS.""" - result = _expand_tilde_ccd_code("~QS") + fake_codes = ["A1AQS", "GLY", "ALA"] + _build_ccd_suffix_map.cache_clear() + with patch("protenix.data.ccd.get_all_ccd_code", return_value=fake_codes): + result = _expand_tilde_ccd_code("~QS") + _build_ccd_suffix_map.cache_clear() assert result == "A1AQS" def test_ambiguous_match_raises(self): @@ -37,7 +41,11 @@ def test_ambiguous_match_raises(self): def test_no_match_returns_original(self): """When no code matches the suffix, return the truncated code.""" - result = _expand_tilde_ccd_code("~$$") + fake_codes = ["GLY", "ALA"] + _build_ccd_suffix_map.cache_clear() + with patch("protenix.data.ccd.get_all_ccd_code", return_value=fake_codes): + result = _expand_tilde_ccd_code("~$$") + _build_ccd_suffix_map.cache_clear() assert result == "~$$" @@ -46,7 +54,11 @@ class TestStructureToProtenixJsonCCDExpansion: def test_9bn8_ligand_expanded(self, structure_9bn8): """9BN8 structure with ~QS ligand should produce CCD_A1AQS in JSON.""" - json_dict = structure_to_protenix_json(structure_9bn8) + _build_ccd_suffix_map.cache_clear() + fake_codes = ["A1AQS", "GLY", "ALA"] + with patch("protenix.data.ccd.get_all_ccd_code", return_value=fake_codes): + json_dict = structure_to_protenix_json(structure_9bn8) + _build_ccd_suffix_map.cache_clear() ligand_entries = [ entry["ligand"]["ligand"] for entry in json_dict["sequences"] if "ligand" in entry diff --git a/tests/utils/test_atom_array_utils.py b/tests/utils/test_atom_array_utils.py index 1e36aed3..76f0f2b1 100644 --- a/tests/utils/test_atom_array_utils.py +++ b/tests/utils/test_atom_array_utils.py @@ -930,7 +930,7 @@ def test_empty_atom_array(self): with pytest.raises(ValueError, match="Cannot remove atoms from empty AtomArray\|Stack"): remove_atoms_with_any_nan_coords(atom_array) - + def test_empty_atom_array_stack(self): """Test with empty AtomArrayStack.""" atom_array = AtomArray(0) diff --git a/tests/utils/test_guidance_script_arguments.py b/tests/utils/test_guidance_script_arguments.py index f7e954ce..c8b301b2 100644 --- a/tests/utils/test_guidance_script_arguments.py +++ b/tests/utils/test_guidance_script_arguments.py @@ -4,6 +4,7 @@ from argparse import Namespace from pathlib import Path +from unittest.mock import patch import pytest from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor @@ -63,8 +64,12 @@ def _build_job(model: StructurePredictor) -> JobConfig: ) -def test_populate_config_preserves_default_checkpoint_when_none_provided(model_wrapper_type): - """populate_config_for_guidance_type should keep model defaults if no checkpoint arg exists.""" +@patch( + "sampleworks.utils.guidance_script_arguments._resolve_checkpoint", + return_value="/checkpoints/mock.ckpt", +) +def test_populate_config_resolves_checkpoint_when_none_provided(_mock_resolve, model_wrapper_type): + """populate_config_for_guidance_type should auto-resolve checkpoint if no arg exists.""" config = GuidanceConfig( protein="protein", structure="/tmp/structure.cif", @@ -73,29 +78,38 @@ def test_populate_config_preserves_default_checkpoint_when_none_provided(model_w guidance_type=GuidanceType.PURE_GUIDANCE, log_path="/tmp/output/run.log", ) - default_checkpoint = config.model_checkpoint config.populate_config_for_guidance_type( _build_job(model_wrapper_type), - Namespace(use_tweedie=False), + Namespace(use_tweedie=False, step_scaler_type="noisespace"), ) - assert config.model_checkpoint == default_checkpoint + assert config.model_checkpoint == "/checkpoints/mock.ckpt" def test_populate_config_uses_model_checkpoint_argument(model_wrapper_type): """populate_config_for_guidance_type should read the model_checkpoint arg.""" - config = GuidanceConfig( - protein="protein", - structure="/tmp/structure.cif", - density="/tmp/density.mrc", - model=model_wrapper_type, - guidance_type=GuidanceType.PURE_GUIDANCE, - log_path="/tmp/output/run.log", - ) - - args = Namespace(model_checkpoint="/tmp/custom.ckpt", use_tweedie=False) - config.populate_config_for_guidance_type(_build_job(model_wrapper_type), args) + with patch( + "sampleworks.utils.guidance_script_arguments._resolve_checkpoint", + return_value="/checkpoints/mock.ckpt", + ) as mock_resolve: + config = GuidanceConfig( + protein="protein", + structure="/tmp/structure.cif", + density="/tmp/density.mrc", + model=model_wrapper_type, + guidance_type=GuidanceType.PURE_GUIDANCE, + log_path="/tmp/output/run.log", + ) + mock_resolve.reset_mock() + + args = Namespace( + model_checkpoint="/tmp/custom.ckpt", + use_tweedie=False, + step_scaler_type="noisespace", + ) + config.populate_config_for_guidance_type(_build_job(model_wrapper_type), args) + mock_resolve.assert_not_called() assert config.model_checkpoint == "/tmp/custom.ckpt" @@ -105,9 +119,16 @@ def test_populate_config_uses_model_checkpoint_argument(model_wrapper_type): # ============================================================================ -def test_validate_model_checkpoint_requires_non_empty_value(model_wrapper_type): - """Validation should fail fast when checkpoint is missing.""" - with pytest.raises(ValueError, match="Missing checkpoint"): +@patch( + "sampleworks.utils.guidance_script_arguments._resolve_checkpoint", + side_effect=ValueError( + "Running guidance requires a model checkpoint for 'model'. " + "Provide --model-checkpoint or bake checkpoints into /checkpoints/." + ), +) +def test_validate_model_checkpoint_requires_non_empty_value(_mock_resolve, model_wrapper_type): + """Validation should fail fast when checkpoint is missing and can't be auto-resolved.""" + with pytest.raises(ValueError, match="requires a model checkpoint"): validate_model_checkpoint(model_wrapper_type, "")