From 2a5c21198040e95a25ac8ee7d3eb5e2b588664c1 Mon Sep 17 00:00:00 2001 From: "Victor Mao (main)" Date: Tue, 26 Aug 2025 16:45:35 -0400 Subject: [PATCH 1/7] Fixing task ID replacement for MNP jobs on AWS Batch --- metaflow/plugins/aws/batch/batch_client.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index bf0f6a824e7..6d020f54c82 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -4,6 +4,7 @@ import random import time import hashlib +import re try: unicode @@ -104,11 +105,17 @@ def execute(self): ) secondary_commands = self.payload["containerOverrides"]["command"][-1] # other tasks do not have control- prefix, and have the split id appended to the task -id - secondary_commands = secondary_commands.replace( - self._task_id, - self._task_id.replace("control-", "") - + "-node-$AWS_BATCH_JOB_NODE_INDEX", - ) + # Fix: Only replace task ID in specific arguments, not in environment variables + + # Replace task ID in --task-id argument + task_id_pattern = r'--task-id\s+' + re.escape(self._task_id) + replacement_task_id = self._task_id.replace("control-", "") + "-node-$AWS_BATCH_JOB_NODE_INDEX" + secondary_commands = re.sub(task_id_pattern, f'--task-id {replacement_task_id}', secondary_commands) + + # Replace task ID in MF_PATHSPEC environment variable (pathspec format: flow_name/run_id/step_name/task_id) + pathspec_pattern = r'(MF_PATHSPEC=[^/]+/[^/]+/[^/]+/)' + re.escape(self._task_id) + r'(\s|$|;)' + pathspec_replacement = r'\g<1>' + replacement_task_id + r'\g<2>' + secondary_commands = re.sub(pathspec_pattern, pathspec_replacement, secondary_commands) secondary_commands = secondary_commands.replace( "ubf_control", "ubf_task", From ce15127b20975ce3d3c962dac62cb613597fe53a Mon Sep 17 00:00:00 2001 From: "Victor Mao (main)" Date: Thu, 28 Aug 2025 10:48:34 -0400 Subject: [PATCH 2/7] Modifying so that we can have better/earlier matches for places to replace with node-index --- metaflow/plugins/aws/batch/batch_cli.py | 2 ++ metaflow/plugins/aws/batch/batch_client.py | 28 +++++++--------------- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 9d3eb4cbaa0..1cfcf6171dd 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -244,6 +244,8 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs): if num_parallel and num_parallel > 1: # For multinode, we need to add a placeholder that can be mutated by the caller step_args += " [multinode-args]" + # Add task ID placeholder for secondary nodes + step_args += f" --task-id {kwargs['task_id']}[NODE-INDEX]" step_cli = "{entrypoint} {top_args} step {step} {step_args}".format( entrypoint=entrypoint, top_args=top_args, diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index 6d020f54c82..6b53aef4d9a 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -4,7 +4,6 @@ import random import time import hashlib -import re try: unicode @@ -97,6 +96,8 @@ def execute(self): commands = self.payload["containerOverrides"]["command"][-1] # add split-index as this worker is also an ubf_task commands = commands.replace("[multinode-args]", "--split-index 0") + # For main node, remove the placeholder since it keeps the original task ID + commands = commands.replace("[NODE-INDEX]", "") main_task_override["command"][-1] = commands # secondary tasks @@ -104,24 +105,13 @@ def execute(self): self.payload["containerOverrides"] ) secondary_commands = self.payload["containerOverrides"]["command"][-1] - # other tasks do not have control- prefix, and have the split id appended to the task -id - # Fix: Only replace task ID in specific arguments, not in environment variables - - # Replace task ID in --task-id argument - task_id_pattern = r'--task-id\s+' + re.escape(self._task_id) - replacement_task_id = self._task_id.replace("control-", "") + "-node-$AWS_BATCH_JOB_NODE_INDEX" - secondary_commands = re.sub(task_id_pattern, f'--task-id {replacement_task_id}', secondary_commands) - - # Replace task ID in MF_PATHSPEC environment variable (pathspec format: flow_name/run_id/step_name/task_id) - pathspec_pattern = r'(MF_PATHSPEC=[^/]+/[^/]+/[^/]+/)' + re.escape(self._task_id) + r'(\s|$|;)' - pathspec_replacement = r'\g<1>' + replacement_task_id + r'\g<2>' - secondary_commands = re.sub(pathspec_pattern, pathspec_replacement, secondary_commands) - secondary_commands = secondary_commands.replace( - "ubf_control", - "ubf_task", - ) - secondary_commands = secondary_commands.replace( - "[multinode-args]", "--split-index $AWS_BATCH_JOB_NODE_INDEX" + # For secondary nodes: remove "control-" prefix and replace placeholders + secondary_commands = ( + secondary_commands.replace( + "control-[NODE-INDEX]", "-node-$AWS_BATCH_JOB_NODE_INDEX" + ) + .replace("ubf_control", "ubf_task") + .replace("[multinode-args]", "--split-index $AWS_BATCH_JOB_NODE_INDEX") ) secondary_task_container_override["command"][-1] = secondary_commands From 86c3b8473e8210d2d40e73fd424a7c5d47f979a4 Mon Sep 17 00:00:00 2001 From: "Victor Mao (main)" Date: Thu, 28 Aug 2025 11:19:52 -0400 Subject: [PATCH 3/7] Fixing step_kwargs conflict with logs writing --- metaflow/plugins/aws/batch/batch_cli.py | 9 ++++++--- metaflow/plugins/aws/batch/batch_client.py | 5 ++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 1cfcf6171dd..6aec01c6cb4 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -239,13 +239,16 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs): } kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys()) - step_args = " ".join(util.dict_to_cli_options(kwargs)) + # For multinode, create modified kwargs for command construction only num_parallel = num_parallel or 0 + step_kwargs = kwargs.copy() + if num_parallel and num_parallel > 1: + step_kwargs["task_id"] = f"{kwargs['task_id']}[NODE-INDEX]" + + step_args = " ".join(util.dict_to_cli_options(step_kwargs)) if num_parallel and num_parallel > 1: # For multinode, we need to add a placeholder that can be mutated by the caller step_args += " [multinode-args]" - # Add task ID placeholder for secondary nodes - step_args += f" --task-id {kwargs['task_id']}[NODE-INDEX]" step_cli = "{entrypoint} {top_args} step {step} {step_args}".format( entrypoint=entrypoint, top_args=top_args, diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index 6b53aef4d9a..e296103bb46 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -107,9 +107,8 @@ def execute(self): secondary_commands = self.payload["containerOverrides"]["command"][-1] # For secondary nodes: remove "control-" prefix and replace placeholders secondary_commands = ( - secondary_commands.replace( - "control-[NODE-INDEX]", "-node-$AWS_BATCH_JOB_NODE_INDEX" - ) + secondary_commands.replace("control-", "") + .replace("[NODE-INDEX]", "-node-$AWS_BATCH_JOB_NODE_INDEX") .replace("ubf_control", "ubf_task") .replace("[multinode-args]", "--split-index $AWS_BATCH_JOB_NODE_INDEX") ) From f2ee2852c96b4a63f8cbba57e601b68448d01683 Mon Sep 17 00:00:00 2001 From: "Victor Mao (main)" Date: Thu, 28 Aug 2025 11:50:19 -0400 Subject: [PATCH 4/7] Making it so that the [NODE-INDEX] substitution gets passed to MF_PATHSPEC --- metaflow/plugins/aws/batch/batch_cli.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 6aec01c6cb4..e8b82d256ca 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -266,15 +266,26 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs): retry_deco[0].attributes.get("minutes_between_retries", 1) ) - # Set batch attributes + # Set batch attributes - use modified task_id for multinode to ensure MF_PATHSPEC has placeholder + task_spec_task_id = ( + step_kwargs["task_id"] if num_parallel > 1 else kwargs["task_id"] + ) task_spec = { + "flow_name": ctx.obj.flow.name, + "step_name": step_name, + "run_id": kwargs["run_id"], + "task_id": task_spec_task_id, + "retry_count": str(retry_count), + } + # Keep attrs clean with original task_id for metadata + main_task_spec = { "flow_name": ctx.obj.flow.name, "step_name": step_name, "run_id": kwargs["run_id"], "task_id": kwargs["task_id"], "retry_count": str(retry_count), } - attrs = {"metaflow.%s" % k: v for k, v in task_spec.items()} + attrs = {"metaflow.%s" % k: v for k, v in main_task_spec.items()} attrs["metaflow.user"] = util.get_username() attrs["metaflow.version"] = ctx.obj.environment.get_environment_info()[ "metaflow_version" From cc5b44e7f2df83fe7268fedbca25abbd9e40c8f6 Mon Sep 17 00:00:00 2001 From: "Victor Mao (main)" Date: Mon, 22 Sep 2025 23:45:30 -0400 Subject: [PATCH 5/7] Updating flow for MNP --- metaflow/plugins/aws/batch/batch_cli.py | 9 +- metaflow/plugins/aws/batch/batch_client.py | 38 +++++++- metaflow/plugins/aws/batch/batch_decorator.py | 86 ++++++++++++++++++- 3 files changed, 130 insertions(+), 3 deletions(-) diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 7bf69448361..ec4f5304044 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -252,7 +252,9 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs): num_parallel = num_parallel or 0 step_kwargs = kwargs.copy() if num_parallel and num_parallel > 1: - step_kwargs["task_id"] = f"{kwargs['task_id']}[NODE-INDEX]" + # Pass task_id via an env var so shell can expand node index at runtime. + # Using a value starting with '$' prevents quoting in dict_to_cli_options. + step_kwargs["task_id"] = "$MF_TASK_ID_BASE[NODE-INDEX]" step_args = " ".join(util.dict_to_cli_options(step_kwargs)) if num_parallel and num_parallel > 1: @@ -318,6 +320,11 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs): if split_vars: env.update(split_vars) + # For multinode, provide the base task id to be expanded in the container + if num_parallel and num_parallel > 1: + # Ensure we don't carry a possible 'control-' prefix into worker IDs + env["MF_TASK_ID_BASE"] = str(kwargs["task_id"]).replace("control-", "") + if retry_count: ctx.obj.echo_always( "Sleeping %d minutes before the next AWS Batch retry" diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index e296103bb46..4c3eb9c750b 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -4,6 +4,7 @@ import random import time import hashlib +import os try: unicode @@ -19,7 +20,34 @@ class BatchClient(object): def __init__(self): from ..aws_client import get_aws_client - self._client = get_aws_client("batch") + # Prefer the task role by default when running inside AWS Batch containers + # by temporarily removing higher-precedence env credentials for this process. + # This avoids AMI-injected AWS_* env vars from overriding the task role. + # Outside of Batch, we leave env vars untouched unless explicitly opted-in. + if "AWS_BATCH_JOB_ID" in os.environ: + _aws_env_keys = [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "AWS_PROFILE", + "AWS_DEFAULT_PROFILE", + ] + _present = [k for k in _aws_env_keys if k in os.environ] + print( + "[Metaflow] AWS credential-related env vars present before Batch client init:", + _present, + ) + _saved_env = { + k: os.environ.pop(k) for k in _aws_env_keys if k in os.environ + } + try: + self._client = get_aws_client("batch") + finally: + # Restore prior env for the rest of the process + for k, v in _saved_env.items(): + os.environ[k] = v + else: + self._client = get_aws_client("batch") def active_job_queues(self): paginator = self._client.get_paginator("describe_job_queues") @@ -404,6 +432,14 @@ def _register_job_definition( self.num_parallel = num_parallel or 0 if self.num_parallel >= 1: + # Set the ulimit of number of open files to 65536. This is because we cannot set it easily once worker processes start on Batch. + # job_definition["containerProperties"]["linuxParameters"]["ulimits"] = [ + # { + # "name": "nofile", + # "softLimit": 65536, + # "hardLimit": 65536, + # } + # ] job_definition["type"] = "multinode" job_definition["nodeProperties"] = { "numNodes": self.num_parallel, diff --git a/metaflow/plugins/aws/batch/batch_decorator.py b/metaflow/plugins/aws/batch/batch_decorator.py index 6d64ed994aa..8d51db49d45 100644 --- a/metaflow/plugins/aws/batch/batch_decorator.py +++ b/metaflow/plugins/aws/batch/batch_decorator.py @@ -421,6 +421,89 @@ def _wait_for_mapper_tasks(self, flow, step_name): TIMEOUT = 600 last_completion_timeout = time.time() + TIMEOUT print("Waiting for batch secondary tasks to finish") + + # Prefer Batch API when metadata is local (nodes can't share local metadata files). + # If metadata isn't bound yet but we are on Batch, also prefer Batch API. + md = getattr(self, "metadata", None) + if md is not None and md.TYPE == "local": + return self._wait_for_mapper_tasks_batch_api( + flow, step_name, last_completion_timeout + ) + if md is None and "AWS_BATCH_JOB_ID" in os.environ: + return self._wait_for_mapper_tasks_batch_api( + flow, step_name, last_completion_timeout + ) + return self._wait_for_mapper_tasks_metadata( + flow, step_name, last_completion_timeout + ) + + def _wait_for_mapper_tasks_batch_api( + self, flow, step_name, last_completion_timeout + ): + """ + Poll the shared datastore (S3) for DONE markers for each mapper task. + This avoids relying on a metadata service or local metadata files. + """ + from metaflow.datastore.task_datastore import TaskDataStore + + pathspecs = getattr(flow, "_control_mapper_tasks", []) + total = len(pathspecs) + if total == 0: + print("No mapper tasks discovered for datastore wait; returning") + return True + + print("Waiting for mapper DONE markers in datastore for %d tasks" % total) + poll_sleep = 3.0 + while last_completion_timeout > time.time(): + time.sleep(poll_sleep) + completed = 0 + for ps in pathspecs: + try: + parts = ps.split("/") + if len(parts) == 3: + run_id, step, task_id = parts + else: + # Fallback in case of unexpected format + run_id, step, task_id = self.run_id, step_name, parts[-1] + tds = TaskDataStore( + self.flow_datastore, + run_id, + step, + task_id, + mode="r", + allow_not_done=True, + ) + if tds.has_metadata(TaskDataStore.METADATA_DONE_SUFFIX): + completed += 1 + except Exception as e: + if os.environ.get("METAFLOW_DEBUG_BATCH_POLL") in ( + "1", + "true", + "True", + ): + print("Datastore wait: error checking %s: %s" % (ps, e)) + continue + if completed == total: + print("All mapper tasks have written DONE markers to datastore") + return True + print( + "Waiting for mapper DONE markers. Finished: %d/%d" % (completed, total) + ) + poll_sleep = min(poll_sleep * 1.25, 10.0) + + raise Exception( + "Batch secondary workers did not finish in %s seconds (datastore wait)" + % (time.time() - (last_completion_timeout - 600)) + ) + + def _wait_for_mapper_tasks_metadata(self, flow, step_name, last_completion_timeout): + """ + Polls Metaflow metadata (Step client) for task completion. + Works with service-backed metadata providers but can fail with local metadata + in multi-node setups due to isolated per-node filesystems. + """ + from metaflow import Step + while last_completion_timeout > time.time(): time.sleep(2) try: @@ -441,7 +524,8 @@ def _wait_for_mapper_tasks(self, flow, step_name): except Exception: pass raise Exception( - "Batch secondary workers did not finish in %s seconds" % TIMEOUT + "Batch secondary workers did not finish in %s seconds" + % (time.time() - (last_completion_timeout - 600)) ) @classmethod From 9fa43918af2f628a05ea65e5df1e94cf6ccd9299 Mon Sep 17 00:00:00 2001 From: "Victor Mao (main)" Date: Thu, 30 Oct 2025 16:44:58 -0400 Subject: [PATCH 6/7] Resolving comments --- metaflow/plugins/aws/batch/batch_decorator.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/metaflow/plugins/aws/batch/batch_decorator.py b/metaflow/plugins/aws/batch/batch_decorator.py index 8d51db49d45..b6a4802d040 100644 --- a/metaflow/plugins/aws/batch/batch_decorator.py +++ b/metaflow/plugins/aws/batch/batch_decorator.py @@ -426,18 +426,18 @@ def _wait_for_mapper_tasks(self, flow, step_name): # If metadata isn't bound yet but we are on Batch, also prefer Batch API. md = getattr(self, "metadata", None) if md is not None and md.TYPE == "local": - return self._wait_for_mapper_tasks_batch_api( + return self._wait_for_mapper_tasks_datastore( flow, step_name, last_completion_timeout ) if md is None and "AWS_BATCH_JOB_ID" in os.environ: - return self._wait_for_mapper_tasks_batch_api( + return self._wait_for_mapper_tasks_datastore( flow, step_name, last_completion_timeout ) return self._wait_for_mapper_tasks_metadata( flow, step_name, last_completion_timeout ) - def _wait_for_mapper_tasks_batch_api( + def _wait_for_mapper_tasks_datastore( self, flow, step_name, last_completion_timeout ): """ @@ -476,17 +476,14 @@ def _wait_for_mapper_tasks_batch_api( if tds.has_metadata(TaskDataStore.METADATA_DONE_SUFFIX): completed += 1 except Exception as e: - if os.environ.get("METAFLOW_DEBUG_BATCH_POLL") in ( - "1", - "true", - "True", - ): - print("Datastore wait: error checking %s: %s" % (ps, e)) + self.logger.warning("Datastore wait: error checking %s: %s", ps, e) continue if completed == total: - print("All mapper tasks have written DONE markers to datastore") + self.logger.info( + "All mapper tasks have written DONE markers to datastore" + ) return True - print( + self.logger.info( "Waiting for mapper DONE markers. Finished: %d/%d" % (completed, total) ) poll_sleep = min(poll_sleep * 1.25, 10.0) @@ -515,7 +512,7 @@ def _wait_for_mapper_tasks_metadata(self, flow, step_name, last_completion_timeo ): # for some reason task.finished fails return True else: - print( + self.logger.info( "Waiting for all parallel tasks to finish. Finished: {}/{}".format( len(tasks), len(flow._control_mapper_tasks), From a0f68ca95b2695b7110b8706eea1fa231589d588 Mon Sep 17 00:00:00 2001 From: "Victor Mao (main)" Date: Thu, 30 Oct 2025 16:46:33 -0400 Subject: [PATCH 7/7] Cleaning up code --- metaflow/plugins/aws/batch/batch_client.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index 4c3eb9c750b..28ef783cc60 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -432,14 +432,6 @@ def _register_job_definition( self.num_parallel = num_parallel or 0 if self.num_parallel >= 1: - # Set the ulimit of number of open files to 65536. This is because we cannot set it easily once worker processes start on Batch. - # job_definition["containerProperties"]["linuxParameters"]["ulimits"] = [ - # { - # "name": "nofile", - # "softLimit": 65536, - # "hardLimit": 65536, - # } - # ] job_definition["type"] = "multinode" job_definition["nodeProperties"] = { "numNodes": self.num_parallel,