diff --git a/open-media-match/.devcontainer/omm_config.py b/open-media-match/.devcontainer/omm_config.py index 62b197249..106ff5f50 100644 --- a/open-media-match/.devcontainer/omm_config.py +++ b/open-media-match/.devcontainer/omm_config.py @@ -1,8 +1,15 @@ +# This is the configuration that is used by default for the developer instance +# which runs in the dev container by default. Every config field is present +# to make it easier to copy this + +from OpenMediaMatch.storage.postgres.impl import DefaultOMMStore from threatexchange.signal_type.pdq.signal import PdqSignal from threatexchange.signal_type.md5 import VideoMD5Signal +from threatexchange.content_type.photo import PhotoContent +from threatexchange.content_type.video import VideoContent +from threatexchange.exchanges.impl.static_sample import StaticSampleSignalExchangeAPI # Database configuration -PRODUCTION = False DBUSER = "media_match" DBPASS = "hunter2" DBHOST = "localhost" @@ -10,13 +17,17 @@ DATABASE_URI = f"postgresql+psycopg2://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}" # Role configuration +PRODUCTION = False ROLE_HASHER = True ROLE_MATCHER = True ROLE_CURATOR = True - -# Installed signal types -SIGNAL_TYPES = [PdqSignal, VideoMD5Signal] - - # APScheduler (background threads for development) -SCHEDULER_API_ENABLED = True +TASK_FETCHER = True +TASK_INDEXER = True + +# Core functionality configuration +STORAGE_IFACE_INSTANCE = DefaultOMMStore( + signal_types=[PdqSignal, VideoMD5Signal], + content_types=[PhotoContent, VideoContent], + exchange_types=[StaticSampleSignalExchangeAPI], +) diff --git a/open-media-match/src/OpenMediaMatch/app.py b/open-media-match/src/OpenMediaMatch/app.py index c9ba7360f..98470c621 100644 --- a/open-media-match/src/OpenMediaMatch/app.py +++ b/open-media-match/src/OpenMediaMatch/app.py @@ -75,11 +75,15 @@ def create_app() -> flask.Flask: SQLALCHEMY_DATABASE_URI=app.config.get("DATABASE_URI"), SQLALCHEMY_TRACK_MODIFICATIONS=False, ) - # Probably better to move this into a more normal looking default config - storage_cls = t.cast( - t.Type[IUnifiedStore], app.config.get("storage_cls", DefaultOMMStore) - ) - app.config["storage_instance"] = storage_cls.init_flask(app) + + if "STORAGE_IFACE_INSTANCE" not in app.config: + app.logger.warning("No storage class provided, using the default") + app.config["STORAGE_IFACE_INSTANCE"] = DefaultOMMStore() + storage = app.config["STORAGE_IFACE_INSTANCE"] + assert isinstance( + storage, IUnifiedStore + ), "STORAGE_IFACE_INSTANCE is not an instance of IUnifiedStore" + storage.init_flask(app) _setup_task_logging(app.logger) @@ -98,7 +102,7 @@ def status(): storage = get_storage() if not storage.is_ready(): return "NOT-READY", 503 - return "I-AM-ALIVE\n", 200 + return "I-AM-ALIVE", 200 @app.route("/site-map") def site_map(): @@ -195,27 +199,32 @@ def build_indices(): # We only want to run apscheduler in debug mode # and only in the "outer" reloader process if _is_dbg_werkzeug_reloaded_process(): - app.logger.critical( - "DEVELOPMENT: Started background tasks with apscheduler." - ) - now = datetime.datetime.now() scheduler = dev_apscheduler.get_apscheduler() scheduler.init_app(app) - scheduler.add_job( - "Fetcher", - fetcher.apscheduler_fetch_all, - trigger="interval", - seconds=60 * 4, - start_date=now + datetime.timedelta(seconds=30), - ) - scheduler.add_job( - "Indexer", - build_index.apscheduler_build_all_indices, - trigger="interval", - seconds=60, - start_date=now + datetime.timedelta(seconds=15), - ) + tasks = [] + if app.config.get("TASK_FETCHER", False): + tasks.append("Fetcher") + scheduler.add_job( + "Fetcher", + fetcher.apscheduler_fetch_all, + trigger="interval", + seconds=60 * 4, + start_date=now + datetime.timedelta(seconds=30), + ) + if app.config.get("TASK_INDEXER", False): + tasks.append("Indexer") + scheduler.add_job( + "Indexer", + build_index.apscheduler_build_all_indices, + trigger="interval", + seconds=60, + start_date=now + datetime.timedelta(seconds=15), + ) + if tasks: + app.logger.critical( + "DEVELOPMENT: Started %s with apscheduler.", ", ".join(tasks) + ) scheduler.start() return app diff --git a/open-media-match/src/OpenMediaMatch/background_tasks/fetcher.py b/open-media-match/src/OpenMediaMatch/background_tasks/fetcher.py index 910272fe9..8d32a65f4 100644 --- a/open-media-match/src/OpenMediaMatch/background_tasks/fetcher.py +++ b/open-media-match/src/OpenMediaMatch/background_tasks/fetcher.py @@ -48,9 +48,25 @@ def fetch( collab_store: ISignalExchangeStore, signal_type_cfgs: t.Mapping[str, SignalTypeConfig], collab: CollaborationConfigBase, +): + """Wrapper for exception recording""" + try: + collab_store.exchange_start_fetch(collab.name) + _fetch(collab_store, signal_type_cfgs, collab) + except Exception: + logger.exception("%s[%s] Failed to fetch!", collab.name, collab.api) + collab_store.exchange_complete_fetch( + collab.name, is_up_to_date=False, exception=True + ) + + +def _fetch( + collab_store: ISignalExchangeStore, + signal_type_cfgs: t.Mapping[str, SignalTypeConfig], + collab: CollaborationConfigBase, ): """ - Fetch data from + Fetch data for a single collab. 1. Attempt to authenticate with that collaboration's API using stored credentials. @@ -65,14 +81,11 @@ def fetch( log("Fetching signals for %s from %s", collab.name, collab.api) api_cls = collab_store.exchange_get_type_configs().get(collab.api) - if api_cls is None: - log( - "No such SignalExchangeAPI '%s' - maybe it was deleted?" - " You might have serious misconfiguration", - level=logger.critical, - ) - return - api_client = collab_store.exchange_get_api_instance(api_cls.get_name()) + assert ( + api_cls is not None + ), f"No such SignalExchangeAPI '{collab.api}' - maybe it was deleted?" + + api_client = api_cls.for_collab(collab) starting_checkpoint = collab_store.exchange_get_fetch_checkpoint(collab.name) checkpoint = starting_checkpoint @@ -105,52 +118,33 @@ def fetch( fetch_start = time.time() last_db_commit = fetch_start - pending_merge: t.Optional[FetchDeltaTyped] = None up_to_date = False - exception = False + pending_merge: t.Optional[FetchDeltaTyped] = None - collab_store.exchange_start_fetch(collab.name) - try: - it = api_client.fetch_iter(signal_types, checkpoint) - delta: FetchDeltaTyped - for delta in it: - assert delta.checkpoint is not None # Infinite loop protection - progress_time = delta.checkpoint.get_progress_timestamp() - log( - "fetch_iter() with %d new records%s", - len(delta.updates), - ("" if progress_time is None else f" @ {_timeformat(progress_time)}"), - level=logger.debug, - ) - pending_merge = _merge_delta(pending_merge, delta) - next_checkpoint = delta.checkpoint - - if checkpoint is not None: - prev_time = checkpoint.get_progress_timestamp() - if prev_time is not None and progress_time is not None: - assert prev_time <= progress_time, ( - "checkpoint time rewound? ", - "This can indicate a serious ", - "problem with the API and checkpointing", - ) - checkpoint = next_checkpoint # Only used for the rewind check - - if _should_commit(pending_merge, last_db_commit): - log("Committing progress...") - collab_store.exchange_commit_fetch( - collab, - starting_checkpoint, - pending_merge.updates, - pending_merge.checkpoint, + delta: FetchDeltaTyped + for delta in api_client.fetch_iter(signal_types, checkpoint): + assert delta.checkpoint is not None # Infinite loop protection + progress_time = delta.checkpoint.get_progress_timestamp() + log( + "fetch_iter() with %d new records%s", + len(delta.updates), + ("" if progress_time is None else f" @ {_timeformat(progress_time)}"), + level=logger.debug, + ) + pending_merge = _merge_delta(pending_merge, delta) + next_checkpoint = delta.checkpoint + + if checkpoint is not None: + prev_time = checkpoint.get_progress_timestamp() + if prev_time is not None and progress_time is not None: + assert prev_time <= progress_time, ( + "checkpoint time rewound? ", + "This can indicate a serious ", + "problem with the API and checkpointing", ) - starting_checkpoint = pending_merge.checkpoint - pending_merge = None - last_db_commit = time.time() - if _hit_single_config_limit(fetch_start): - log("Hit limit for one config fetch") - return + checkpoint = next_checkpoint # Only used for the rewind check - if pending_merge is not None: + if _should_commit(pending_merge, last_db_commit): log("Committing progress...") collab_store.exchange_commit_fetch( collab, @@ -158,17 +152,29 @@ def fetch( pending_merge.updates, pending_merge.checkpoint, ) + starting_checkpoint = pending_merge.checkpoint + pending_merge = None + last_db_commit = time.time() + if _hit_single_config_limit(fetch_start): + log("Hit limit for one config fetch") + break + else: up_to_date = True log("Fetched all data! Up to date!") - except Exception: - log("failed to fetch!", level=logger.exception) - exception = True - return - finally: - collab_store.exchange_complete_fetch( - collab.name, is_up_to_date=up_to_date, exception=exception + + if pending_merge is not None: + log("Committing progress...") + collab_store.exchange_commit_fetch( + collab, + starting_checkpoint, + pending_merge.updates, + pending_merge.checkpoint, ) + collab_store.exchange_complete_fetch( + collab.name, is_up_to_date=up_to_date, exception=False + ) + def _merge_delta( into: t.Optional[FetchDeltaTyped], new: FetchDeltaTyped diff --git a/open-media-match/src/OpenMediaMatch/persistence.py b/open-media-match/src/OpenMediaMatch/persistence.py index f269f0b8c..13e74af86 100644 --- a/open-media-match/src/OpenMediaMatch/persistence.py +++ b/open-media-match/src/OpenMediaMatch/persistence.py @@ -27,4 +27,4 @@ def get_storage() -> IUnifiedStore: Holdover from earlier development, maybe remove someday. """ - return t.cast(IUnifiedStore, current_app.config["storage_instance"]) + return t.cast(IUnifiedStore, current_app.config["STORAGE_IFACE_INSTANCE"]) diff --git a/open-media-match/src/OpenMediaMatch/storage/interface.py b/open-media-match/src/OpenMediaMatch/storage/interface.py index 85b06680b..433100ce5 100644 --- a/open-media-match/src/OpenMediaMatch/storage/interface.py +++ b/open-media-match/src/OpenMediaMatch/storage/interface.py @@ -229,12 +229,6 @@ def exchange_get_type_configs(self) -> t.Mapping[str, TSignalExchangeAPICls]: Return all installed SignalExchange types. """ - @abc.abstractmethod - def exchange_get_api_instance(self, api_cls_name: str) -> TSignalExchangeAPI: - """ - Returns an initialized and authenticated client for an API. - """ - @abc.abstractmethod def exchange_update( self, cfg: CollaborationConfigBase, *, create: bool = False @@ -505,8 +499,7 @@ class IUnifiedStore( mostly for typing. """ - @classmethod - def init_flask(cls, app: flask.Flask) -> t.Self: + def init_flask(cls, app: flask.Flask) -> None: """ Make any flask-specific initialization for this storage implementation @@ -514,7 +507,7 @@ def init_flask(cls, app: flask.Flask) -> t.Self: you to write __init__ how is most useful to your implementation for testing. """ - return cls() + return @abc.abstractmethod def is_ready(self) -> bool: diff --git a/open-media-match/src/OpenMediaMatch/storage/postgres/impl.py b/open-media-match/src/OpenMediaMatch/storage/postgres/impl.py index 3a35b205e..5f0765213 100644 --- a/open-media-match/src/OpenMediaMatch/storage/postgres/impl.py +++ b/open-media-match/src/OpenMediaMatch/storage/postgres/impl.py @@ -13,15 +13,18 @@ from sqlalchemy.sql.expression import ClauseElement, Executable from sqlalchemy.ext.compiler import compiles -from threatexchange.utils import dataclass_json +from threatexchange.exchanges.impl.static_sample import StaticSampleSignalExchangeAPI from threatexchange.signal_type.pdq.signal import PdqSignal from threatexchange.signal_type.md5 import VideoMD5Signal +from threatexchange.content_type.photo import PhotoContent +from threatexchange.content_type.video import VideoContent from threatexchange.exchanges.signal_exchange_api import ( TSignalExchangeAPICls, TSignalExchangeAPI, ) from threatexchange.signal_type.index import SignalTypeIndex from threatexchange.signal_type.signal_base import SignalType +from threatexchange.content_type.content_base import ContentType from threatexchange.exchanges.fetch_state import ( FetchCheckpointBase, CollaborationConfigBase, @@ -55,53 +58,62 @@ class DefaultOMMStore(interface.IUnifiedStore): * Blobstore (e.g. built indices) """ - signal_types: list[t.Type[SignalType]] + signal_types: t.Mapping[str, t.Type[SignalType]] + content_types: t.Mapping[str, t.Type[ContentType]] + exchange_types: t.Mapping[str, TSignalExchangeAPICls] - def __init__(self, signal_types: list[t.Type[SignalType]]) -> None: - self.signal_types = signal_types - if signal_types is not None: - assert isinstance(signal_types, list) - for element in signal_types: - assert issubclass(element, SignalType) - self.signal_types = signal_types + def __init__( + self, + *, + signal_types: t.Sequence[t.Type[SignalType]] | None = None, + content_types: t.Sequence[t.Type[ContentType]] | None = None, + exchange_types: t.Sequence[TSignalExchangeAPICls] | None = None, + ) -> None: + if signal_types is None: + signal_types = [PdqSignal, VideoMD5Signal] + if content_types is None: + content_types = [PhotoContent, VideoContent] + if exchange_types is None: + exchange_types = [StaticSampleSignalExchangeAPI] + + self.signal_types = {st.get_name(): st for st in signal_types} + self.content_types = {ct.get_name(): ct for ct in content_types} + self.exchange_types = {et.get_name(): et for et in exchange_types} assert len(self.signal_types) == len( - {s.get_name() for s in self.signal_types} - ), "All signal must have unique names" + signal_types + ), "All signal types must have unique names" + assert len(self.content_types) == len( + content_types + ), "All content types must have unique names" + assert len(self.exchange_types) == len( + exchange_types + ), "All exchange types must have unique names" def is_ready(self) -> bool: """ Whether we have finished pre-loading indices. """ - return True + return True # TODO def get_content_type_configs(self) -> t.Mapping[str, interface.ContentTypeConfig]: - # TODO - return MockedUnifiedStore().get_content_type_configs() + return { + name: interface.ContentTypeConfig(True, ct) + for name, ct in self.content_types.items() + } def exchange_get_type_configs(self) -> t.Mapping[str, TSignalExchangeAPICls]: - # TODO - return MockedUnifiedStore().exchange_get_type_configs() - - def exchange_get_api_instance(self, api_cls_name: str) -> TSignalExchangeAPI: - # TODO - return MockedUnifiedStore().exchange_get_api_instance(api_cls_name) + return self.exchange_types def get_signal_type_configs(self) -> t.Mapping[str, SignalTypeConfig]: # If a signal is installed, then it is enabled by default. But it may be disabled by an # override in the database. signal_type_overrides = self._query_signal_type_overrides() - get_enabled_ratio = ( - lambda s: signal_type_overrides[s.get_name()] - if s.get_name() in signal_type_overrides - else 1.0 - ) return { - s.get_name(): interface.SignalTypeConfig( - # Note - we do this logic here because this function is re-executed each request - get_enabled_ratio(s), - s, + name: interface.SignalTypeConfig( + signal_type_overrides.get(name, 1.0), + st, ) - for s in self.signal_types + for name, st in self.signal_types.items() } def _create_or_update_signal_type_override( @@ -562,17 +574,12 @@ def bank_yield_content( for row in partition: yield row._tuple()[0].as_iteration_item() - @classmethod - def init_flask(cls, app: flask.Flask) -> t.Self: + def init_flask(self, app: flask.Flask) -> None: migrate = flask_migrate.Migrate() database.db.init_app(app) migrate.init_app(app, database.db) - flask_utils.add_cli_commands(app) - signal_types = app.config.get("SIGNAL_TYPES", [PdqSignal, VideoMD5Signal]) - return cls(signal_types) - def explain(q, analyze: bool = False): """ diff --git a/open-media-match/src/OpenMediaMatch/tests/test_api.py b/open-media-match/src/OpenMediaMatch/tests/test_api.py index f8b3c9d87..98e1710c4 100644 --- a/open-media-match/src/OpenMediaMatch/tests/test_api.py +++ b/open-media-match/src/OpenMediaMatch/tests/test_api.py @@ -18,7 +18,7 @@ def test_status_response(client: FlaskClient): response = client.get("/status") assert response.status_code == 200 - assert response.data == b"I-AM-ALIVE\n" + assert response.data == b"I-AM-ALIVE" def test_banks_empty_index(client: FlaskClient):