diff --git a/src/sampleworks/utils/guidance_script_arguments.py b/src/sampleworks/utils/guidance_script_arguments.py index 509c64ae..7e71449c 100644 --- a/src/sampleworks/utils/guidance_script_arguments.py +++ b/src/sampleworks/utils/guidance_script_arguments.py @@ -188,6 +188,12 @@ def populate_config_for_guidance_type(self, job: JobConfig, args: argparse.Names self.step_scaler_type = args.step_scaler_type self.ensemble_size = job.ensemble_size + def as_dict(self) -> dict[str, Any]: + """Return a dictionary representation of the guidance config, converting Path to strings""" + output = self.__dict__.copy() + output["density"] = str(self.density) + output["structure"] = str(self.structure) + return output def add_generic_args(parser: argparse.ArgumentParser | GuidanceConfig): parser.add_argument("--structure", type=str, required=True, help="Input structure") diff --git a/src/sampleworks/utils/guidance_script_utils.py b/src/sampleworks/utils/guidance_script_utils.py index bbefb327..b6b07ce5 100644 --- a/src/sampleworks/utils/guidance_script_utils.py +++ b/src/sampleworks/utils/guidance_script_utils.py @@ -594,7 +594,8 @@ def run_guidance_job_queue(job_queue_path: str) -> list[JobResult]: job_result = run_guidance(job, job.guidance_type, model_wrapper, device) # write out the job parameters to a JSON file in the same directory as the refined.cif file with open(Path(job_result.output_dir) / "job_metadata.json", "w") as fp: - json.dump(job.__dict__, fp) + # use the GuidanceConfig's as_dict() method to avoid serializing PosixPath objects + json.dump(job.as_dict(), fp) job_results.append(job_result) torch.cuda.empty_cache() # just in case diff --git a/src/sampleworks/utils/msa.py b/src/sampleworks/utils/msa.py index 7647fbec..42f7dfbe 100644 --- a/src/sampleworks/utils/msa.py +++ b/src/sampleworks/utils/msa.py @@ -20,6 +20,85 @@ MAX_MSA_SEQS = 16384 +def _validate_msa_cache_contents(msa_hash: str, msa_dir: Path) -> None: + """Validate the contents of the MSA cache. + For each hash, there should be a set of files `"{hash}_{n}.csv", "{hash}_{n}.a3m"` + where `n` is the index of the sequence in the input data, 0, 1, ..., len(sequences) - 1. + In most cases there is only one sequence, so `n = 0` only. The .a3m files should have a sequence + on every other line. The .csv files have two columns. The sequences in the second column of the + .csv file should be the same as the sequences in the a3m file. + + Parameters + ---------- + msa_hash : str + The hash key for the MSA files. + msa_dir : Path + The directory containing the MSA cache files. + + Raises + ------ + FileNotFoundError + If expected files are missing. + ValueError + If file contents don't match or are invalid. + """ + # Find all files matching the hash pattern + csv_files = sorted(msa_dir.glob(f"{msa_hash}_*.csv")) + a3m_files = sorted(msa_dir.glob(f"{msa_hash}_*.a3m")) + + if not csv_files: + raise FileNotFoundError(f"No CSV files found for hash {msa_hash} in {msa_dir}") + if not a3m_files: + raise FileNotFoundError(f"No A3M files found for hash {msa_hash} in {msa_dir}") + + # Validate that we have matching pairs + csv_indices = {int(f.stem.split('_')[-1]) for f in csv_files} + a3m_indices = {int(f.stem.split('_')[-1]) for f in a3m_files} + + if csv_indices != a3m_indices: + raise ValueError( + f"Mismatch between CSV and A3M file indices. " + f"CSV: {sorted(csv_indices)}, A3M: {sorted(a3m_indices)}" + ) + + # Validate each pair of files + for idx in sorted(csv_indices): + csv_path = msa_dir / f"{msa_hash}_{idx}.csv" + a3m_path = msa_dir / f"{msa_hash}_{idx}.a3m" + + # Read CSV sequences (skip header, take second column) + with csv_path.open('r') as f: + csv_lines = f.readlines() + + if not csv_lines or csv_lines[0].strip() != "key,sequence": + raise ValueError(f"Invalid CSV header in {csv_path}") + + csv_sequences = [line.strip().split(',', 1)[1] for line in csv_lines[1:] if line.strip()] + + # Read A3M sequences (every other line, skipping headers) + with a3m_path.open('r') as f: + a3m_lines = f.readlines() + + # A3M format: header lines start with '>', sequences on alternating lines + a3m_sequences = [line.strip() for i, line in enumerate(a3m_lines) if i % 2 == 1] + + # Validate that sequences match + if len(csv_sequences) != len(a3m_sequences): + raise ValueError( + f"Sequence count mismatch for index {idx}: " + f"CSV has {len(csv_sequences)}, A3M has {len(a3m_sequences)}" + ) + + for seq_idx, (csv_seq, a3m_seq) in enumerate(zip(csv_sequences, a3m_sequences)): + if csv_seq != a3m_seq: + raise ValueError( + f"Sequence mismatch at index {idx}, sequence {seq_idx}: " + f"CSV and A3M sequences differ" + ) + + logger.debug(f"Validated {len(csv_sequences)} sequences for hash {msa_hash}, index {idx}") + + # From https://github.com/jwohlwend/boltz/blob/main/src/boltz/main.py#L415 # FIXME: this function is a big, long mess. We should clean it up. # For the love of decent code, don't copy this and use it somewhere else and respect the @@ -182,6 +261,7 @@ def _compute_msa( etc... """ + # TODO I don't think these are actually used anywhere. We can probably remove them. msa_idx_idr = msa_dir / f"{idx}" msa_idx_idr.mkdir(exist_ok=True, parents=True) logger.info(f"Writing MSA for target {target_id} sequence {idx} to {msa_idx_idr.resolve()}") @@ -270,7 +350,9 @@ def get_msa( if not all([m.exists() for m in msa_path_dict.values()]): # this will generate both a3m and csv files for us. - msa_path_dict = _compute_msa( + # do NOT capture this output, it will return paths to .csv files, + # which we don't necessarily want. + _ = _compute_msa( data, hash_key, # this is the "target_id" argument to compute_msa self.msa_dir, @@ -285,9 +367,11 @@ def get_msa( else: self._cache_hits += 1 + _validate_msa_cache_contents(hash_key, self.msa_dir) + # Protenix needs special MSAs # This is a kind of hacky way to handle Protenix; I'm not sure what to do easily but - # use their pipeline, and just have it put everything in our cache directory. + # use their pipeline and just have it put everything in our cache directory. elif structure_predictor == StructurePredictor.PROTENIX: if not PROTENIX_AVAILABLE: raise RuntimeError("Protenix is not installed, cannot use Protenix MSA tools.")