Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions tests/utils/test_msa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Tests for MSAManager and the MSA cache helpers."""

from __future__ import annotations

from pathlib import Path
from unittest.mock import patch

import pytest
from sampleworks.utils import msa as msa_module
from sampleworks.utils.guidance_constants import StructurePredictor
from sampleworks.utils.msa import _compute_msa, MSAManager


# ============================================================================
# _hash_arguments (test 5)
# ============================================================================


def test_hash_arguments_is_deterministic_and_input_sensitive():
data = {"A": "MKTAY", "B": "GGGAA"}
base = MSAManager._hash_arguments(data, "greedy")

assert base == MSAManager._hash_arguments({"A": "MKTAY", "B": "GGGAA"}, "greedy")

# Different sequence -> different hash.
assert base != MSAManager._hash_arguments({"A": "MKTAY", "B": "GGGAC"}, "greedy")
# Different pairing strategy -> different hash.
assert base != MSAManager._hash_arguments(data, "complete")
# Different value order -> different hash (hash is over tuple(values)).
assert base != MSAManager._hash_arguments({"B": "GGGAA", "A": "MKTAY"}, "greedy")


# ============================================================================
# __init__ cache directory handling (test 4)
# ============================================================================


def test_cache_dir_explicit_is_used(tmp_path: Path):
manager = MSAManager(msa_cache_dir=tmp_path)

assert manager.msa_dir == tmp_path
assert manager.msa_dir.exists()


def test_cache_dir_explicit_string_converted_to_path(tmp_path: Path):
manager = MSAManager(msa_cache_dir=str(tmp_path))

assert isinstance(manager.msa_dir, Path)
assert manager.msa_dir == tmp_path


def test_cache_dir_default_created_under_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
# Redirect Path.home() so we don't pollute the real home directory.
monkeypatch.setattr(Path, "home", classmethod(lambda cls: tmp_path))

manager = MSAManager(msa_cache_dir=None)

expected = tmp_path / ".sampleworks" / "msa"
assert manager.msa_dir == expected
assert expected.exists()


# ============================================================================
# get_msa -> _compute_msa wiring (test 1)
# ============================================================================


def _write_valid_cache_files(msa_dir: Path, hash_key: str, sequences: list[str]) -> None:
"""Write a matching csv/a3m pair so _validate_msa_cache_contents passes."""
csv_lines = ["key,sequence"] + [f"-1,{s}" for s in sequences]
(msa_dir / f"{hash_key}_0.csv").write_text("\n".join(csv_lines))

a3m_lines = []
for i, s in enumerate(sequences):
a3m_lines.append(f">{hash_key}_0_{i}")
a3m_lines.append(s)
(msa_dir / f"{hash_key}_0.a3m").write_text("\n".join(a3m_lines) + "\n")


def test_get_msa_calls_compute_when_cache_missing_and_skips_when_present(tmp_path: Path):
manager = MSAManager(msa_cache_dir=tmp_path)
data = {"A": "MKTAY"}
pairing = "greedy"
hash_key = MSAManager._hash_arguments(data, pairing)

def fake_compute(data_arg, target_id, msa_dir, *args, **kwargs):
_write_valid_cache_files(msa_dir, target_id, ["MKTAY", "MKTAC"])
return {name: msa_dir / f"{target_id}_0.csv" for name in data_arg}

with patch.object(msa_module, "_compute_msa", side_effect=fake_compute) as mock_compute:
result = manager.get_msa(data, pairing, structure_predictor=StructurePredictor.BOLTZ_2)

assert mock_compute.call_count == 1
# First positional arg is data, second is the hash used as target_id.
call_args = mock_compute.call_args
assert call_args.args[0] == data
assert call_args.args[1] == hash_key
assert call_args.args[2] == tmp_path

assert result == {"A": tmp_path / f"{hash_key}_0.csv"}
assert manager._api_calls == 1
assert manager._cache_hits == 0

# Second call should hit the cache and not invoke compute again.
mock_compute.reset_mock()
manager.get_msa(data, pairing, structure_predictor=StructurePredictor.BOLTZ_2)
mock_compute.assert_not_called()
assert manager._cache_hits == 1


# ============================================================================
# _compute_msa argument forwarding to run_mmseqs2 (test 2)
# ============================================================================


def test_compute_msa_calls_run_mmseqs2_with_correct_arguments(tmp_path: Path):
data = {"A": "AAA", "B": "BBB"}
target_id = "tgt"
fake_return = [">q\nAAA\n", ">q\nBBB\n"]

with patch.object(msa_module, "run_mmseqs2", return_value=fake_return) as mock_run:
_compute_msa(
data=data,
target_id=target_id,
msa_dir=tmp_path,
msa_server_url="https://example.org/api",
msa_pairing_strategy="greedy",
api_key_header="X-Custom-Key",
api_key_value="secret",
)

# With len(data) == 2, both the paired and unpaired branches should fire.
assert mock_run.call_count == 2

paired_call, unpaired_call = mock_run.call_args_list

# Paired call.
assert paired_call.args[0] == ["AAA", "BBB"]
assert paired_call.args[1] == str(tmp_path / f"{target_id}_paired_tmp")
assert paired_call.kwargs["use_pairing"] is True
assert paired_call.kwargs["use_env"] is True
assert paired_call.kwargs["host_url"] == "https://example.org/api"
assert paired_call.kwargs["pairing_strategy"] == "greedy"
assert paired_call.kwargs["auth_headers"] == {
"Content-Type": "application/json",
"X-Custom-Key": "secret",
}

# Unpaired call: should forward the same env/auth flags as the paired call.
assert unpaired_call.args[0] == ["AAA", "BBB"]
assert unpaired_call.args[1] == str(tmp_path / f"{target_id}_unpaired_tmp")
assert unpaired_call.kwargs["use_pairing"] is False
assert unpaired_call.kwargs["use_env"] is True
assert unpaired_call.kwargs["host_url"] == "https://example.org/api"
assert unpaired_call.kwargs["pairing_strategy"] == "greedy"
assert unpaired_call.kwargs["auth_headers"] == {
"Content-Type": "application/json",
"X-Custom-Key": "secret",
}


# ============================================================================
# _compute_msa file output consistency (test 3)
# ============================================================================


def test_compute_msa_writes_matching_csv_and_a3m(tmp_path: Path):
data = {"chainA": "MKTAY"}
target_id = "thehash"
# FASTA-style: alternating header / sequence lines. _compute_msa parses [1::2].
fake_return = [">query\nMKTAY\n>hit1\nMKTAC\n>hit2\nMKTAG\n"]

with patch.object(msa_module, "run_mmseqs2", return_value=fake_return):
_compute_msa(
data=data,
target_id=target_id,
msa_dir=tmp_path,
msa_server_url="https://example.org/api",
msa_pairing_strategy="greedy",
)

csv_path = tmp_path / f"{target_id}_0.csv"
a3m_path = tmp_path / f"{target_id}_0.a3m"
assert csv_path.exists()
assert a3m_path.exists()

csv_lines = csv_path.read_text().splitlines()
assert csv_lines[0] == "key,sequence"
csv_sequences = [line.split(",", 1)[1] for line in csv_lines[1:]]

a3m_lines = a3m_path.read_text().splitlines()
a3m_sequences = [line for i, line in enumerate(a3m_lines) if i % 2 == 1]

assert csv_sequences == a3m_sequences == ["MKTAY", "MKTAC", "MKTAG"]
Loading