Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes for batch tagging branch #8

Open
wants to merge 6 commits into
base: fsat/pr-1627--rebased--2.12.28
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions metaflow/plugins/aws/batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def create_job(
efs_volumes=None,
use_tmpfs=None,
aws_batch_tags=None,
cli_aws_batch_tags=None,
tmpfs_tempdir=None,
tmpfs_size=None,
tmpfs_path=None,
Expand Down Expand Up @@ -329,14 +328,10 @@ def create_job(
if key in attrs:
k, v = sanitize_batch_tag(key, attrs.get(key))
job.tag(k, v)

if cli_aws_batch_tags is not None:
for tag in cli_aws_batch_tags:
job.tag(tag['key'], tag['value'])

if aws_batch_tags is not None:
for tag in aws_batch_tags:
job.tag(tag['key'], tag['value'])
job.tag(tag["key"], tag["value"])

return job

Expand Down
20 changes: 5 additions & 15 deletions metaflow/plugins/aws/batch/batch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from metaflow.metadata_provider.util import sync_local_metadata_from_datastore
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR
from metaflow.mflog import TASK_LOG_SOURCE
from metaflow.parameters import JSONTypeClass
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
from .batch import Batch, BatchKilledException
from metaflow.tagging_util import validate_tags, validate_aws_tag


@click.group()
Expand Down Expand Up @@ -47,6 +47,7 @@ def _execute_cmd(func, flow_name, run_id, user, my_runs, echo):

func(flow_name, run_id, user, echo)


@batch.command(help="List unfinished AWS Batch tasks of this flow")
@click.option(
"--my-runs",
Expand Down Expand Up @@ -146,7 +147,9 @@ def kill(ctx, run_id, user, my_runs):
help="Activate designated number of elastic fabric adapter devices. "
"EFA driver must be installed and instance type compatible with EFA",
)
@click.option("--aws-batch-tags", multiple=True, default=None, help="AWS tags. Format: key=value, multiple allowed")
@click.option(
"--aws-batch-tags", type=JSONTypeClass(), default=None, help="AWS Batch tags."
)
@click.option("--use-tmpfs", is_flag=True, help="tmpfs requirement for AWS Batch.")
@click.option("--tmpfs-tempdir", is_flag=True, help="tmpfs requirement for AWS Batch.")
@click.option("--tmpfs-size", help="tmpfs requirement for AWS Batch.")
Expand Down Expand Up @@ -275,19 +278,6 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs):
"metaflow_version"
]



if aws_batch_tags is not None:
if not isinstance(aws_batch_tags, list[str]):
raise CommandException("aws_tags must be list[str]")
aws_tags_list = [
{'key': item.split('=')[0],
'value': item.split('=')[1]} for item in aws_batch_tags.items()
]
for tag in aws_tags_list:
validate_aws_tag(tag)


env_deco = [deco for deco in node.decorators if deco.name == "environment"]
if env_deco:
env = env_deco[0].attributes["vars"]
Expand Down
43 changes: 31 additions & 12 deletions metaflow/plugins/aws/batch/batch_decorator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import platform
import sys
Expand Down Expand Up @@ -30,7 +31,7 @@
get_ec2_instance_metadata,
)
from .batch import BatchException
from metaflow.tagging_util import validate_tags, validate_aws_tag
from metaflow.tagging_util import validate_aws_tag


class BatchDecorator(StepDecorator):
Expand Down Expand Up @@ -185,24 +186,32 @@ def __init__(self, attributes=None, statically_defined=False):
if self.attributes["trainium"] is not None:
self.attributes["inferentia"] = self.attributes["trainium"]

if not isinstance(BATCH_DEFAULT_TAGS, dict) and not all(instance(k, str) and isinstance(v, str) for k, v in BATCH_DEFAULT_TAGS.items()):
raise BatchException("BATCH_DEFAULT_TAGS environment variable must be Dict[str, str]")
if not isinstance(BATCH_DEFAULT_TAGS, dict) and not all(
isinstance(k, str) and isinstance(v, str)
for k, v in BATCH_DEFAULT_TAGS.items()
):
raise BatchException(
"BATCH_DEFAULT_TAGS environment variable must be Dict[str, str]"
)
if self.attributes["aws_batch_tags"] is None:
self.attributes["aws_batch_tags"] = BATCH_DEFAULT_TAGS

if self.attributes["aws_batch_tags"] is not None:
if not isinstance(self.attributes["aws_tags"], dict) and not all(isinstance(k, str) and isinstance(v, str) for k, v in self.attributes["aws_batch_tags"].items()):
if not isinstance(self.attributes["aws_batch_tags"], dict) and not all(
isinstance(k, str) and isinstance(v, str)
for k, v in self.attributes["aws_batch_tags"].items()
):
raise BatchException("aws_batch_tags must be Dict[str, str]")

