Skip to content

Commit cda9465

Browse files
authored
Feat: support cuda on modal (#35)
1 parent 593b1d6 commit cda9465

File tree

2 files changed

+59
-5
lines changed

2 files changed

+59
-5
lines changed

src/discord-cluster-manager/cogs/modal_cog.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ async def run_modal(
3030
script: discord.Attachment,
3131
gpu_type: app_commands.Choice[str],
3232
):
33-
if not script.filename.endswith(".py"):
33+
if not script.filename.endswith(".py") and not script.filename.endswith(".cu"):
3434
await interaction.response.send_message(
35-
"Please provide a Python (.py) file"
35+
"Please provide a Python (.py) or CUDA (.cu) file"
3636
)
3737
return
3838

@@ -55,12 +55,19 @@ async def run_modal(
5555
async def trigger_modal_run(self, script_content: str, filename: str) -> str:
5656
logger.info("Attempting to trigger Modal run")
5757

58-
from modal_runner import modal_app, run_script
58+
from modal_runner import modal_app
5959

6060
try:
61+
print(f"Running {filename} with Modal")
6162
with modal.enable_output():
6263
with modal_app.run():
63-
result = run_script.remote(script_content)
64+
if filename.endswith(".py"):
65+
from modal_runner import run_script
66+
67+
result = run_script.remote(script_content)
68+
elif filename.endswith(".cu"):
69+
from modal_runner import run_cuda_script
70+
result = run_cuda_script.remote(script_content)
6471
return result
6572
except Exception as e:
6673
logger.error(f"Error in trigger_modal_run: {str(e)}", exc_info=True)

src/discord-cluster-manager/modal_runner.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def timeout_handler(signum, frame):
3030
signal.signal(signal.SIGALRM, original_handler)
3131

3232

33-
3433
@modal_app.function(
3534
gpu="T4", image=Image.debian_slim(python_version="3.10").pip_install(["torch"])
3635
)
@@ -66,3 +65,51 @@ def run_script(script_content: str, timeout_seconds: int = 300) -> str:
6665
return f"Error executing script: {str(e)}"
6766
finally:
6867
sys.stdout = sys.__stdout__
68+
69+
70+
@modal_app.function(
71+
gpu="T4",
72+
image=Image.from_registry(
73+
"nvidia/cuda:12.6.0-devel-ubuntu24.04", add_python="3.11"
74+
),
75+
)
76+
def run_cuda_script(script_content: str, timeout_seconds: int = 600) -> str:
77+
import sys
78+
from io import StringIO
79+
import subprocess
80+
import os
81+
82+
output = StringIO()
83+
sys.stdout = output
84+
85+
try:
86+
with timeout(timeout_seconds):
87+
with open("script.cu", "w") as f:
88+
f.write(script_content)
89+
90+
# Compile the CUDA code
91+
compile_process = subprocess.run(
92+
["nvcc", "script.cu", "-o", "script.out"],
93+
capture_output=True,
94+
text=True,
95+
)
96+
97+
if compile_process.returncode != 0:
98+
return f"Compilation Error:\n{compile_process.stderr}"
99+
100+
run_process = subprocess.run(
101+
["./script.out"], capture_output=True, text=True
102+
)
103+
104+
return run_process.stdout
105+
106+
except TimeoutException as e:
107+
return f"Timeout Error: {str(e)}"
108+
except Exception as e:
109+
return f"Error: {str(e)}"
110+
finally:
111+
if os.path.exists("script.cu"):
112+
os.remove("script.cu")
113+
if os.path.exists("script.out"):
114+
os.remove("script.out")
115+
sys.stdout = sys.__stdout__

0 commit comments

Comments
 (0)