Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/sampleworks/utils/guidance_script_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def populate_config_for_guidance_type(self, job: JobConfig, args: argparse.Names
self.step_scaler_type = args.step_scaler_type
self.ensemble_size = job.ensemble_size

def as_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the guidance config, converting Path to strings"""
output = self.__dict__.copy()
output["density"] = str(self.density)
output["structure"] = str(self.structure)
return output

def add_generic_args(parser: argparse.ArgumentParser | GuidanceConfig):
parser.add_argument("--structure", type=str, required=True, help="Input structure")
Expand Down
3 changes: 2 additions & 1 deletion src/sampleworks/utils/guidance_script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,8 @@ def run_guidance_job_queue(job_queue_path: str) -> list[JobResult]:
job_result = run_guidance(job, job.guidance_type, model_wrapper, device)
# write out the job parameters to a JSON file in the same directory as the refined.cif file
with open(Path(job_result.output_dir) / "job_metadata.json", "w") as fp:
json.dump(job.__dict__, fp)
# use the GuidanceConfig's as_dict() method to avoid serializing PosixPath objects
json.dump(job.as_dict(), fp)

job_results.append(job_result)
torch.cuda.empty_cache() # just in case
Expand Down
88 changes: 86 additions & 2 deletions src/sampleworks/utils/msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,85 @@
MAX_MSA_SEQS = 16384


def _validate_msa_cache_contents(msa_hash: str, msa_dir: Path) -> None:
"""Validate the contents of the MSA cache.
For each hash, there should be a set of files `"{hash}_{n}.csv", "{hash}_{n}.a3m"`
where `n` is the index of the sequence in the input data, 0, 1, ..., len(sequences) - 1.
In most cases there is only one sequence, so `n = 0` only. The .a3m files should have a sequence
on every other line. The .csv files have two columns. The sequences in the second column of the
.csv file should be the same as the sequences in the a3m file.

Parameters
----------
msa_hash : str
The hash key for the MSA files.
msa_dir : Path
The directory containing the MSA cache files.

Raises
------
FileNotFoundError
If expected files are missing.
ValueError
If file contents don't match or are invalid.
"""
# Find all files matching the hash pattern
csv_files = sorted(msa_dir.glob(f"{msa_hash}_*.csv"))
a3m_files = sorted(msa_dir.glob(f"{msa_hash}_*.a3m"))

if not csv_files:
raise FileNotFoundError(f"No CSV files found for hash {msa_hash} in {msa_dir}")
if not a3m_files:
raise FileNotFoundError(f"No A3M files found for hash {msa_hash} in {msa_dir}")

# Validate that we have matching pairs
csv_indices = {int(f.stem.split('_')[-1]) for f in csv_files}
a3m_indices = {int(f.stem.split('_')[-1]) for f in a3m_files}

if csv_indices != a3m_indices:
raise ValueError(
f"Mismatch between CSV and A3M file indices. "
f"CSV: {sorted(csv_indices)}, A3M: {sorted(a3m_indices)}"
)

# Validate each pair of files
for idx in sorted(csv_indices):
csv_path = msa_dir / f"{msa_hash}_{idx}.csv"
a3m_path = msa_dir / f"{msa_hash}_{idx}.a3m"

# Read CSV sequences (skip header, take second column)
with csv_path.open('r') as f:
csv_lines = f.readlines()

if not csv_lines or csv_lines[0].strip() != "key,sequence":
raise ValueError(f"Invalid CSV header in {csv_path}")

csv_sequences = [line.strip().split(',', 1)[1] for line in csv_lines[1:] if line.strip()]

# Read A3M sequences (every other line, skipping headers)
with a3m_path.open('r') as f:
a3m_lines = f.readlines()

# A3M format: header lines start with '>', sequences on alternating lines
a3m_sequences = [line.strip() for i, line in enumerate(a3m_lines) if i % 2 == 1]

# Validate that sequences match
if len(csv_sequences) != len(a3m_sequences):
raise ValueError(
f"Sequence count mismatch for index {idx}: "
f"CSV has {len(csv_sequences)}, A3M has {len(a3m_sequences)}"
)

for seq_idx, (csv_seq, a3m_seq) in enumerate(zip(csv_sequences, a3m_sequences)):
Comment thread
k-chrispens marked this conversation as resolved.
if csv_seq != a3m_seq:
raise ValueError(
f"Sequence mismatch at index {idx}, sequence {seq_idx}: "
f"CSV and A3M sequences differ"
)

logger.debug(f"Validated {len(csv_sequences)} sequences for hash {msa_hash}, index {idx}")


# From https://github.com/jwohlwend/boltz/blob/main/src/boltz/main.py#L415
# FIXME: this function is a big, long mess. We should clean it up.
# For the love of decent code, don't copy this and use it somewhere else and respect the
Expand Down Expand Up @@ -182,6 +261,7 @@ def _compute_msa(
etc...
"""

# TODO I don't think these are actually used anywhere. We can probably remove them.
msa_idx_idr = msa_dir / f"{idx}"
msa_idx_idr.mkdir(exist_ok=True, parents=True)
logger.info(f"Writing MSA for target {target_id} sequence {idx} to {msa_idx_idr.resolve()}")
Expand Down Expand Up @@ -270,7 +350,9 @@ def get_msa(

if not all([m.exists() for m in msa_path_dict.values()]):
# this will generate both a3m and csv files for us.
msa_path_dict = _compute_msa(
# do NOT capture this output, it will return paths to .csv files,
# which we don't necessarily want.
_ = _compute_msa(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't want to capture the output, should this be void? I think this is the only place it is used currently. Either that or we return a dict of paths to all the generated files, something like that.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are okay with it, I want to leave that for a bigger cleanup of that method.

data,
hash_key, # this is the "target_id" argument to compute_msa
self.msa_dir,
Expand All @@ -285,9 +367,11 @@ def get_msa(
else:
self._cache_hits += 1

_validate_msa_cache_contents(hash_key, self.msa_dir)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid full content validation on every cache hit.

Line 370 forces full CSV/A3M read+compare even on cache hits, which can turn the fast path into expensive I/O for large MSAs. Run deep validation only after recompute (or behind a debug/integrity flag).

Suggested patch
-            if not all([m.exists() for m in msa_path_dict.values()]):
+            if not all(m.exists() for m in msa_path_dict.values()):
                 # this will generate both a3m and csv files for us.
                 # do NOT capture this output, it will return paths to .csv files,
                 # which we don't necessarily want.
                 _ = _compute_msa(
                     data,
                     hash_key,  # this is the "target_id" argument to compute_msa
                     self.msa_dir,
                     self.msa_server_url,
                     msa_pairing_strategy,
                     msa_server_username=self.msa_server_username,
                     msa_server_password=self.msa_server_password,
                     api_key_header=self.api_key_header,
                     api_key_value=self.api_key_value,
                 )
                 self._api_calls += 1
+                _validate_msa_cache_contents(hash_key, self.msa_dir)
             else:
                 self._cache_hits += 1
-
-            _validate_msa_cache_contents(hash_key, self.msa_dir)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/sampleworks/utils/msa.py` at line 370, The call to
_validate_msa_cache_contents(hash_key, self.msa_dir) forces expensive
full-CSV/A3M validation on every cache hit; change the logic in the function (or
the caller that uses hash_key and self.msa_dir) so deep validation runs only
when recomputing or when an explicit debug/integrity flag is set (e.g., add a
parameter like perform_deep_validation=False or check self.debug_integrity
before invoking _validate_msa_cache_contents), and ensure the fast path for
cache hits skips the full read/compare while preserving an option to trigger the
heavy validation during recompute or when the integrity flag is true.


# Protenix needs special MSAs
# This is a kind of hacky way to handle Protenix; I'm not sure what to do easily but
# use their pipeline, and just have it put everything in our cache directory.
# use their pipeline and just have it put everything in our cache directory.
elif structure_predictor == StructurePredictor.PROTENIX:
if not PROTENIX_AVAILABLE:
raise RuntimeError("Protenix is not installed, cannot use Protenix MSA tools.")
Expand Down
Loading