Skip to content

Commit 5433bd4

Browse files
authored
Cancel in progress gpu jobs (#20)
1 parent 478b069 commit 5433bd4

File tree

4 files changed

+76
-16
lines changed

4 files changed

+76
-16
lines changed

.github/actions/pytest-gpu/action.yaml

+19
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,22 @@ runs:
125125
--gpu_num ${{ inputs.gpu_num }} \
126126
--git_ssh_clone ${{ inputs.git_ssh_clone }} \
127127
${REF_ARGS}
128+
- name: Follow Run Logs
129+
shell: bash
130+
env:
131+
MOSAICML_API_KEY: ${{ inputs.mcloud_api_key }}
132+
run: |
133+
set -ex
134+
135+
python .github/mcli/follow_mcli_logs.py \
136+
--name '${{ steps.tests.outputs.RUN_NAME }}'
137+
- name: Stop Run if Cancelled
138+
if: ${{ cancelled() }}
139+
shell: bash
140+
env:
141+
MOSAICML_API_KEY: ${{ inputs.mcloud_api_key }}
142+
run: |
143+
set -ex
144+
145+
python .github/mcli/cancel_mcli_run.py \
146+
--name '${{ steps.tests.outputs.RUN_NAME }}'

.github/mcli/cancel_mcli_run.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2024 MosaicML CI-Testing authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import argparse
5+
6+
from mcli import RunStatus, get_run, stop_run, wait_for_run_status
7+
8+
"""Cancel an MCLI run."""
9+
10+
if __name__ == '__main__':
11+
12+
parser = argparse.ArgumentParser()
13+
parser.add_argument('--name', type=str, required=True, help='Name of run')
14+
args = parser.parse_args()
15+
16+
run = get_run(args.name)
17+
18+
print('[GHA] Stopping run.')
19+
stop_run(run)
20+
21+
# Wait until run stops
22+
run = wait_for_run_status(run, status=RunStatus.STOPPED)
23+
print('[GHA] Run stopped.')

.github/mcli/follow_mcli_logs.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2024 MosaicML CI-Testing authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import argparse
5+
6+
from mcli import RunStatus, follow_run_logs, get_run, wait_for_run_status
7+
8+
"""Follow MCLI run logs."""
9+
10+
if __name__ == '__main__':
11+
12+
parser = argparse.ArgumentParser()
13+
parser.add_argument('--name', type=str, required=True, help='Name of run')
14+
args = parser.parse_args()
15+
16+
run = get_run(args.name)
17+
18+
# Wait until run starts before fetching logs
19+
run = wait_for_run_status(run, status='running')
20+
print('[GHA] Run started. Following logs...')
21+
22+
# Print logs
23+
for line in follow_run_logs(run):
24+
print(line, end='')
25+
26+
print('[GHA] Run completed. Waiting for run to finish...')
27+
run = wait_for_run_status(run, status=RunStatus.COMPLETED)
28+
29+
# Fail if command exited with non-zero exit code or timed out (didn't reach COMPLETED)
30+
assert run.status == RunStatus.COMPLETED, f'Run {run.name} did not complete: {run.status} ({run.reason})'

.github/mcli/mcli_pytest.py

+4-16
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
"""Run pytest using MCLI."""
55

66
import argparse
7-
import time
7+
import os
88

9-
from mcli import RunConfig, RunStatus, create_run, follow_run_logs, wait_for_run_status
9+
from mcli import RunConfig, create_run
1010

1111
if __name__ == '__main__':
1212

@@ -111,17 +111,5 @@
111111
run = create_run(config)
112112
print(f'[GHA] Run created: {run.name}')
113113

114-
# Wait until run starts before fetching logs
115-
run = wait_for_run_status(run, status='running')
116-
start_time = time.time()
117-
print('[GHA] Run started. Following logs...')
118-
119-
# Print logs
120-
for line in follow_run_logs(run):
121-
print(line, end='')
122-
123-
print('[GHA] Run completed. Waiting for run to finish...')
124-
run = wait_for_run_status(run, status=RunStatus.COMPLETED)
125-
126-
# Fail if command exited with non-zero exit code or timed out (didn't reach COMPLETED)
127-
assert run.status == RunStatus.COMPLETED, f'Run {run.name} did not complete: {run.status} ({run.reason})'
114+
with open(os.environ['GITHUB_OUTPUT'], 'a') as fh:
115+
print(f'RUN_NAME={run.name}', file=fh)

0 commit comments

Comments
 (0)