diff --git a/openfold3/core/data/tools/colabfold_msa_server.py b/openfold3/core/data/tools/colabfold_msa_server.py index c75ddab6a..5ae0cdc88 100644 --- a/openfold3/core/data/tools/colabfold_msa_server.py +++ b/openfold3/core/data/tools/colabfold_msa_server.py @@ -24,7 +24,7 @@ from dataclasses import dataclass, field from enum import IntEnum from pathlib import Path -from typing import Literal, NamedTuple +from typing import Any, Literal, NamedTuple import numpy as np import pandas as pd @@ -60,6 +60,59 @@ def __str__(self) -> str: return self.name.lower() +def _parse_a3m_file_by_m(a3m_file_path: Path, m_filter: dict[int, Any] | None = None) -> dict[int, list[str]]: + """Parse an a3m file and extract MSA sections by M value. + + This function uses the same parsing logic as the main query_colabfold_msa_server + to extract individual MSA sections from batch a3m files. + + Real a3m files from ColabFold have null bytes (\x00) before new M value headers, + which is how the parser detects new M sections. + + Args: + a3m_file_path: Path to the a3m file to parse + m_filter: Optional dictionary mapping M values to filter by. If provided, + only M values in this dictionary will be included in the result. + If None, all M values will be included. + + Returns: + Dictionary mapping M values to lists of lines (the MSA section for that M) + """ + a3m_lines = {} + try: + with open(a3m_file_path) as f: + update_M, M = True, None + for line in f: + if len(line) > 0: + if "\x00" in line: + line = line.replace("\x00", "") + update_M = True + if line.startswith(">") and update_M: + try: + M = int(line[1:].rstrip()) + update_M = False + # If filter is provided, only include M values in the filter + if m_filter is None or M in m_filter: + if M not in a3m_lines: + a3m_lines[M] = [] + except ValueError: + # Not a pure integer (e.g., UniRef header), keep current M + pass + # Add line to current M section (only if M is set and in filter if provided) + if M is not None: + if m_filter is None or M in m_filter: + a3m_lines[M].append(line) + except FileNotFoundError: + # File doesn't exist - return empty dict without logging (expected in some cases) + return {} + except Exception as e: + # Only log warnings for unexpected errors + logger.warning(f"Error parsing a3m file {a3m_file_path}: {e}") + return {} + + return a3m_lines + + def query_colabfold_msa_server( x: list[str], prefix: Path, @@ -395,21 +448,15 @@ def download(ID, path): # Gather a3m lines a3m_lines = {} for a3m_file in a3m_files: - update_M, M = True, None - with open(a3m_file) as f: - for line in f: - if len(line) > 0: - if "\x00" in line: - line = line.replace("\x00", "") - update_M = True - if line.startswith(">") and update_M: - M = int(line[1:].rstrip()) - update_M = False - if M not in a3m_lines: - a3m_lines[M] = [] - a3m_lines[M].append(line) + file_a3m_lines = _parse_a3m_file_by_m(Path(a3m_file)) + # Merge into main dict + for M, lines in file_a3m_lines.items(): + if M not in a3m_lines: + a3m_lines[M] = [] + a3m_lines[M].extend(lines) - a3m_lines = ["".join(a3m_lines[n]) for n in Ms] + # Only include M values that exist in a3m_lines to avoid KeyError + a3m_lines = ["".join(a3m_lines[n]) for n in Ms if n in a3m_lines] return (a3m_lines, template_paths) if use_templates else a3m_lines @@ -750,6 +797,73 @@ def query_format_main(self): index=False, ) + # Copy raw a3m files to raw_colabfold_output directory after batch processing + self._organize_raw_main_outputs_by_query() + + def _organize_raw_main_outputs_by_query(self): + """Copy raw main MSA a3m files to raw_colabfold_output directory. + + This method extracts individual MSA sections from the batch a3m files + and saves them as {rep_id}/{filename}.a3m in raw_colabfold_output directory. + Only a3m files are copied, not tar.gz or m8 files. + """ + raw_main_dir = self.output_directory / "raw" / "main" + if not raw_main_dir.exists(): + return + + # Build M -> rep_id mapping + # M is the ColabFold/MMseqs2 internal sequence identifier (starts at 101) + # It's the number used in a3m file headers like >101, >102, etc. + m_to_rep_id = {m: rep_id for rep_id, m in self.colabfold_mapper.rep_id_to_m.items()} + + if not m_to_rep_id: + logger.warning(f"No M to rep_id mapping found. rep_id_to_m: {self.colabfold_mapper.rep_id_to_m}") + return + + logger.info(f"Found {len(m_to_rep_id)} M values in mapping: {list(m_to_rep_id.keys())}") + + # Create raw_colabfold_output directory + raw_colabfold_output_dir = self.output_directory / "raw_colabfold_output" + raw_colabfold_output_dir.mkdir(parents=True, exist_ok=True) + + # Find all .a3m files in raw/main directory dynamically + a3m_files = [f for f in raw_main_dir.iterdir() if f.is_file() and f.suffix == ".a3m"] + + if not a3m_files: + logger.warning(f"No .a3m files found in {raw_main_dir}") + return + + logger.info(f"Found {len(a3m_files)} a3m files to process: {[f.name for f in a3m_files]}") + + # Process each a3m file found in raw/main + for a3m_path in a3m_files: + a3m_file = a3m_path.name + + # Extract MSA sections by M value using the shared parsing function + # Filter by m_to_rep_id to only extract sections for sequences we care about + msa_sections = _parse_a3m_file_by_m(a3m_path, m_filter=m_to_rep_id) + + # Save each MSA section to rep_id-specific directory + logger.info(f"Found {len(msa_sections)} MSA sections in {a3m_file}: {list(msa_sections.keys())}") + for M, lines in msa_sections.items(): + if M not in m_to_rep_id: + logger.warning(f"M value {M} not found in m_to_rep_id mapping") + continue + rep_id = m_to_rep_id[M] + + # Create rep_id-specific directory + rep_dir = raw_colabfold_output_dir / str(rep_id) + rep_dir.mkdir(parents=True, exist_ok=True) + + # Save MSA section to rep_id directory + output_file = rep_dir / a3m_file + try: + with open(output_file, "w") as f: + f.writelines(lines) + logger.info(f"Saved MSA section for M={M}, rep_id={rep_id} to {output_file}") + except Exception as e: + logger.warning(f"Error writing {output_file}: {e}") + def query_format_paired(self): """Submits queries and formats the outputs for paired MSAs.""" paired_alignments_directory = self.output_directory / "paired" diff --git a/openfold3/tests/test_colabfold_msa.py b/openfold3/tests/test_colabfold_msa.py index 97185d383..d67481720 100644 --- a/openfold3/tests/test_colabfold_msa.py +++ b/openfold3/tests/test_colabfold_msa.py @@ -30,6 +30,7 @@ ColabFoldQueryRunner, ComplexGroup, MsaComputationSettings, + _parse_a3m_file_by_m, collect_colabfold_msa_data, get_sequence_hash, preprocess_colabfold_msas, @@ -317,3 +318,269 @@ def test_cli_output_dir_conflict_raises(self, tmp_path): assert "Output directory mismatch" in str(exc_info.value), ( "Expected ValueError on output directory conflict" ) + + +class TestParseA3mFileByM: + """Test suite for _parse_a3m_file_by_m function.""" + + def test_parse_single_m_value(self, tmp_path): + """Test parsing a3m file with a single M value.""" + a3m_content = textwrap.dedent( + """\ + >101 + SEQUENCE1 + >UniRef100_A0A123 + MATCH1 + MATCH1CONT + >UniRef100_B0B456 + MATCH2 + """ + ) + a3m_file = tmp_path / "test.a3m" + a3m_file.write_text(a3m_content) + + result = _parse_a3m_file_by_m(a3m_file) + + assert 101 in result, "Expected M value 101 in result" + assert len(result) == 1, "Expected only one M value" + lines = result[101] + assert len(lines) == 7, "Expected 7 lines for M=101 (M header, sequence, 2 UniRef headers, 3 match lines)" + assert lines[0] == ">101\n", "First line should be M header" + assert lines[1] == "SEQUENCE1\n", "Second line should be sequence" + assert ">UniRef100_A0A123\n" in lines, "Should contain UniRef header" + assert "MATCH1\n" in lines, "Should contain match sequence" + assert "MATCH1CONT\n" in lines, "Should contain continuation of match sequence" + assert ">UniRef100_B0B456\n" in lines, "Should contain second UniRef header" + assert "MATCH2\n" in lines, "Should contain second match sequence" + + def test_parse_multiple_m_values(self, tmp_path): + """Test parsing a3m file with multiple M values separated by null bytes.""" + # Real a3m files have null bytes (\x00) before new M value headers + a3m_content = ">101\nSEQUENCE1\n>UniRef100_A0A123\nMATCH1\n\x00>102\nSEQUENCE2\n>UniRef100_B0B456\nMATCH2\n\x00>103\nSEQUENCE3\n>UniRef100_C0C789\nMATCH3\n" + a3m_file = tmp_path / "test.a3m" + a3m_file.write_bytes(a3m_content.encode("utf-8")) + + result = _parse_a3m_file_by_m(a3m_file) + + assert len(result) == 3, "Expected 3 M values" + assert 101 in result, "Expected M value 101" + assert 102 in result, "Expected M value 102" + assert 103 in result, "Expected M value 103" + + # Check M=101 section + lines_101 = result[101] + assert lines_101[0] == ">101\n", "M=101 should start with >101" + assert "SEQUENCE1\n" in lines_101, "M=101 should contain SEQUENCE1" + assert ">UniRef100_A0A123\n" in lines_101, "M=101 should contain UniRef header" + assert ">102\n" not in lines_101, "M=101 should not contain >102" + + # Check M=102 section + lines_102 = result[102] + assert lines_102[0] == ">102\n", "M=102 should start with >102" + assert "SEQUENCE2\n" in lines_102, "M=102 should contain SEQUENCE2" + assert ">UniRef100_B0B456\n" in lines_102, "M=102 should contain UniRef header" + assert ">103\n" not in lines_102, "M=102 should not contain >103" + + # Check M=103 section + lines_103 = result[103] + assert lines_103[0] == ">103\n", "M=103 should start with >103" + assert "SEQUENCE3\n" in lines_103, "M=103 should contain SEQUENCE3" + + def test_parse_with_m_filter(self, tmp_path): + """Test parsing with M value filter.""" + # Real a3m files have null bytes (\x00) before new M value headers + a3m_content = ">101\nSEQUENCE1\n>UniRef100_A0A123\nMATCH1\n\x00>102\nSEQUENCE2\n>UniRef100_B0B456\nMATCH2\n\x00>103\nSEQUENCE3\n>UniRef100_C0C789\nMATCH3\n" + a3m_file = tmp_path / "test.a3m" + a3m_file.write_bytes(a3m_content.encode("utf-8")) + + # Filter to only include M=101 and M=103 + m_filter = {101: "rep1", 103: "rep3"} + result = _parse_a3m_file_by_m(a3m_file, m_filter=m_filter) + + assert len(result) == 2, "Expected 2 M values after filtering" + assert 101 in result, "Expected M value 101" + assert 103 in result, "Expected M value 103" + assert 102 not in result, "M=102 should be filtered out" + + # Check that filtered sections are correct + lines_101 = result[101] + assert "SEQUENCE1\n" in lines_101, "M=101 should contain SEQUENCE1" + assert "SEQUENCE2\n" not in lines_101, "M=101 should not contain SEQUENCE2" + + lines_103 = result[103] + assert "SEQUENCE3\n" in lines_103, "M=103 should contain SEQUENCE3" + assert "SEQUENCE2\n" not in lines_103, "M=103 should not contain SEQUENCE2" + + def test_parse_with_null_bytes(self, tmp_path): + """Test parsing a3m file with null bytes (\\x00) that trigger M updates.""" + a3m_content = ">101\nSEQUENCE1\n\x00>102\nSEQUENCE2\n" + a3m_file = tmp_path / "test.a3m" + a3m_file.write_bytes(a3m_content.encode("utf-8")) + + result = _parse_a3m_file_by_m(a3m_file) + + assert len(result) == 2, "Expected 2 M values" + assert 101 in result, "Expected M value 101" + assert 102 in result, "Expected M value 102" + + # Check that null bytes are removed + lines_101 = result[101] + assert "\x00" not in "".join(lines_101), "Null bytes should be removed from M=101" + assert "SEQUENCE1\n" in lines_101, "M=101 should contain SEQUENCE1" + + lines_102 = result[102] + assert "\x00" not in "".join(lines_102), "Null bytes should be removed from M=102" + assert "SEQUENCE2\n" in lines_102, "M=102 should contain SEQUENCE2" + + def test_parse_empty_file(self, tmp_path): + """Test parsing an empty a3m file.""" + a3m_file = tmp_path / "empty.a3m" + a3m_file.write_text("") + + result = _parse_a3m_file_by_m(a3m_file) + + assert len(result) == 0, "Expected empty result for empty file" + + def test_parse_file_with_no_m_values(self, tmp_path): + """Test parsing a3m file with no M value headers.""" + a3m_content = textwrap.dedent( + """\ + >UniRef100_A0A123 + MATCH1 + >UniRef100_B0B456 + MATCH2 + """ + ) + a3m_file = tmp_path / "test.a3m" + a3m_file.write_text(a3m_content) + + result = _parse_a3m_file_by_m(a3m_file) + + assert len(result) == 0, "Expected empty result when no M values present" + + def test_parse_file_with_m_value_less_than_101(self, tmp_path): + """Test parsing a3m file with M value less than 101 (should still parse).""" + # Real a3m files have null bytes (\x00) before new M value headers + a3m_content = ">100\nSEQUENCE1\n>UniRef100_A0A123\nMATCH1\n\x00>101\nSEQUENCE2\n>UniRef100_B0B456\nMATCH2\n" + a3m_file = tmp_path / "test.a3m" + a3m_file.write_bytes(a3m_content.encode("utf-8")) + + result = _parse_a3m_file_by_m(a3m_file) + + # Should parse both M=100 and M=101 (function doesn't filter by >= 101) + assert 100 in result, "Expected M value 100" + assert 101 in result, "Expected M value 101" + + def test_parse_with_complex_uniref_headers(self, tmp_path): + """Test parsing with complex UniRef headers that might look like M values.""" + # Real a3m files have null bytes (\x00) before new M value headers + a3m_content = ">101\nSEQUENCE1\n>UniRef100_10566|scaffold\nMATCH1\n>UniRef100_A0A208S2R6\nMATCH2\n\x00>102\nSEQUENCE2\n>UniRef100_12345\nMATCH3\n" + a3m_file = tmp_path / "test.a3m" + a3m_file.write_bytes(a3m_content.encode("utf-8")) + + result = _parse_a3m_file_by_m(a3m_file) + + assert len(result) == 2, "Expected 2 M values" + assert 101 in result, "Expected M value 101" + assert 102 in result, "Expected M value 102" + + # Check that complex UniRef headers are included in correct sections + lines_101 = result[101] + assert ">UniRef100_10566|scaffold\n" in lines_101, "Complex UniRef header should be in M=101" + assert ">UniRef100_A0A208S2R6\n" in lines_101, "UniRef header should be in M=101" + assert ">UniRef100_12345\n" not in lines_101, "UniRef header should not be in M=101" + + lines_102 = result[102] + assert ">UniRef100_12345\n" in lines_102, "UniRef header should be in M=102" + + def test_parse_with_whitespace_in_m_header(self, tmp_path): + """Test parsing M headers with whitespace.""" + # Real a3m files have null bytes (\x00) before new M value headers + a3m_content = ">101 \nSEQUENCE1\n\x00>102\t\nSEQUENCE2\n" + a3m_file = tmp_path / "test.a3m" + a3m_file.write_bytes(a3m_content.encode("utf-8")) + + result = _parse_a3m_file_by_m(a3m_file) + + assert len(result) == 2, "Expected 2 M values" + assert 101 in result, "Expected M value 101" + assert 102 in result, "Expected M value 102" + + def test_parse_with_filter_excluding_all(self, tmp_path): + """Test parsing with filter that excludes all M values.""" + # Real a3m files have null bytes (\x00) before new M value headers + a3m_content = ">101\nSEQUENCE1\n\x00>102\nSEQUENCE2\n" + a3m_file = tmp_path / "test.a3m" + a3m_file.write_bytes(a3m_content.encode("utf-8")) + + # Filter that doesn't include any M values + m_filter = {999: "dummy"} + result = _parse_a3m_file_by_m(a3m_file, m_filter=m_filter) + + assert len(result) == 0, "Expected empty result when all M values filtered out" + + def test_parse_nonexistent_file(self, tmp_path): + """Test parsing a non-existent file.""" + nonexistent_file = tmp_path / "nonexistent.a3m" + + result = _parse_a3m_file_by_m(nonexistent_file) + + assert len(result) == 0, "Expected empty result for non-existent file" + + def test_parse_with_multiline_sequences(self, tmp_path): + """Test parsing with sequences that span multiple lines.""" + # Real a3m files have null bytes (\x00) before new M value headers + a3m_content = ">101\nSEQUENCE1\nPART1\nPART2\n>UniRef100_A0A123\nMATCH1\nMATCH1CONT\n\x00>102\nSEQUENCE2\nPART3\n" + a3m_file = tmp_path / "test.a3m" + a3m_file.write_bytes(a3m_content.encode("utf-8")) + + result = _parse_a3m_file_by_m(a3m_file) + + assert len(result) == 2, "Expected 2 M values" + lines_101 = result[101] + assert "SEQUENCE1\n" in lines_101, "Should contain first part of sequence" + assert "PART1\n" in lines_101, "Should contain second part" + assert "PART2\n" in lines_101, "Should contain third part" + assert "PART3\n" not in lines_101, "Should not contain parts from M=102" + + def test_parse_with_uniref_header_with_metadata(self, tmp_path): + """Test parsing with UniRef headers that have metadata (e.g., >10648|Ga0318525_12252441_1|+3|11...).""" + # Real a3m files have null bytes (\x00) before new M value headers + a3m_content = ">101\nSEQUENCE1\n>UniRef100_A0A123\nMATCH1\n>10648|Ga0318525_12252441_1|+3|11\t57\t0.284\t3.822E-04\t143\t229\t747\t7\t91\t95\nMATCH2\n\x00>102\nSEQUENCE2\n>UniRef100_B0B456\nMATCH3\n>10566|scaffold|+1|5\t42\t0.500\t1.444E-49\t0\t121\t122\t42\t162\t163\nMATCH4\n" + a3m_file = tmp_path / "test.a3m" + a3m_file.write_bytes(a3m_content.encode("utf-8")) + + result = _parse_a3m_file_by_m(a3m_file) + + assert len(result) == 2, "Expected 2 M values" + assert 101 in result, "Expected M value 101" + assert 102 in result, "Expected M value 102" + + # Check M=101 section - should contain the UniRef header with metadata + lines_101 = result[101] + assert ">101\n" in lines_101, "M=101 should start with >101" + assert "SEQUENCE1\n" in lines_101, "M=101 should contain SEQUENCE1" + assert ">UniRef100_A0A123\n" in lines_101, "M=101 should contain simple UniRef header" + assert ">10648|Ga0318525_12252441_1|+3|11" in "".join(lines_101), ( + "M=101 should contain UniRef header with metadata (10648|...)" + ) + assert "MATCH2\n" in lines_101, "M=101 should contain match sequence after metadata header" + # Should NOT contain M=102 content + assert ">102\n" not in lines_101, "M=101 should not contain >102" + assert "SEQUENCE2\n" not in lines_101, "M=101 should not contain SEQUENCE2" + + # Check M=102 section - should contain the other UniRef header with metadata + lines_102 = result[102] + assert ">102\n" in lines_102, "M=102 should start with >102" + assert "SEQUENCE2\n" in lines_102, "M=102 should contain SEQUENCE2" + assert ">UniRef100_B0B456\n" in lines_102, "M=102 should contain simple UniRef header" + assert ">10566|scaffold|+1|5" in "".join(lines_102), ( + "M=102 should contain UniRef header with metadata (10566|...)" + ) + assert "MATCH4\n" in lines_102, "M=102 should contain match sequence after metadata header" + # Should NOT contain M=101 content + assert ">101\n" not in lines_102, "M=102 should not contain >101" + assert "SEQUENCE1\n" not in lines_102, "M=102 should not contain SEQUENCE1" + assert ">10648|Ga0318525_12252441_1|+3|11" not in "".join(lines_102), ( + "M=102 should not contain metadata header from M=101" + )