forked from NVIDIA/NeMo-Skills
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstart_server.py
166 lines (146 loc) · 6.17 KB
/
start_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import os
import subprocess
import sys
import time
from argparse import ArgumentParser
from pathlib import Path
try:
from huggingface_hub import get_token
except (ImportError, ModuleNotFoundError):
get_token = lambda: os.environ.get('HF_TOKEN', '')
# adding nemo_skills to python path to avoid requiring installation
sys.path.append(str(Path(__file__).absolute().parents[1]))
from launcher import CLUSTER_CONFIG, NEMO_SKILLS_CODE, get_server_command, launch_job
from nemo_skills.utils import setup_logging
SLURM_CMD = """
nvidia-smi && \
cd /code && \
export PYTHONPATH=/code && \
export HF_TOKEN={HF_TOKEN} && \
if [ $SLURM_PROCID -eq 0 ]; then \
{{ {server_start_cmd} 2>&1 | tee /tmp/server_logs.txt & }} && sleep 1 && \
echo "Waiting for the server to start" && \
tail -n0 -f /tmp/server_logs.txt | sed '/{server_wait_string}/ q' && \
tail -n10 /tmp/server_logs.txt && \
export NEMO_SKILLS_SERVER_HOST={hosthame_cmd} && \
echo "Server is running on $NEMO_SKILLS_SERVER_HOST" && \
{sandbox_echo} \
sleep infinity; \
else \
{server_start_cmd}; \
fi \
"""
MOUNTS = "{NEMO_SKILLS_CODE}:/code"
JOB_NAME = "interactive-server-{server_type}-{model_name}"
# TODO: nemo does not exit on ctrl+c, need to fix that
if __name__ == "__main__":
setup_logging(disable_hydra_logs=False)
parser = ArgumentParser()
parser.add_argument("--model_path", required=False, default=None, help="Path to the model file")
parser.add_argument('--model_name', required=False, default=None, help="Name of the HF model")
parser.add_argument("--server_type", choices=('nemo', 'tensorrt_llm', 'vllm'), default='tensorrt_llm')
parser.add_argument("--num_gpus", type=int, required=True)
parser.add_argument(
"--num_nodes",
type=int,
default=1,
help="Number of nodes required for hosting LLM server.",
)
parser.add_argument(
"--partition",
required=False,
help="Can specify if need interactive jobs or a specific non-default partition",
)
parser.add_argument(
"--no_sandbox", action="store_true", help="Disables sandbox if code execution is not required."
)
args = parser.parse_args()
# Assert that both model_path and model_name are provided
if args.model_path is not None and args.model_name is not None:
raise ValueError("Both model_path and model_name cannot be provided")
elif args.model_path is None and args.model_name is None:
raise ValueError("Either model_path or model_name must be provided")
if args.model_path is not None:
args.model_path = Path(args.model_path).absolute()
server_start_cmd, num_tasks, server_wait_string = get_server_command(
args.server_type,
args.num_gpus,
args.num_nodes,
args.model_path.name if args.model_path is not None else args.model_name,
)
# TODO: VLLM
sandbox_echo = 'echo "Sandbox is running on ${NEMO_SKILLS_SANDBOX_HOST:-$NEMO_SKILLS_SERVER_HOST}" &&'
hosthame_cmd = '`hostname -I`' if CLUSTER_CONFIG['cluster'] == 'local' else '`hostname`'
format_dict = {
"model_path": args.model_path,
"model_name": args.model_path.name if args.model_path is not None else args.model_name,
"num_gpus": args.num_gpus,
"server_start_cmd": server_start_cmd,
"server_type": args.server_type,
"NEMO_SKILLS_CODE": NEMO_SKILLS_CODE,
"HF_TOKEN": get_token(), # needed for some of the models, so making an option to pass it in
"server_wait_string": server_wait_string,
"sandbox_echo": sandbox_echo if not args.no_sandbox else "",
"hosthame_cmd": hosthame_cmd,
}
if args.model_path is not None:
MOUNTS += f",{args.model_path}:/model"
if os.environ.get("HF_HOME", None) is not None:
MOUNTS += f",{os.environ['HF_HOME']}:/cache/huggingface"
job_id = launch_job(
cmd=SLURM_CMD.format(**format_dict),
num_nodes=args.num_nodes,
tasks_per_node=num_tasks,
gpus_per_node=format_dict["num_gpus"],
job_name=JOB_NAME.format(**format_dict),
container=CLUSTER_CONFIG["containers"][args.server_type],
mounts=MOUNTS.format(**format_dict),
partition=args.partition,
with_sandbox=(not args.no_sandbox),
extra_sbatch_args=["--parsable"],
)
# the rest is only applicable for slurm execution - local execution will block on the launch_job call
if CLUSTER_CONFIG["cluster"] != "slurm":
sys.exit(0)
log_file = f"slurm-{job_id}.out"
# killing the serving job when exiting this script
atexit.register(
lambda job_id: subprocess.run(f"scancel {job_id}", shell=True, check=True),
job_id,
)
# also cleaning up logs
atexit.register(lambda log_file: os.remove(log_file), log_file)
print("Please wait while the server is starting!")
server_host = None
server_started = False
while True: # waiting for the server to start
time.sleep(1)
# checking the logs to see if server has started
if not os.path.isfile(log_file):
continue
with open(log_file) as fin:
for line in fin:
if "running on node" in line:
server_host = line.split()[-1].strip()
if server_wait_string in line:
server_started = True
if server_started:
print(f"Server has started at {server_host}")
break
print("Streaming server logs")
while True: # waiting for the kill signal and streaming logs
subprocess.run(f"tail -f {log_file}", shell=True, check=True)