batch_default_tags_copy = BATCH_DEFAULT_TAGS.copy()
self.attributes["aws_batch_tags"] = batch_default_tags_copy.update(self.attributes["aws_batch_tags"])

self.attributes["aws_batch_tags"] = {
**BATCH_DEFAULT_TAGS,
**self.attributes["aws_batch_tags"],
}

decorator_aws_tags_list = [
{'key': key,
'value': val} for key, val in self.attributes["aws_batch_tags"].items()
{"key": key, "value": val}
for key, val in self.attributes["aws_batch_tags"].items()
]
for tag in decorator_aws_tags_list:
validate_aws_tag(tag)
self.attributes["aws_batch_tags"] = decorator_aws_tags_list

# clean up the alias attribute so it is not passed on.
Expand Down Expand Up @@ -237,6 +246,11 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge
if self.attributes["tmpfs_path"] and self.attributes["tmpfs_path"][0] != "/":
raise BatchException("'tmpfs_path' needs to be an absolute path")

# Validate Batch tags
if self.attributes["aws_batch_tags"]:
for tag in self.attributes["aws_batch_tags"]:
validate_aws_tag(tag)

def runtime_init(self, flow, graph, package, run_id):
# Set some more internal state.
self.flow = flow
Expand All @@ -260,7 +274,12 @@ def runtime_step_cli(
cli_args.commands = ["batch", "step"]
cli_args.command_args.append(self.package_sha)
cli_args.command_args.append(self.package_url)
cli_args.command_options.update(self.attributes)
for k, v in self.attributes.items():
# Some attributes need to be serialized for the CLI
if k in ["aws_batch_tags"]:
cli_args.command_options[k] = json.dumps(v)
else:
cli_args.command_options[k] = v
cli_args.command_options["run-time-limit"] = self.run_time_limit
if not R.use_r():
cli_args.entrypoint[0] = sys.executable
Expand Down
3 changes: 0 additions & 3 deletions metaflow/plugins/aws/step_functions/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(
event_logger,
monitor,
tags=None,
aws_batch_tags=None,
namespace=None,
username=None,
max_workers=None,
Expand All @@ -68,7 +67,6 @@ def __init__(
self.event_logger = event_logger
self.monitor = monitor
self.tags = tags
self.aws_batch_tags = aws_batch_tags
self.namespace = namespace
self.username = username
self.max_workers = max_workers
Expand Down Expand Up @@ -835,7 +833,6 @@ def _batch(self, node):
efa=resources["efa"],
use_tmpfs=resources["use_tmpfs"],
aws_batch_tags=resources["aws_batch_tags"],
cli_aws_batch_tags=self.aws_batch_tags,
tmpfs_tempdir=resources["tmpfs_tempdir"],
tmpfs_size=resources["tmpfs_size"],
tmpfs_path=resources["tmpfs_path"],
Expand Down
27 changes: 0 additions & 27 deletions metaflow/plugins/aws/step_functions/step_functions_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,6 @@ def step_functions(obj, name=None):
"with the given tag. You can specify this option multiple "
"times to attach multiple tags.",
)
@click.option(
"--aws-batch-tags",
"aws_batch_tags",
multiple=True,
default=None,
help="AWS tags.")
@click.option(
"--namespace",
"user_namespace",
Expand Down Expand Up @@ -150,7 +144,6 @@ def step_functions(obj, name=None):
def create(
obj,
tags=None,
aws_batch_tags=None,
user_namespace=None,
only_json=False,
authorize=None,
Expand Down Expand Up @@ -204,7 +197,6 @@ def create(
token,
obj.state_machine_name,
tags,
aws_batch_tags,
user_namespace,
max_workers,
workflow_timeout,
Expand Down Expand Up @@ -324,7 +316,6 @@ def make_flow(
token,
name,
tags,
aws_batch_tags,
namespace,
max_workers,
workflow_timeout,
Expand All @@ -347,23 +338,6 @@ def make_flow(
[obj.package.blob], len_hint=1
)[0]


if aws_batch_tags is not None:
if not all(isinstance(item, str) for item in aws_batch_tags):
raise MetaflowException("AWS Step Functions --aws-tags all items in list must be strings")
for item in aws_batch_tags:
if len(item.split('=')) != 2:
raise MetaflowException("AWS Step Functions --aws-tags strings must be in format 'key=value'")
aws_tags_list = [
{'key': item.split('=')[0],
'value': item.split('=')[1]} for item in aws_batch_tags
]
for tag in aws_tags_list:
validate_aws_tag(tag)
else: aws_tags_list = None



return StepFunctions(
name,
obj.graph,
Expand All @@ -377,7 +351,6 @@ def make_flow(
obj.event_logger,
obj.monitor,
tags=tags,
aws_batch_tags=aws_tags_list,
namespace=namespace,
max_workers=max_workers,
username=get_username(),
Expand Down