Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[omm] Improvements to the omm_config #1480

Merged
merged 2 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions open-media-match/.devcontainer/omm_config.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
# This is the configuration that is used by default for the developer instance
# which runs in the dev container by default. Every config field is present
# to make it easier to copy this

from OpenMediaMatch.storage.postgres.impl import DefaultOMMStore
from threatexchange.signal_type.pdq.signal import PdqSignal
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.content_type.photo import PhotoContent
from threatexchange.content_type.video import VideoContent
from threatexchange.exchanges.impl.static_sample import StaticSampleSignalExchangeAPI

# Database configuration
PRODUCTION = False
DBUSER = "media_match"
DBPASS = "hunter2"
DBHOST = "localhost"
DBNAME = "media_match"
DATABASE_URI = f"postgresql+psycopg2://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}"

# Role configuration
PRODUCTION = False
ROLE_HASHER = True
ROLE_MATCHER = True
ROLE_CURATOR = True

# Installed signal types
SIGNAL_TYPES = [PdqSignal, VideoMD5Signal]


# APScheduler (background threads for development)
SCHEDULER_API_ENABLED = True
TASK_FETCHER = True
TASK_INDEXER = True

# Core functionality configuration
STORAGE_IFACE_INSTANCE = DefaultOMMStore(
signal_types=[PdqSignal, VideoMD5Signal],
content_types=[PhotoContent, VideoContent],
exchange_types=[StaticSampleSignalExchangeAPI],
)
57 changes: 33 additions & 24 deletions open-media-match/src/OpenMediaMatch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,15 @@ def create_app() -> flask.Flask:
SQLALCHEMY_DATABASE_URI=app.config.get("DATABASE_URI"),
SQLALCHEMY_TRACK_MODIFICATIONS=False,
)
# Probably better to move this into a more normal looking default config
storage_cls = t.cast(
t.Type[IUnifiedStore], app.config.get("storage_cls", DefaultOMMStore)
)
app.config["storage_instance"] = storage_cls.init_flask(app)

if "STORAGE_IFACE_INSTANCE" not in app.config:
app.logger.warning("No storage class provided, using the default")
app.config["STORAGE_IFACE_INSTANCE"] = DefaultOMMStore()
storage = app.config["STORAGE_IFACE_INSTANCE"]
assert isinstance(
storage, IUnifiedStore
), "STORAGE_IFACE_INSTANCE is not an instance of IUnifiedStore"
storage.init_flask(app)

_setup_task_logging(app.logger)

Expand All @@ -98,7 +102,7 @@ def status():
storage = get_storage()
if not storage.is_ready():
return "NOT-READY", 503
return "I-AM-ALIVE\n", 200
return "I-AM-ALIVE", 200

@app.route("/site-map")
def site_map():
Expand Down Expand Up @@ -195,27 +199,32 @@ def build_indices():
# We only want to run apscheduler in debug mode
# and only in the "outer" reloader process
if _is_dbg_werkzeug_reloaded_process():
app.logger.critical(
"DEVELOPMENT: Started background tasks with apscheduler."
)

now = datetime.datetime.now()
scheduler = dev_apscheduler.get_apscheduler()
scheduler.init_app(app)
scheduler.add_job(
"Fetcher",
fetcher.apscheduler_fetch_all,
trigger="interval",
seconds=60 * 4,
start_date=now + datetime.timedelta(seconds=30),
)
scheduler.add_job(
"Indexer",
build_index.apscheduler_build_all_indices,
trigger="interval",
seconds=60,
start_date=now + datetime.timedelta(seconds=15),
)
tasks = []
if app.config.get("TASK_FETCHER", False):
tasks.append("Fetcher")
scheduler.add_job(
"Fetcher",
fetcher.apscheduler_fetch_all,
trigger="interval",
seconds=60 * 4,
start_date=now + datetime.timedelta(seconds=30),
)
if app.config.get("TASK_INDEXER", False):
tasks.append("Indexer")
scheduler.add_job(
"Indexer",
build_index.apscheduler_build_all_indices,
trigger="interval",
seconds=60,
start_date=now + datetime.timedelta(seconds=15),
)
if tasks:
app.logger.critical(
"DEVELOPMENT: Started %s with apscheduler.", ", ".join(tasks)
)
scheduler.start()

return app
122 changes: 64 additions & 58 deletions open-media-match/src/OpenMediaMatch/background_tasks/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,25 @@ def fetch(
collab_store: ISignalExchangeStore,
signal_type_cfgs: t.Mapping[str, SignalTypeConfig],
collab: CollaborationConfigBase,
):
"""Wrapper for exception recording"""
try:
collab_store.exchange_start_fetch(collab.name)
_fetch(collab_store, signal_type_cfgs, collab)
except Exception:
logger.exception("%s[%s] Failed to fetch!", collab.name, collab.api)
collab_store.exchange_complete_fetch(
collab.name, is_up_to_date=False, exception=True
)


