Skip to content

Commit 8036194

Browse files
authored
[RAPTOR-13911] refactor to remove distutils usage (#1528)
1 parent 95164b1 commit 8036194

File tree

5 files changed

+114
-40
lines changed

5 files changed

+114
-40
lines changed

custom_model_runner/datarobot_drum/drum/args_parser.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77
import argparse
88
import os
9-
import subprocess
109
import sys
1110
from urllib.parse import urlparse
1211

custom_model_runner/datarobot_drum/drum/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import logging
99
import os
1010
import sys
11+
import trafaret as t
1112
from contextvars import ContextVar
12-
from distutils.util import strtobool
1313
from urllib.parse import urlparse, urlunparse
1414

1515
from contextlib import contextmanager
@@ -126,7 +126,7 @@ def to_bool(value):
126126
return False
127127
if isinstance(value, bool):
128128
return value
129-
return strtobool(value)
129+
return t.ToBool().check(value)
130130

131131

132132
FIT_METADATA_FILENAME = "fit_runtime_data.json"

custom_model_runner/datarobot_drum/drum/data_marshalling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Released under the terms of DataRobot Tool and Utility Agreement.
66
"""
77
import logging
8-
from distutils.util import strtobool
8+
import trafaret as t
99
from typing import Any, List, Optional, Union
1010

1111
import numpy as np
@@ -81,9 +81,9 @@ def _standardize(label):
8181
except ValueError:
8282
pass
8383

84-
# Maybe if its a boolean we can make it floaty anyways
84+
# Maybe if it's a boolean we can make it floaty anyways
8585
try:
86-
return float(strtobool(label))
86+
return float(t.ToBool().check(label))
8787
except ValueError:
8888
pass
8989

custom_model_runner/datarobot_drum/drum/drum.py

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77
import contextlib
88
import copy
9-
import glob
109
import json
1110
import logging
1211
import os
@@ -17,7 +16,6 @@
1716
import sys
1817
import tempfile
1918
import time
20-
from distutils.dir_util import copy_tree
2119
from pathlib import Path
2220
from tempfile import NamedTemporaryFile
2321
from typing import Callable, Optional
@@ -1211,35 +1209,74 @@ def _generate_runtime_report_file(self, fit_mem_usage: float, pred_mem_usage: fl
12111209
json.dump(report_information, open(output_path, "w"))
12121210

12131211

1214-
def output_in_code_dir(code_dir, output_dir):
1215-
"""Does the code directory house the output directory?"""
1216-
code_abs_path = os.path.abspath(code_dir)
1217-
output_abs_path = os.path.abspath(output_dir)
1218-
return os.path.commonpath([code_dir, output_abs_path]) == code_abs_path
1212+
def _output_in_code_dir(code_dir: Path, output_dir: Path) -> bool:
1213+
"""Return True if output_dir is inside code_dir."""
1214+
try:
1215+
output_dir.resolve().relative_to(code_dir.resolve())
1216+
return True
1217+
except ValueError:
1218+
return False
12191219

12201220

1221-
def create_custom_inference_model_folder(code_dir, output_dir):
1222-
readme = """
1223-
This folder was generated by the DRUM tool. It provides functionality for making
1224-
predictions using the model trained by DRUM
1221+
def _copy_tree(src_dir: Path, dst_dir: Path) -> set[Path]:
12251222
"""
1226-
files_in_output = set(glob.glob(output_dir + "/**"))
1227-
if output_in_code_dir(code_dir, output_dir):
1228-
# since the output directory is in the code directory use a tempdir to copy into first and
1229-
# cleanup files and prevent errors related to copying the output into itself.
1230-
with tempfile.TemporaryDirectory() as tempdir:
1231-
copy_tree(code_dir, tempdir)
1232-
# remove the temporary version of the target dir
1233-
shutil.rmtree(os.path.join(tempdir, os.path.relpath(output_dir, code_dir)))
1234-
shutil.rmtree(os.path.join(tempdir, "__pycache__"), ignore_errors=True)
1235-
copied_files = set(copy_tree(tempdir, output_dir))
1223+
Recursively copy contents of src_dir into dst_dir.
1224+
Returns a set of all copied file paths.
1225+
"""
1226+
copied_files = set()
1227+
dst_dir.mkdir(parents=True, exist_ok=True)
1228+
1229+
for item in src_dir.iterdir():
1230+
dst_item = dst_dir / item.name
1231+
if item.is_dir():
1232+
shutil.copytree(item, dst_item, dirs_exist_ok=True)
1233+
copied_files.update(p for p in dst_item.rglob("*") if p.is_file())
1234+
else:
1235+
shutil.copy2(item, dst_item)
1236+
copied_files.add(dst_item)
1237+
1238+
return copied_files
1239+
1240+
1241+
def create_custom_inference_model_folder(code_dir: str, output_dir: str) -> None:
1242+
"""
1243+
Prepares a model inference folder by copying code_dir into output_dir,
1244+
avoiding recursive self-copying if output_dir is inside code_dir.
1245+
"""
1246+
code_path = Path(code_dir).resolve()
1247+
output_path = Path(output_dir).resolve()
1248+
1249+
readme_content = (
1250+
"This folder was generated by the DRUM tool. It provides functionality for making\n"
1251+
"predictions using the model trained by DRUM\n"
1252+
)
1253+
1254+
existing_files = set(p for p in output_path.rglob("*") if p.is_file())
1255+
1256+
if _output_in_code_dir(code_path, output_path):
1257+
with tempfile.TemporaryDirectory() as tmp:
1258+
tmp_path = Path(tmp)
1259+
_copy_tree(code_path, tmp_path)
1260+
1261+
# Remove the output subdir inside the copied tree
1262+
rel_output = output_path.relative_to(code_path)
1263+
shutil.rmtree(tmp_path / rel_output, ignore_errors=True)
1264+
1265+
# Clean up __pycache__
1266+
shutil.rmtree(tmp_path / "__pycache__", ignore_errors=True)
1267+
1268+
copied_files = _copy_tree(tmp_path, output_path)
12361269
else:
1237-
copied_files = set(copy_tree(code_dir, output_dir))
1238-
shutil.rmtree(os.path.join(output_dir, "__pycache__"), ignore_errors=True)
1239-
with open(os.path.join(output_dir, "README.md"), "w") as fp:
1240-
fp.write(readme)
1241-
if files_in_output & copied_files:
1242-
print("Files were overwritten: {}".format(files_in_output & copied_files))
1270+
copied_files = _copy_tree(code_path, output_path)
1271+
shutil.rmtree(output_path / "__pycache__", ignore_errors=True)
1272+
1273+
# Add README
1274+
(output_path / "README.md").write_text(readme_content)
1275+
1276+
# Check overwritten files
1277+
overwritten = existing_files & copied_files
1278+
if overwritten:
1279+
print(f"Files were overwritten: {sorted(overwritten)}")
12431280

12441281

12451282
def _get_default_numeric_param_value(param_config: Dict, cast_to_int: bool) -> Union[int, float]:

tests/unit/datarobot_drum/drum/test_drum.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import os
1010
import sys
11+
import tempfile
1112
from copy import deepcopy
1213
from pathlib import Path
1314
from tempfile import TemporaryDirectory, NamedTemporaryFile
@@ -23,7 +24,7 @@
2324
from datarobot_drum.drum.drum import (
2425
CMRunner,
2526
create_custom_inference_model_folder,
26-
output_in_code_dir,
27+
_output_in_code_dir,
2728
)
2829
from datarobot_drum.drum.enum import (
2930
RunLanguage,
@@ -498,12 +499,49 @@ def test_output_dir_copy(self):
498499
assert not Path(out_dir, "__pycache__").exists()
499500
assert not Path(out_dir, "out").exists()
500501

501-
def test_output_in_code_dir(self):
502-
code_dir = "/test/code/is/here"
503-
output_other = "/test/not/code"
504-
output_code_dir = "/test/code/is/here/output"
505-
assert not output_in_code_dir(code_dir, output_other)
506-
assert output_in_code_dir(code_dir, output_code_dir)
502+
503+
class TestUtilityFunctionsOutputInCommonDir:
504+
def test_output_inside_code_dir(self):
505+
with tempfile.TemporaryDirectory() as tmp:
506+
code_dir = Path(tmp)
507+
output_dir = code_dir / "subdir"
508+
output_dir.mkdir()
509+
assert _output_in_code_dir(code_dir, output_dir) is True
510+
511+
def test_output_is_code_dir(self):
512+
with tempfile.TemporaryDirectory() as tmp:
513+
path = Path(tmp)
514+
assert _output_in_code_dir(path, path) is True
515+
516+
def test_output_outside_code_dir(self):
517+
with tempfile.TemporaryDirectory() as tmp1, tempfile.TemporaryDirectory() as tmp2:
518+
code_dir = Path(tmp1)
519+
output_dir = Path(tmp2)
520+
assert _output_in_code_dir(code_dir, output_dir) is False
521+
522+
def test_output_sibling_directory(self):
523+
with tempfile.TemporaryDirectory() as tmp:
524+
parent = Path(tmp)
525+
code_dir = parent / "code"
526+
output_dir = parent / "output"
527+
code_dir.mkdir()
528+
output_dir.mkdir()
529+
assert _output_in_code_dir(code_dir, output_dir) is False
530+
531+
def test_relative_path_handling(self):
532+
with tempfile.TemporaryDirectory() as tmp:
533+
base = Path(tmp)
534+
code_dir = base / "code"
535+
output_dir = code_dir / "out"
536+
output_dir.mkdir(parents=True)
537+
538+
# Use relative paths
539+
cwd = Path.cwd()
540+
try:
541+
os.chdir(base)
542+
assert _output_in_code_dir(Path("code"), Path("code/out")) is True
543+
finally:
544+
os.chdir(cwd)
507545

508546

509547
class TestRuntimeParametersDockerCommand:

0 commit comments

Comments
 (0)