Skip to content

Commit

Permalink
Sagemaker integration (#151)
Browse files Browse the repository at this point in the history
* remove relative imports

* sm-integration

* folder rename + args for region, arn, s3

* setup.py fixes

* cleanup tri mentions + address nits

* Updates for sagemaker integration

* sm fixes

* format + conflicts

---------

Co-authored-by: Achal Dave <[email protected]>
Co-authored-by: Achal Dave <[email protected]>
  • Loading branch information
3 people authored Dec 20, 2023
1 parent 176f869 commit 0da1e0c
Show file tree
Hide file tree
Showing 7 changed files with 286 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
venv
wandb
logs
checkpoints
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,6 @@ weights*
out*
tests/assets/*
.vscode/
secrets.env
checkpoints/
experiments/
Empty file added sagemaker_train/.dockerignore
Empty file.
34 changes: 34 additions & 0 deletions sagemaker_train/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
ARG AWS_REGION

# SageMaker PyTorch image
FROM 763104351884.dkr.ecr.${AWS_REGION}.amazonaws.com/pytorch-training:2.1.0-gpu-py310-cu121-ubuntu20.04-sagemaker

# Run custom installation of libraries
# RUN pip install xxx
# RUN apt-get update && apt-get install -y xxx
# ENV <your environment variables>
# etc....

# Remove the conda installed symlink for libcurl, which causes an error with curl.
# Fixes the following error:
# curl: /opt/conda/lib/libcurl.so.4: no version information available (required by curl)
RUN rm /opt/conda/lib/libcurl.so.4

ENV PATH="/opt/ml/code:${PATH}"

# this environment variable is used by the SageMaker PyTorch container to determine our user code directory.
ENV SAGEMAKER_SUBMIT_DIRECTORY /opt/ml/code

# /opt/ml and all subdirectories are utilized by SageMaker, use the /code subdirectory to store your user code.
COPY . /opt/ml/code/
RUN rm /opt/ml/code/setup.py

RUN pip install -r /opt/ml/code/requirements.txt
RUN pip uninstall flash-attn -y
RUN pip install flash-attn>=2.2
# # Prevent sagemaker from installing requirements again.
# RUN rm /opt/ml/code/setup.py
RUN rm /opt/ml/code/requirements.txt

# Defines a script entrypoint
ENV SAGEMAKER_PROGRAM open_lm/main.py
15 changes: 15 additions & 0 deletions sagemaker_train/Dockerfile_update
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
ARG BASE_DOCKER
# Dockerfile that updates the container with new code.
# SageMaker PyTorch image
FROM ${BASE_DOCKER}

# /opt/ml and all subdirectories are utilized by SageMaker, use the /code subdirectory to store your user code.
COPY . /opt/ml/code/

# RUN pip install -e /opt/ml/code/

# Prevent sagemaker from installing requirements again.
RUN rm /opt/ml/code/setup.py
RUN rm /opt/ml/code/requirements.txt

ENV SAGEMAKER_PROGRAM open_lm/main.py
38 changes: 38 additions & 0 deletions sagemaker_train/cfg_sample.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
accum-freq: 4
beta1: 0.9
beta2: 0.95
data-key: "json"
dataset-resampled: True
# delete-previous-checkpoint: False
# Total 25B * 40 = 1T tokens
epochs: 40
fsdp: True
fsdp-limit-all-gathers: True
# grad-checkpointing: False
grad-clip-norm: 1
log-every-n-steps: 20
model: "open_lm_7b"
name: "sample_7b"
precision: "amp_bfloat16"
report-to: "wandb"
seed: 124
train-data-mix-weights: [0.725, 0.275]
train-data: ["TODO"]
train-num-samples: 25_000_000_000
wandb-project-name: "lm1"
workers: 4
logs: /opt/ml/checkpoints/

# Some important parameters, double checked with Mitchell:
batch-size: 16
ffn-type: swiglu
# fsdp-amp: False
fsdp-pure-bf16: True
fsdp-backward-prefetch: True
lr: 3.e-4
lr-cooldown-end: 3.e-5
model-norm: "gain_only_lp_layer_norm"
qk-norm: True
warmup: 5000
wd: 0.1
z-loss-coefficient: 1.e-4
194 changes: 194 additions & 0 deletions sagemaker_train/launch_sagemaker_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import argparse
import time
import os
import subprocess
from datetime import datetime
from pathlib import Path

import boto3
import sagemaker
from sagemaker.pytorch import PyTorch


NAME = "openlm-main"
INSTANCE_MAPPER = {
"p4": "ml.p4d.24xlarge",
"p4de": "ml.p4de.24xlarge",
"p5": "ml.p5.48xlarge",
}


def run_command(command):
print(f"=> {command}")
subprocess.run(command, shell=True, check=True)


def get_image(user, instance_type, build_type=None, profile="poweruser", region="us-east-1"):
os.environ["AWS_PROFILE"] = f"{profile}"
account = subprocess.getoutput(
f"aws --region {region} --profile {profile} sts get-caller-identity --query Account --output text"
)
docker_dir = Path(__file__).parent
if instance_type in ("p4", "p4de"):
algorithm_name = f"{user}-{NAME}-p4"
dockerfile_base = docker_dir / "Dockerfile"
dockerfile_update = docker_dir / "Dockerfile_update"
elif instance_type == "p5":
algorithm_name = f"{user}-{NAME}-p5"
dockerfile_base = docker_dir / "Dockerfile"
dockerfile_update = docker_dir / "Dockerfile_update"
else:
raise ValueError(f"Unknown instance_type: {instance_type}")
fullname = f"{account}.dkr.ecr.{region}.amazonaws.com/{algorithm_name}:latest"
if build_type is None:
return fullname

login_cmd = f"aws ecr get-login-password --region {region} --profile {profile} | docker login --username AWS --password-stdin"

if build_type == "full":
print("Building container")
commands = [
# Log in to Sagemaker account to get image.
f"{login_cmd} 763104351884.dkr.ecr.{region}.amazonaws.com",
f"docker build --progress=plain -f {dockerfile_base} --build-arg AWS_REGION={region} -t {algorithm_name} .",
f"docker tag {algorithm_name} {fullname}",
f"{login_cmd} {fullname}",
(
f"aws --region {region} ecr describe-repositories --repository-names {algorithm_name} || "
f"aws --region {region} ecr create-repository --repository-name {algorithm_name}"
),
]
elif build_type == "update":
print("Updating container")
commands = [
f"docker build --progress=plain -f {dockerfile_update} --build-arg BASE_DOCKER={algorithm_name} -t {algorithm_name} .",
f"docker tag {algorithm_name} {fullname}",
f"{login_cmd} {fullname}",
]
else:
raise ValueError(f"Unknown build_type: {build_type}")

# Create command, making sure to exit if any part breaks.
command = "\n".join([f"{x} || exit 1" for x in commands])
run_command(command)
run_command(f"docker push {fullname}")
print("Sleeping for 5 seconds to ensure push succeeded")
time.sleep(5)
return f"{account}.dkr.ecr.{region}.amazonaws.com/{algorithm_name}:latest"


def main():
# Use first line of file docstring as description if it exists.
parser = argparse.ArgumentParser()
parser.add_argument("--build-type", choices=["full", "update"], help="Build image from scratch")
parser.add_argument("--local", action="store_true")
parser.add_argument("--user", required=True, help="User name")
parser.add_argument("--cfg-path", required=True, help="Launch config")

# AWS profile args
parser.add_argument("--region", default="us-east-1", help="AWS region")
parser.add_argument("--profile", default="poweruser", help="AWS profile to use")
parser.add_argument("--arn", default=None, help="If None, reads from SAGEMAKER_ARN env var")
parser.add_argument(
"--s3-remote-sync", default=None, help="S3 path to sync to. If none, reads from S3_REMOTE_SYNC env var"
)

# Instance args
parser.add_argument("--instance-count", default=1, type=int, help="Number of instances")
parser.add_argument("--instance-type", default="p4de", choices=list(INSTANCE_MAPPER.keys()))
parser.add_argument("--spot-instance", action="store_true")

args = parser.parse_args()
main_after_setup_move(args)


def main_after_setup_move(args):
if args.arn is None:
assert "SAGEMAKER_ARN" in os.environ, "Please specify --arn or set the SAGEMAKER_ARN environment variable"
args.arn = os.environ["SAGEMAKER_ARN"]

if args.s3_remote_sync is None:
assert (
"S3_REMOTE_SYNC" in os.environ
), "Please specify --s3-remote-sync or set the S3_REMOTE_SYNC environment variable"
args.s3_remote_sync = os.environ["S3_REMOTE_SYNC"]

image = get_image(
args.user,
args.instance_type,
region=args.region,
build_type=args.build_type,
profile=args.profile,
)

##########
# Create session and make sure of account and region
##########
sagemaker_session = sagemaker.Session(boto_session=boto3.session.Session(region_name=args.region))

role = args.arn
# provide a pre-existing role ARN as an alternative to creating a new role
role_name = role.split(["/"][-1])
print(f"SageMaker Execution Role:{role}")
print(f"The name of the Execution role: {role_name[-1]}")

client = boto3.client("sts")
account = client.get_caller_identity()["Account"]
print(f"AWS account:{account}")

session = boto3.session.Session()
region = session.region_name
print(f"AWS region:{region}")

##########
# Configure the training
##########
base_job_name = f"{args.user.replace('.', '-')}-{NAME}"

checkpoint_local_path = "/opt/ml/checkpoints"

def get_job_name(base):
now = datetime.now()
# Format example: 2023-03-03-10-14-02-324
now_ms_str = f"{now.microsecond // 1000:03d}"
date_str = f"{now.strftime('%Y-%m-%d-%H-%M-%S')}-{now_ms_str}"

job_name = "_".join([base, date_str])

return job_name

job_name = get_job_name(base_job_name)

output_root = f"{args.s3_remote_sync}/sagemaker/{args.user}/{NAME}/"
output_s3 = os.path.join(output_root, job_name)

estimator = PyTorch(
entry_point="open_lm/main.py",
sagemaker_session=sagemaker_session,
base_job_name=base_job_name,
hyperparameters={"config": args.cfg_path},
role=role,
image_uri=image,
instance_count=args.instance_count,
instance_type="local_gpu" if args.local else INSTANCE_MAPPER[args.instance_type],
train_use_spot_instances=args.spot_instance,
output_path=output_s3,
job_name=job_name,
checkpoint_s3_uri=None if args.local else f"{output_s3}/checkpoint",
checkpoint_local_path=None if args.local else checkpoint_local_path,
code_location=output_s3,
# Training using SMDataParallel Distributed Training Framework
distribution={"torch_distributed": {"enabled": True}},
# Max run 5 days
max_run=5 * 24 * 60 * 60,
max_wait=5 * 24 * 60 * 60 if args.spot_instance else None,
input_mode="FastFile",
# environment={"TORCH_DISTRIBUTED_DEBUG": "DETAIL", "TORCH_CPP_LOG_LEVEL": "INFO"},
keep_alive_period_in_seconds=30 * 60 if not args.spot_instance else None, # 30 minutes
)

estimator.fit()


if __name__ == "__main__":
main()

0 comments on commit 0da1e0c

Please sign in to comment.