def _fetch(
collab_store: ISignalExchangeStore,
signal_type_cfgs: t.Mapping[str, SignalTypeConfig],
collab: CollaborationConfigBase,
):
"""
Fetch data from
Fetch data for a single collab.

1. Attempt to authenticate with that collaboration's API
using stored credentials.
Expand All @@ -65,14 +81,11 @@ def fetch(
log("Fetching signals for %s from %s", collab.name, collab.api)

api_cls = collab_store.exchange_get_type_configs().get(collab.api)
if api_cls is None:
log(
"No such SignalExchangeAPI '%s' - maybe it was deleted?"
" You might have serious misconfiguration",
level=logger.critical,
)
return
api_client = collab_store.exchange_get_api_instance(api_cls.get_name())
assert (
api_cls is not None
), f"No such SignalExchangeAPI '{collab.api}' - maybe it was deleted?"

api_client = api_cls.for_collab(collab)

starting_checkpoint = collab_store.exchange_get_fetch_checkpoint(collab.name)
checkpoint = starting_checkpoint
Expand Down Expand Up @@ -105,70 +118,63 @@ def fetch(

fetch_start = time.time()
last_db_commit = fetch_start
pending_merge: t.Optional[FetchDeltaTyped] = None
up_to_date = False
exception = False
pending_merge: t.Optional[FetchDeltaTyped] = None

collab_store.exchange_start_fetch(collab.name)
try:
it = api_client.fetch_iter(signal_types, checkpoint)
delta: FetchDeltaTyped
for delta in it:
assert delta.checkpoint is not None # Infinite loop protection
progress_time = delta.checkpoint.get_progress_timestamp()
log(
"fetch_iter() with %d new records%s",
len(delta.updates),
("" if progress_time is None else f" @ {_timeformat(progress_time)}"),
level=logger.debug,
)
pending_merge = _merge_delta(pending_merge, delta)
next_checkpoint = delta.checkpoint

if checkpoint is not None:
prev_time = checkpoint.get_progress_timestamp()
if prev_time is not None and progress_time is not None:
assert prev_time <= progress_time, (
"checkpoint time rewound? ",
"This can indicate a serious ",
"problem with the API and checkpointing",
)
checkpoint = next_checkpoint # Only used for the rewind check

if _should_commit(pending_merge, last_db_commit):
log("Committing progress...")
collab_store.exchange_commit_fetch(
collab,
starting_checkpoint,
pending_merge.updates,
pending_merge.checkpoint,
delta: FetchDeltaTyped
for delta in api_client.fetch_iter(signal_types, checkpoint):
assert delta.checkpoint is not None # Infinite loop protection
progress_time = delta.checkpoint.get_progress_timestamp()
log(
"fetch_iter() with %d new records%s",
len(delta.updates),
("" if progress_time is None else f" @ {_timeformat(progress_time)}"),
level=logger.debug,
)
pending_merge = _merge_delta(pending_merge, delta)
next_checkpoint = delta.checkpoint

if checkpoint is not None:
prev_time = checkpoint.get_progress_timestamp()
if prev_time is not None and progress_time is not None:
assert prev_time <= progress_time, (
"checkpoint time rewound? ",
"This can indicate a serious ",
"problem with the API and checkpointing",
)
starting_checkpoint = pending_merge.checkpoint
pending_merge = None
last_db_commit = time.time()
if _hit_single_config_limit(fetch_start):
log("Hit limit for one config fetch")
return
checkpoint = next_checkpoint # Only used for the rewind check

if pending_merge is not None:
if _should_commit(pending_merge, last_db_commit):
log("Committing progress...")
collab_store.exchange_commit_fetch(
collab,
starting_checkpoint,
pending_merge.updates,
pending_merge.checkpoint,
)
starting_checkpoint = pending_merge.checkpoint
pending_merge = None
last_db_commit = time.time()
if _hit_single_config_limit(fetch_start):
log("Hit limit for one config fetch")
break
else:
up_to_date = True
log("Fetched all data! Up to date!")
except Exception:
log("failed to fetch!", level=logger.exception)
exception = True
return
finally:
collab_store.exchange_complete_fetch(
collab.name, is_up_to_date=up_to_date, exception=exception

if pending_merge is not None:
log("Committing progress...")
collab_store.exchange_commit_fetch(
collab,
starting_checkpoint,
pending_merge.updates,
pending_merge.checkpoint,
)

collab_store.exchange_complete_fetch(
collab.name, is_up_to_date=up_to_date, exception=False
)


def _merge_delta(
into: t.Optional[FetchDeltaTyped], new: FetchDeltaTyped
Expand Down
2 changes: 1 addition & 1 deletion open-media-match/src/OpenMediaMatch/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ def get_storage() -> IUnifiedStore:

Holdover from earlier development, maybe remove someday.
"""
return t.cast(IUnifiedStore, current_app.config["storage_instance"])
return t.cast(IUnifiedStore, current_app.config["STORAGE_IFACE_INSTANCE"])
11 changes: 2 additions & 9 deletions open-media-match/src/OpenMediaMatch/storage/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,6 @@ def exchange_get_type_configs(self) -> t.Mapping[str, TSignalExchangeAPICls]:
Return all installed SignalExchange types.
"""

@abc.abstractmethod
def exchange_get_api_instance(self, api_cls_name: str) -> TSignalExchangeAPI:
"""
Returns an initialized and authenticated client for an API.
"""

@abc.abstractmethod
def exchange_update(
self, cfg: CollaborationConfigBase, *, create: bool = False
Expand Down Expand Up @@ -505,16 +499,15 @@ class IUnifiedStore(
mostly for typing.
"""

@classmethod
def init_flask(cls, app: flask.Flask) -> t.Self:
def init_flask(cls, app: flask.Flask) -> None:
"""
Make any flask-specific initialization for this storage implementation

This serves as the normal constructor when used with OMM, which allows
you to write __init__ how is most useful to your implementation for
testing.
"""
return cls()
return

@abc.abstractmethod
def is_ready(self) -> bool:
Expand Down
Loading
Loading