diff --git a/compass/plugin/__init__.py b/compass/plugin/__init__.py index 5f50ee3f..42e00488 100644 --- a/compass/plugin/__init__.py +++ b/compass/plugin/__init__.py @@ -9,6 +9,7 @@ from .ordinance import ( BaseTextExtractor, BaseParser, + DocSelectionMethod, KeywordBasedHeuristic, PromptBasedTextCollector, PromptBasedTextExtractor, diff --git a/compass/plugin/one_shot/base.py b/compass/plugin/one_shot/base.py index e5254a73..d7ca83ee 100644 --- a/compass/plugin/one_shot/base.py +++ b/compass/plugin/one_shot/base.py @@ -11,6 +11,7 @@ NoOpHeuristic, NoOpTextCollector, NoOpTextExtractor, + DocSelectionMethod, PromptBasedTextCollector, PromptBasedTextExtractor, OrdinanceExtractionPlugin, @@ -123,11 +124,26 @@ def create_schema_based_one_shot_extraction_plugin(config, tech): # noqa: C901 may provide a custom system prompt if you want to provide more specific instructions to the LLM for the structured data extraction step. - - `allow_multi_doc_extraction`: Boolean flag indicating - whether to allow multiple documents to be used for the - extraction context simultaneously. By default, ``False``, - which means the first document that returns some extracted - data will be marked as the source. + - `doc_selection_method`: String defining the multi-doc + selection option. Specifically, if multiple documents pass + the filter, this method determines how the documents are + submitted to the extraction context. Allowed options are: + + - "single doc": Use the first document that returns some + extracted data as the source document for the + extraction context. + - "multi doc context": Submit text from multiple + documents to the extraction context simultaneously. + - "multi doc all": Each document is extracted separately + and the results concatenated. This may give duplicated + feature results if the same feature is mentioned in + multiple documents. + - "multi doc mixed": Each document is extracted + separately and the results are merged together at the + end. In this approach, each feature is reported at + most once. + + By default, ``"single doc"``. tech : str Technology identifier to use for the plugin (e.g., "wind", @@ -161,10 +177,27 @@ class SchemaBasedExtractionPlugin(OrdinanceExtractionPlugin): SCHEMA = config["schema"] """dict: Schema for the output of the text extraction step""" - ALLOW_MULTI_DOC_EXTRACTION = config.get( - "allow_multi_doc_extraction", False + DOC_SELECTION_METHOD = DocSelectionMethod.normalize( + config.get("doc_selection_method", "single doc") ) - """bool: Whether to allow extraction over multiple documents""" + """str: Method for selecting documents for extraction context + + Allowed options: + + - "single doc": Use the first document that returns some + extracted data as the source document for the extraction + context. + - "multi doc context": Submit text from multiple documents + to the extraction context simultaneously. + - "multi doc all": Each document is extracted separately + and the results concatenated. This may give duplicated + feature results if the same feature is mentioned in + multiple documents. + - "multi doc mixed": Each document is extracted separately + and the results are merged together at the end. In this + approach, each feature is reported at most once. + + """ IDENTIFIER = tech """str: Identifier for extraction task """ diff --git a/compass/plugin/ordinance.py b/compass/plugin/ordinance.py index 998bbbc7..8e53180b 100644 --- a/compass/plugin/ordinance.py +++ b/compass/plugin/ordinance.py @@ -2,6 +2,8 @@ import asyncio import logging +import operator +from enum import StrEnum from warnings import warn from textwrap import dedent from itertools import chain @@ -61,6 +63,63 @@ } +class DocSelectionMethod(StrEnum): + """Document selection modes for structured extraction""" + + SINGLE_DOC = "single_doc" + """Evaluate candidate documents one at a time until data is found""" + MULTI_DOC_CONTEXT = "multi_doc_context" + """Combine multiple documents into one extraction context""" + MULTI_DOC_ALL = "multi_doc_all" + """Parse each document separately and keep all extracted rows""" + MULTI_DOC_MIXED = "multi_doc_mixed" + """Parse separately and merge rows so each feature appears once""" + + @classmethod + def normalize(cls, value): + """Normalize a config value into a selection mode + + Parameters + ---------- + value : str or DocSelectionMethod + Input selection mode from plugin configuration or an + existing enum value. + + Returns + ------- + DocSelectionMethod + Normalized document selection mode. + + Raises + ------ + COMPASSPluginConfigurationError + Raised if ``value`` is not a string or enum member, or if + it does not map to a supported selection mode. + """ + if isinstance(value, cls): + return value + + if not isinstance(value, str): + msg = ( + "doc_selection_method must be a string or " + f"{cls.__name__} value." + ) + raise COMPASSPluginConfigurationError(msg) + + normalized = ( + value.replace(" ", "_").replace("-", "_").strip().casefold() + ) + try: + return cls(normalized) + except ValueError as err: + msg = ( + f"Invalid doc_selection_method: {value!r}. " + "Allowed options are: " + f"{sorted(method.value for method in cls)}." + ) + raise COMPASSPluginConfigurationError(msg) from err + + class BaseTextExtractor(BaseLLMCaller, ABC): """Extract succinct extraction text from input""" @@ -596,8 +655,8 @@ class OrdinanceExtractionPlugin(FilteredExtractionPlugin): methods as needed. """ - ALLOW_MULTI_DOC_EXTRACTION = False - """bool: Whether to allow extraction over multiple documents""" + DOC_SELECTION_METHOD = DocSelectionMethod.SINGLE_DOC + """str: Only allow one document to be output""" @property @abstractmethod @@ -701,19 +760,99 @@ async def parse_docs_for_structured_data(self, extraction_context): Context with extracted data/information stored in the ``.attrs`` dictionary, or ``None`` if no data was extracted. """ - if self.ALLOW_MULTI_DOC_EXTRACTION: - return await self.parse_multi_doc_context_for_structured_data( - extraction_context - ) - return await self.parse_single_doc_for_structured_data( - extraction_context + match DocSelectionMethod.normalize(self.DOC_SELECTION_METHOD): + case DocSelectionMethod.SINGLE_DOC: + return await self.parse_single_doc_for_structured_data( + extraction_context + ) + + case DocSelectionMethod.MULTI_DOC_CONTEXT: + return await self.parse_multi_doc_context_for_structured_data( + extraction_context + ) + + case DocSelectionMethod.MULTI_DOC_ALL: + return await self.parse_multi_doc_concat(extraction_context) + + case DocSelectionMethod.MULTI_DOC_MIXED: + return await self.parse_multi_doc_merge(extraction_context) + + case _: + msg = ( + "Invalid DOC_SELECTION_METHOD: " + f"{self.DOC_SELECTION_METHOD!r}. " + "Supported methods are: " + f"{sorted(method.value for method in DocSelectionMethod)}." + ) + raise COMPASSPluginConfigurationError(msg) + + async def parse_single_doc_for_structured_data(self, extraction_context): + """Parse documents one at a time to extract structured data + + This mode evaluates candidate documents in sequence and stops + at the first document that produces ordinance data. Once a + usable source is found, later candidate documents are not used + to supplement, compare, or override that result. This is the + simplest selection strategy and is best suited to workflows + where one document is expected to contain the authoritative + ordinance language on its own. + + Documents are expected to come sorted by priority, with the most + likely source of ordinance language appearing first in the + `extraction_context`. + + Parameters + ---------- + extraction_context : ExtractionContext + Context containing candidate documents to parse. + + Returns + ------- + ExtractionContext or None + Context with extracted data/information stored in the + ``.attrs`` dictionary, or ``None`` if no data was extracted. + """ + for doc_for_extraction in extraction_context: + data_df = await self.parse_for_structured_data(doc_for_extraction) + row_count = self.get_structured_data_row_count(data_df) + if row_count > 0: + data_df["source"] = doc_for_extraction.attrs.get("source") + data_df["year"] = extract_year_from_doc_attrs( + doc_for_extraction.attrs + ) + await extraction_context.mark_doc_as_data_source( + doc_for_extraction, out_fn_stem=self.jurisdiction.full_name + ) + extraction_context.attrs["structured_data"] = data_df + logger.info( + "%d ordinance value(s) found for %s from doc:\n%s. ", + num_ordinances_dataframe(data_df), + self.jurisdiction.full_name, + doc_for_extraction, + ) + return extraction_context + + logger.debug( + "No ordinances found; searched %d docs", + extraction_context.num_documents, ) + return None async def parse_multi_doc_context_for_structured_data( self, extraction_context ): """Parse all documents to extract structured data/information + This mode combines the relevant text from all candidate + documents into one shared extraction context before structured + data are parsed. It is useful when the information needed for a + single ordinance feature may be split across multiple sources + and should be interpreted together rather than compared as + separate document-level outputs. When source references can be + recovered from the extracted rows, each row is mapped back to + its originating document; otherwise the result falls back to + reporting the full document set as the source context. + Parameters ---------- extraction_context : ExtractionContext @@ -748,18 +887,23 @@ async def parse_multi_doc_context_for_structured_data( extraction_context.attrs["structured_data"] = data_df logger.info( - "%d ordinance value(s) found in %d docs for %s. ", + "%d ordinance value(s) found for %s in %d docs. ", num_ordinances_dataframe(data_df), - extraction_context.num_documents, self.jurisdiction.full_name, + extraction_context.num_documents, ) return extraction_context - async def parse_single_doc_for_structured_data(self, extraction_context): - """Parse documents one at a time to extract structured data + async def parse_multi_doc_concat(self, extraction_context): + """Parse all documents and concatenate extracted data - The first document to return some extracted data will be marked - as the source and will be returned from this method. + This mode keeps all extracted ordinance rows from every + candidate document that produced structured data. Unlike the + merge mode, it does not try to choose a single best row for a + feature or resolve conflicts between sources. If the same + feature is extracted from multiple ordinances, each version is + preserved in the output with its own source and year so users + can compare the results directly. Parameters ---------- @@ -772,31 +916,121 @@ async def parse_single_doc_for_structured_data(self, extraction_context): Context with extracted data/information stored in the ``.attrs`` dictionary, or ``None`` if no data was extracted. """ - for doc_for_extraction in extraction_context: - data_df = await self.parse_for_structured_data(doc_for_extraction) + + tasks = [ + asyncio.create_task( + self.parse_for_structured_data(doc_for_extraction), + name=self.jurisdiction.full_name, + ) + for doc_for_extraction in extraction_context + ] + data_dfs = await asyncio.gather(*tasks) + + all_data = [] + for doc_ind, (data_df, doc) in enumerate( + zip(data_dfs, extraction_context, strict=True), start=1 + ): row_count = self.get_structured_data_row_count(data_df) - if row_count > 0: - data_df["source"] = doc_for_extraction.attrs.get("source") - data_df["year"] = extract_year_from_doc_attrs( - doc_for_extraction.attrs - ) - await extraction_context.mark_doc_as_data_source( - doc_for_extraction, out_fn_stem=self.jurisdiction.full_name - ) - extraction_context.attrs["structured_data"] = data_df - logger.info( - "%d ordinance value(s) found in doc from %s for %s. ", - num_ordinances_dataframe(data_df), - doc_for_extraction.attrs.get("source", "unknown source"), - self.jurisdiction.full_name, - ) - return extraction_context + if row_count == 0: + continue - logger.debug( - "No ordinances found; searched %d docs", - extraction_context.num_documents, + data_df["source"] = doc.attrs.get("source") + data_df["year"] = extract_year_from_doc_attrs(doc.attrs) + await extraction_context.mark_doc_as_data_source( + doc, out_fn_stem=f"{self.jurisdiction.full_name}_{doc_ind}" + ) + logger.info( + "%d ordinance value(s) found for %s from doc:\n%s. ", + num_ordinances_dataframe(data_df), + self.jurisdiction.full_name, + doc, + ) + all_data.append(data_df) + + if not all_data: + logger.debug( + "No ordinances found; searched %d docs", + extraction_context.num_documents, + ) + return None + + extraction_context.attrs["structured_data"] = pd.concat( + all_data, ignore_index=True ) - return None + return extraction_context + + async def parse_multi_doc_merge(self, extraction_context): + """Parse all documents and merge the extracted data + + This mode keeps at most one row per extracted feature across + all candidate documents. When every document with extracted + data has a known ordinance year, newer ordinances take + precedence and older ordinances are only used to fill in + features that are missing from the newer sources. If any + candidate document has an unknown year, documents are instead + prioritized by how many ordinance features they contain. + + Documents with extracted prohibitions are treated specially. + If any candidate document contains a prohibition, only + prohibition-bearing documents are considered for the final + merged output. The returned rows keep the source and year of + the document they came from so downstream consumers can still + trace each retained feature back to its originating ordinance. + + Parameters + ---------- + extraction_context : ExtractionContext + Context containing candidate documents to parse. + + Returns + ------- + ExtractionContext or None + Context with extracted data/information stored in the + ``.attrs`` dictionary, or ``None`` if no data was extracted. + """ + + tasks = [ + asyncio.create_task( + self.parse_for_structured_data(doc_for_extraction), + name=self.jurisdiction.full_name, + ) + for doc_for_extraction in extraction_context + ] + data_dfs = await asyncio.gather(*tasks) + + candidates = [] + for doc_ind, (data_df, doc) in enumerate( + zip(data_dfs, extraction_context, strict=True), start=1 + ): + row_count = self.get_structured_data_row_count(data_df) + if row_count == 0: + continue + + data_df["source"] = doc.attrs.get("source") + data_df["year"] = year = extract_year_from_doc_attrs(doc.attrs) + candidates.append( + { + "data_df": data_df, + "doc": doc, + "doc_ind": doc_ind, + "row_count": row_count, + "year": year, + } + ) + + if not candidates: + logger.debug( + "No ordinances found; searched %d docs", + extraction_context.num_documents, + ) + return None + + candidates = _filter_to_prohibition_cands_if_needed(candidates) + candidates = _prioritize_candidates(candidates) + extraction_context.attrs["structured_data"] = await _merge_candidates( + candidates, extraction_context, self.jurisdiction.full_name + ) + return extraction_context async def parse_for_structured_data(self, source): """Extract all possible structured data from a document @@ -1064,7 +1298,7 @@ def _register_clean_file_names(self): def _valid_chunk(chunk): """True if chunk has content""" - return chunk and "no relevant text" not in chunk.lower() + return bool(chunk and "no relevant text" not in chunk.lower()) def _validate_in_out_keys(consumers, producers): @@ -1190,3 +1424,86 @@ async def _fill_in_all_sources(data_df, extraction_context, out_fn_stem): ) return data_df + + +def _filter_to_prohibition_cands_if_needed(candidates): + """Filter to just candidates with prohibitions, if any""" + prohibition_candidates = [ + candidate + for candidate in candidates + if _has_prohibitions(candidate["data_df"]) + ] + return prohibition_candidates or candidates + + +def _prioritize_candidates(candidates): + """Sort candidates by year (only if all have years) and row count""" + if len(candidates) <= 1: + return candidates + + if all(candidate["year"] is not None for candidate in candidates): + return sorted( + candidates, + key=operator.itemgetter("year", "row_count"), + reverse=True, + ) + + return sorted( + candidates, + key=operator.itemgetter("row_count"), + reverse=True, + ) + + +async def _merge_candidates(candidates, extraction_context, out_stem): + """Merge extracted features while respecting candidate priority""" + merged_rows = [] + merged_features = set() + contributing_candidates = [] + for candidate in candidates: + data_df = candidate["data_df"] + if data_df is None or data_df.empty or "feature" not in data_df: + continue + + feature_keys = data_df["feature"].map(_feature_key) + keep_mask = feature_keys.notna() + if merged_features: + keep_mask &= ~feature_keys.isin(merged_features) + keep_mask &= ~feature_keys.duplicated() + if not keep_mask.any(): + continue + + selected_feature_keys = feature_keys.loc[keep_mask] + merged_features.update(selected_feature_keys.tolist()) + merged_rows.extend(data_df.loc[keep_mask].to_dict("records")) + contributing_candidates.append(candidate) + + if not merged_rows: + return None + + for candidate in contributing_candidates: + await extraction_context.mark_doc_as_data_source( + candidate["doc"], + out_fn_stem=f"{out_stem}_{candidate['doc_ind']}", + ) + + return pd.DataFrame(merged_rows).reset_index(drop=True) + + +def _feature_key(feature): + """Get normalized feature key""" + if pd.isna(feature): + return None + return str(feature).strip().casefold() + + +def _has_prohibitions(data_df): + """Check for prohibition in data""" + if data_df is None or data_df.empty or "feature" not in data_df: + return False + + prohibition_mask = data_df["feature"].map(_feature_key).eq("prohibitions") + if not prohibition_mask.any(): + return False + + return num_ordinances_dataframe(data_df.loc[prohibition_mask]) > 0 diff --git a/compass/scripts/download.py b/compass/scripts/download.py index f2e5c285..fc9ddc8d 100644 --- a/compass/scripts/download.py +++ b/compass/scripts/download.py @@ -699,9 +699,7 @@ async def filter_ordinance_docs( "Found %d potential ordinance documents for %s\n\t- %s", len(docs), jurisdiction.full_name, - "\n\t- ".join( - [doc.attrs.get("source", "Unknown source") for doc in docs] - ), + "\n\t- ".join([str(doc) for doc in docs]), ) return docs @@ -817,7 +815,7 @@ def _sort_final_ord_docs(all_ord_docs): def _ord_doc_sorting_key(doc): """Compute a composite sorting score for ordinance documents""" - no_date = (-1, -1, -1) + no_date = (_NEG_INF, _NEG_INF, _NEG_INF) latest_year, latest_month, latest_day = doc.attrs.get("date") or no_date best_docs_from_website = doc.attrs.get(_SCORE_KEY, 0) prefer_pdf_files = isinstance(doc, PDFDocument) diff --git a/examples/water_rights_demo/one-shot/plugin_config.yaml b/examples/water_rights_demo/one-shot/plugin_config.yaml index b5f7e47a..0627bc8b 100755 --- a/examples/water_rights_demo/one-shot/plugin_config.yaml +++ b/examples/water_rights_demo/one-shot/plugin_config.yaml @@ -2,7 +2,7 @@ schema: ./water_rights_schema.json5 data_type_short_desc: water rights and regulations -allow_multi_doc_extraction: True # Important for water rights! +doc_selection_method: "multi doc context" # Important for water rights! query_templates: - "{jurisdiction} rules" diff --git a/tests/python/unit/plugin/test_plugin_ordinances.py b/tests/python/unit/plugin/test_plugin_ordinances.py index 074c2daa..6f0f2b03 100644 --- a/tests/python/unit/plugin/test_plugin_ordinances.py +++ b/tests/python/unit/plugin/test_plugin_ordinances.py @@ -1,16 +1,103 @@ """COMPASS ordinance plugin tests""" +import asyncio +from collections import UserList from pathlib import Path +from types import SimpleNamespace +import pandas as pd import pytest from compass.plugin.ordinance import ( BaseTextCollector, BaseTextExtractor, BaseParser, + DocSelectionMethod, OrdinanceExtractionPlugin, + _feature_key, + _fill_in_all_sources, + _fill_out_multi_file_sources, + _filter_to_prohibition_cands_if_needed, + _get_source_inds, + _has_prohibitions, + _merge_candidates, + _prioritize_candidates, + _valid_chunk, + _validate_in_out_keys, ) -from compass.exceptions import COMPASSPluginConfigurationError +from compass.exceptions import ( + COMPASSPluginConfigurationError, + COMPASSRuntimeError, +) + + +class MergePlugin(OrdinanceExtractionPlugin): + """Concrete ordinance plugin for merge tests""" + + TEXT_COLLECTORS = [] + TEXT_EXTRACTORS = [] + PARSERS = [] + + IDENTIFIER = "test" + WEBSITE_KEYWORDS = ["test"] + QUERY_TEMPLATES = ["test"] + HEURISTIC = None + + async def parse_docs_for_structured_data(self, extraction_context): + return extraction_context + + +class FakeDoc: + def __init__(self, source, year=None, structured_data=None): + self.attrs = {"source": source} + if year is not None: + self.attrs["date"] = (year, 1, 1) + if structured_data is not None: + self.attrs["structured_data"] = structured_data + + +class FakeExtractionContext(UserList): + """List-like extraction context for merge tests""" + + def __init__(self, docs): + super().__init__(docs) + self.attrs = {} + self.marked_sources = [] + + @property + def num_documents(self): + return len(self) + + async def mark_doc_as_data_source(self, doc, out_fn_stem): + self.marked_sources.append((doc.attrs.get("source"), out_fn_stem)) + + +@pytest.fixture +def merge_plugin(): + """Build a concrete plugin for merge-path tests""" + + plugin = MergePlugin(None, None, None) + plugin.jurisdiction = SimpleNamespace(full_name="Test County") + return plugin + + +def _data_df(*rows): + return pd.DataFrame(rows) + + +async def _run_multi_doc_merge(plugin, context, data_dfs): + """Run the public merge path with controlled per-doc outputs""" + + for doc, data_df in zip(context, data_dfs, strict=True): + doc.attrs["structured_data"] = data_df + + async def _fake_parse_for_structured_data(doc): + await asyncio.sleep(0) + return doc.attrs["structured_data"] + + plugin.parse_for_structured_data = _fake_parse_for_structured_data + out = await plugin.parse_multi_doc_merge(context) + return out.attrs["structured_data"] def test_plugin_validation_parse_key_same(): @@ -189,5 +276,445 @@ async def parse_docs_for_structured_data(self, extraction_context): MYPlugin(None, None, None).validate_plugin_configuration() +@pytest.mark.asyncio +async def test_parse_docs_for_structured_data_accepts_enum_value(): + """Enum-valued doc selection should dispatch correctly""" + + class MYPlugin(OrdinanceExtractionPlugin): + TEXT_COLLECTORS = [] + TEXT_EXTRACTORS = [] + PARSERS = [] + + IDENTIFIER = "test" + WEBSITE_KEYWORDS = ["test"] + QUERY_TEMPLATES = ["test"] + HEURISTIC = None + DOC_SELECTION_METHOD = DocSelectionMethod.MULTI_DOC_ALL + + async def parse_single_doc_for_structured_data( + self, extraction_context + ): + raise AssertionError("wrong dispatch") + + async def parse_multi_doc_context_for_structured_data( + self, extraction_context + ): + raise AssertionError("wrong dispatch") + + async def parse_multi_doc_concat(self, extraction_context): + return "concat" + + async def parse_multi_doc_merge(self, extraction_context): + raise AssertionError("wrong dispatch") + + plugin = MYPlugin(None, None, None) + + assert await plugin.parse_docs_for_structured_data(None) == "concat" + + +@pytest.mark.asyncio +async def test_merge_multi_doc_data_prefers_latest_year(merge_plugin): + """Latest dated doc should win overlapping features""" + + context = FakeExtractionContext( + [ + FakeDoc("older", 2021), + FakeDoc("newer", 2024), + ] + ) + data_dfs = [ + _data_df( + {"feature": "setback", "value": 100, "summary": "old"}, + {"feature": "height", "value": 80, "summary": "old"}, + ), + _data_df( + {"feature": "setback", "value": 150, "summary": "new"}, + ), + ] + + merged = await _run_multi_doc_merge(merge_plugin, context, data_dfs) + + assert set(merged["feature"].str.casefold()) == {"setback", "height"} + setback = merged.loc[merged["feature"].str.casefold() == "setback"] + height = merged.loc[merged["feature"].str.casefold() == "height"] + assert setback.iloc[0]["value"] == 150 + assert setback.iloc[0]["source"] == "newer" + assert setback.iloc[0]["year"] == 2024 + assert height.iloc[0]["value"] == 80 + assert height.iloc[0]["source"] == "older" + assert height.iloc[0]["year"] == 2021 + assert context.marked_sources == [ + ("newer", "Test County_2"), + ("older", "Test County_1"), + ] + + +@pytest.mark.asyncio +async def test_merge_multi_doc_data_falls_back_to_ordinance_count( + merge_plugin, +): + """Unknown years should fall back to ordinance count priority""" + + context = FakeExtractionContext( + [ + FakeDoc("unknown-year"), + FakeDoc("known-year", 2025), + ] + ) + data_dfs = [ + _data_df( + {"feature": "setback", "value": 100, "summary": "one"}, + {"feature": "height", "value": 50, "summary": "two"}, + ), + _data_df( + {"feature": "setback", "value": 200, "summary": "other"}, + ), + ] + + merged = await _run_multi_doc_merge(merge_plugin, context, data_dfs) + + setback = merged.loc[merged["feature"].str.casefold() == "setback"] + assert setback.iloc[0]["value"] == 100 + assert setback.iloc[0]["source"] == "unknown-year" + assert pd.isna(setback.iloc[0]["year"]) + + +@pytest.mark.asyncio +async def test_merge_multi_doc_data_breaks_year_ties_by_row_count( + merge_plugin, +): + """Equal years should break ties using ordinance count""" + + context = FakeExtractionContext( + [ + FakeDoc("fewer", 2024), + FakeDoc("more", 2024), + ] + ) + data_dfs = [ + _data_df( + {"feature": "setback", "value": 100, "summary": "one"}, + ), + _data_df( + {"feature": "setback", "value": 200, "summary": "two"}, + {"feature": "height", "value": 70, "summary": "two"}, + ), + ] + + merged = await _run_multi_doc_merge(merge_plugin, context, data_dfs) + + setback = merged.loc[merged["feature"].str.casefold() == "setback"] + assert setback.iloc[0]["value"] == 200 + assert setback.iloc[0]["source"] == "more" + + +@pytest.mark.asyncio +async def test_merge_multi_doc_data_limits_to_prohibition_documents( + merge_plugin, +): + """Any prohibition should limit merging to prohibition docs only""" + + context = FakeExtractionContext( + [ + FakeDoc("prohibition-older", 2022), + FakeDoc("prohibition-newer", 2024), + FakeDoc("non-prohibition", 2026), + ] + ) + data_dfs = [ + _data_df( + { + "feature": "prohibitions", + "value": None, + "summary": "older prohibition", + }, + {"feature": "height", "value": 90, "summary": "older"}, + ), + _data_df( + { + "feature": "Prohibitions", + "value": None, + "summary": "newer prohibition", + }, + {"feature": "setback", "value": 300, "summary": "newer"}, + ), + _data_df( + {"feature": "noise", "value": 45, "summary": "ignored"}, + ), + ] + + merged = await _run_multi_doc_merge(merge_plugin, context, data_dfs) + + assert set(merged["feature"].str.casefold()) == { + "prohibitions", + "setback", + "height", + } + assert "noise" not in set(merged["feature"].str.casefold()) + prohibition = merged.loc[ + merged["feature"].str.casefold() == "prohibitions" + ] + assert prohibition.iloc[0]["source"] == "prohibition-newer" + assert context.marked_sources == [ + ("prohibition-newer", "Test County_2"), + ("prohibition-older", "Test County_1"), + ] + + +@pytest.mark.asyncio +async def test_parse_multi_doc_merge_returns_context(merge_plugin): + """Public merge path should attach merged structured data""" + + docs = [ + FakeDoc( + "older", + 2022, + _data_df( + {"feature": "height", "value": 60, "summary": "older"}, + ), + ), + FakeDoc( + "newer", + 2024, + _data_df( + {"feature": "setback", "value": 100, "summary": "newer"}, + ), + ), + ] + context = FakeExtractionContext(docs) + + async def _fake_parse_for_structured_data(doc): + await asyncio.sleep(0) + return doc.attrs["structured_data"] + + merge_plugin.parse_for_structured_data = _fake_parse_for_structured_data + + out = await merge_plugin.parse_multi_doc_merge(context) + + assert out is context + assert set(out.attrs["structured_data"]["feature"].str.casefold()) == { + "setback", + "height", + } + + +@pytest.mark.parametrize( + "chunk,expected", + [("Useful text", True), ("No relevant text.", False), ("", False)], +) +def test_valid_chunk(chunk, expected): + """Helper should reject empty and negative extraction responses""" + + assert _valid_chunk(chunk) == expected + + +def test_validate_in_out_keys_raises_for_missing_key(): + """Helper should fail when no producer satisfies a required input""" + + class Producer: + OUT_LABEL = "produced" + + class Consumer: + IN_LABEL = "missing" + + with pytest.raises( + COMPASSPluginConfigurationError, + match=r"IN_LABEL 'missing'", + ): + _validate_in_out_keys([Consumer], [Producer]) + + +def test_get_source_inds_returns_integer_indices(): + """Helper should extract integer source indices from rows""" + + data_df = _data_df( + {"feature": "setback", "source": 0}, + {"feature": "height", "source": 1}, + {"feature": "noise", "source": 1}, + ) + + source_inds = _get_source_inds(data_df, 3) + + assert list(source_inds) == [0, 1] + + +@pytest.mark.parametrize( + "data_df,num_docs,match", + [ + (_data_df({"feature": "setback"}), 2, "column not found"), + ( + _data_df({"feature": "setback", "source": "one"}), + 2, + "non-integer values", + ), + ( + _data_df({"feature": "setback", "source": 2}), + 2, + "out-of-bounds indices", + ), + ], +) +def test_get_source_inds_raises_for_invalid_source_values( + data_df, num_docs, match +): + """Helper should reject missing, invalid, and out-of-range sources""" + + with pytest.raises(COMPASSRuntimeError, match=match): + _get_source_inds(data_df, num_docs) + + +@pytest.mark.asyncio +async def test_fill_out_multi_file_sources_maps_valid_source_indices(): + """Helper should map per-row source indices back to document metadata""" + + context = FakeExtractionContext( + [FakeDoc("doc-one", 2021), FakeDoc("doc-two", 2024)] + ) + data_df = _data_df( + {"feature": "setback", "source": 0}, + {"feature": "height", "source": 1}, + ) + + filled = await _fill_out_multi_file_sources(data_df, context, "County") + + assert list(filled["source"]) == ["doc-one", "doc-two"] + assert list(filled["year"]) == [2021, 2024] + assert context.marked_sources == [ + ("doc-one", "County_1"), + ("doc-two", "County_2"), + ] + + +@pytest.mark.asyncio +async def test_fill_in_all_sources_reports_full_context_when_needed(): + """Fallback helper should report all documents when row sources fail""" + + context = FakeExtractionContext( + [FakeDoc("doc-one", 2020), FakeDoc("doc-two", 2024)] + ) + data_df = _data_df({"feature": "setback", "value": 100}) + + filled = await _fill_in_all_sources(data_df, context, "County") + + assert filled.iloc[0]["source"] == "doc-one ;\ndoc-two" + assert filled.iloc[0]["year"] == 2024 + assert context.marked_sources == [ + ("doc-one", "County_1"), + ("doc-two", "County_2"), + ] + + +def test_feature_key_normalizes_values_and_handles_missing(): + """Feature-key helper should normalize strings and preserve missing""" + + assert _feature_key(" Prohibitions ") == "prohibitions" + assert _feature_key(pd.NA) is None + + +def test_has_prohibitions_requires_ordinance_content(): + """Prohibition helper should only flag rows with actual ordinance data""" + + with_prohibition = _data_df( + {"feature": "Prohibitions", "summary": "Wind is prohibited."} + ) + without_prohibition = _data_df( + {"feature": "Prohibitions", "summary": None, "value": None} + ) + + assert _has_prohibitions(with_prohibition) + assert not _has_prohibitions(without_prohibition) + + +def test_filter_to_prohibition_candidates_only_when_present(): + """Candidate helper should narrow to prohibition-bearing documents""" + + candidates = [ + { + "data_df": _data_df( + {"feature": "setback", "summary": "Regular standard"} + ) + }, + { + "data_df": _data_df( + { + "feature": "prohibitions", + "summary": "Wind systems are prohibited.", + } + ) + }, + ] + + filtered = _filter_to_prohibition_cands_if_needed(candidates) + + assert filtered == [candidates[1]] + + +def test_prioritize_candidates_prefers_latest_year_then_row_count(): + """Priority helper should sort by year when every candidate has one""" + + candidates = [ + {"year": 2021, "row_count": 5}, + {"year": 2024, "row_count": 1}, + {"year": 2024, "row_count": 3}, + ] + + prioritized = _prioritize_candidates(candidates) + + assert prioritized == [candidates[2], candidates[1], candidates[0]] + + +def test_prioritize_candidates_falls_back_to_row_count_without_years(): + """Priority helper should ignore year sorting when any year is unknown""" + + candidates = [ + {"year": 2024, "row_count": 1}, + {"year": None, "row_count": 3}, + {"year": 2021, "row_count": 2}, + ] + + prioritized = _prioritize_candidates(candidates) + + assert prioritized == [candidates[1], candidates[2], candidates[0]] + + +@pytest.mark.asyncio +async def test_merge_candidates_keeps_first_feature_and_marks_sources(): + """Merge helper should keep first-seen features by candidate priority""" + + context = FakeExtractionContext([FakeDoc("older"), FakeDoc("newer")]) + candidates = [ + { + "data_df": _data_df( + {"feature": "setback", "value": 200, "source": "newer"}, + {"feature": "height", "value": 80, "source": "newer"}, + ), + "doc": context[1], + "doc_ind": 2, + }, + { + "data_df": _data_df( + {"feature": "setback", "value": 100, "source": "older"}, + {"feature": "noise", "value": 45, "source": "older"}, + ), + "doc": context[0], + "doc_ind": 1, + }, + ] + + merged = await _merge_candidates(candidates, context, "County") + + assert set(merged["feature"].str.casefold()) == { + "setback", + "height", + "noise", + } + setback = merged.loc[merged["feature"].str.casefold() == "setback"] + assert setback.iloc[0]["value"] == 200 + assert context.marked_sources == [ + ("newer", "County_2"), + ("older", "County_1"), + ] + + if __name__ == "__main__": pytest.main(["-q", "--show-capture=all", Path(__file__), "-rapP"]) diff --git a/tox.ini b/tox.ini index cbe07c74..ff3ad93b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] min_version = 4.26 envlist = - py{312,313,314}-cl{817}-docl{290}-lts{1}-nx{342}-nltk{391}-np{224}-oai{234}-pd{223}-pw{149}-pj{5168}-rich{1394}-toml{0102} + py{312,313,314}-cl{817}-lts{1}-nx{342}-nltk{391}-np{224}-oai{234}-pd{223}-pw{149}-pj{5168}-rich{1394}-toml{0102} [gh-actions] python = @@ -15,7 +15,6 @@ commands = pytest tests --dist loadscope {posargs} deps = cl817: click>=8.1.7,<9 c4ai063: crawl4ai>=0.6.3,<0.7 - docl290: docling>=2.90.0,<3 lts1: langchain-text-splitters>=1.0.0,<2 nx342: networkx>=3.4.2,<4 nltk391: nltk>=3.9.1,<4 @@ -38,7 +37,6 @@ description = minimum supported versions deps= click~=8.1.7 crawl4ai~=0.6.3 - docling~=2.90.0 langchain-text-splitters~=1.0.0 networkx~=3.4.2 nltk~=3.9.1