-
Notifications
You must be signed in to change notification settings - Fork 4
Fix path serialization bug and first-run RF3 MSA bug #176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6b8884a
e9b9930
c4bc7b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,85 @@ | |
| MAX_MSA_SEQS = 16384 | ||
|
|
||
|
|
||
| def _validate_msa_cache_contents(msa_hash: str, msa_dir: Path) -> None: | ||
| """Validate the contents of the MSA cache. | ||
| For each hash, there should be a set of files `"{hash}_{n}.csv", "{hash}_{n}.a3m"` | ||
| where `n` is the index of the sequence in the input data, 0, 1, ..., len(sequences) - 1. | ||
| In most cases there is only one sequence, so `n = 0` only. The .a3m files should have a sequence | ||
| on every other line. The .csv files have two columns. The sequences in the second column of the | ||
| .csv file should be the same as the sequences in the a3m file. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| msa_hash : str | ||
| The hash key for the MSA files. | ||
| msa_dir : Path | ||
| The directory containing the MSA cache files. | ||
|
|
||
| Raises | ||
| ------ | ||
| FileNotFoundError | ||
| If expected files are missing. | ||
| ValueError | ||
| If file contents don't match or are invalid. | ||
| """ | ||
| # Find all files matching the hash pattern | ||
| csv_files = sorted(msa_dir.glob(f"{msa_hash}_*.csv")) | ||
| a3m_files = sorted(msa_dir.glob(f"{msa_hash}_*.a3m")) | ||
|
|
||
| if not csv_files: | ||
| raise FileNotFoundError(f"No CSV files found for hash {msa_hash} in {msa_dir}") | ||
| if not a3m_files: | ||
| raise FileNotFoundError(f"No A3M files found for hash {msa_hash} in {msa_dir}") | ||
|
|
||
| # Validate that we have matching pairs | ||
| csv_indices = {int(f.stem.split('_')[-1]) for f in csv_files} | ||
| a3m_indices = {int(f.stem.split('_')[-1]) for f in a3m_files} | ||
|
|
||
| if csv_indices != a3m_indices: | ||
| raise ValueError( | ||
| f"Mismatch between CSV and A3M file indices. " | ||
| f"CSV: {sorted(csv_indices)}, A3M: {sorted(a3m_indices)}" | ||
| ) | ||
|
|
||
| # Validate each pair of files | ||
| for idx in sorted(csv_indices): | ||
| csv_path = msa_dir / f"{msa_hash}_{idx}.csv" | ||
| a3m_path = msa_dir / f"{msa_hash}_{idx}.a3m" | ||
|
|
||
| # Read CSV sequences (skip header, take second column) | ||
| with csv_path.open('r') as f: | ||
| csv_lines = f.readlines() | ||
|
|
||
| if not csv_lines or csv_lines[0].strip() != "key,sequence": | ||
| raise ValueError(f"Invalid CSV header in {csv_path}") | ||
|
|
||
| csv_sequences = [line.strip().split(',', 1)[1] for line in csv_lines[1:] if line.strip()] | ||
|
|
||
| # Read A3M sequences (every other line, skipping headers) | ||
| with a3m_path.open('r') as f: | ||
| a3m_lines = f.readlines() | ||
|
|
||
| # A3M format: header lines start with '>', sequences on alternating lines | ||
| a3m_sequences = [line.strip() for i, line in enumerate(a3m_lines) if i % 2 == 1] | ||
|
|
||
| # Validate that sequences match | ||
| if len(csv_sequences) != len(a3m_sequences): | ||
| raise ValueError( | ||
| f"Sequence count mismatch for index {idx}: " | ||
| f"CSV has {len(csv_sequences)}, A3M has {len(a3m_sequences)}" | ||
| ) | ||
|
|
||
| for seq_idx, (csv_seq, a3m_seq) in enumerate(zip(csv_sequences, a3m_sequences)): | ||
| if csv_seq != a3m_seq: | ||
| raise ValueError( | ||
| f"Sequence mismatch at index {idx}, sequence {seq_idx}: " | ||
| f"CSV and A3M sequences differ" | ||
| ) | ||
|
|
||
| logger.debug(f"Validated {len(csv_sequences)} sequences for hash {msa_hash}, index {idx}") | ||
|
|
||
|
|
||
| # From https://github.com/jwohlwend/boltz/blob/main/src/boltz/main.py#L415 | ||
| # FIXME: this function is a big, long mess. We should clean it up. | ||
| # For the love of decent code, don't copy this and use it somewhere else and respect the | ||
|
|
@@ -182,6 +261,7 @@ def _compute_msa( | |
| etc... | ||
| """ | ||
|
|
||
| # TODO I don't think these are actually used anywhere. We can probably remove them. | ||
| msa_idx_idr = msa_dir / f"{idx}" | ||
| msa_idx_idr.mkdir(exist_ok=True, parents=True) | ||
| logger.info(f"Writing MSA for target {target_id} sequence {idx} to {msa_idx_idr.resolve()}") | ||
|
|
@@ -270,7 +350,9 @@ def get_msa( | |
|
|
||
| if not all([m.exists() for m in msa_path_dict.values()]): | ||
| # this will generate both a3m and csv files for us. | ||
| msa_path_dict = _compute_msa( | ||
| # do NOT capture this output, it will return paths to .csv files, | ||
| # which we don't necessarily want. | ||
| _ = _compute_msa( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we don't want to capture the output, should this be void? I think this is the only place it is used currently. Either that or we return a dict of paths to all the generated files, something like that.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you are okay with it, I want to leave that for a bigger cleanup of that method. |
||
| data, | ||
| hash_key, # this is the "target_id" argument to compute_msa | ||
| self.msa_dir, | ||
|
|
@@ -285,9 +367,11 @@ def get_msa( | |
| else: | ||
| self._cache_hits += 1 | ||
|
|
||
| _validate_msa_cache_contents(hash_key, self.msa_dir) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid full content validation on every cache hit. Line 370 forces full CSV/A3M read+compare even on cache hits, which can turn the fast path into expensive I/O for large MSAs. Run deep validation only after recompute (or behind a debug/integrity flag). Suggested patch- if not all([m.exists() for m in msa_path_dict.values()]):
+ if not all(m.exists() for m in msa_path_dict.values()):
# this will generate both a3m and csv files for us.
# do NOT capture this output, it will return paths to .csv files,
# which we don't necessarily want.
_ = _compute_msa(
data,
hash_key, # this is the "target_id" argument to compute_msa
self.msa_dir,
self.msa_server_url,
msa_pairing_strategy,
msa_server_username=self.msa_server_username,
msa_server_password=self.msa_server_password,
api_key_header=self.api_key_header,
api_key_value=self.api_key_value,
)
self._api_calls += 1
+ _validate_msa_cache_contents(hash_key, self.msa_dir)
else:
self._cache_hits += 1
-
- _validate_msa_cache_contents(hash_key, self.msa_dir)🤖 Prompt for AI Agents |
||
|
|
||
| # Protenix needs special MSAs | ||
| # This is a kind of hacky way to handle Protenix; I'm not sure what to do easily but | ||
| # use their pipeline, and just have it put everything in our cache directory. | ||
| # use their pipeline and just have it put everything in our cache directory. | ||
| elif structure_predictor == StructurePredictor.PROTENIX: | ||
| if not PROTENIX_AVAILABLE: | ||
| raise RuntimeError("Protenix is not installed, cannot use Protenix MSA tools.") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.