|
6 | 6 | """ |
7 | 7 | import contextlib |
8 | 8 | import copy |
9 | | -import glob |
10 | 9 | import json |
11 | 10 | import logging |
12 | 11 | import os |
|
17 | 16 | import sys |
18 | 17 | import tempfile |
19 | 18 | import time |
20 | | -from distutils.dir_util import copy_tree |
21 | 19 | from pathlib import Path |
22 | 20 | from tempfile import NamedTemporaryFile |
23 | 21 | from typing import Callable, Optional |
@@ -1211,35 +1209,74 @@ def _generate_runtime_report_file(self, fit_mem_usage: float, pred_mem_usage: fl |
1211 | 1209 | json.dump(report_information, open(output_path, "w")) |
1212 | 1210 |
|
1213 | 1211 |
|
1214 | | -def output_in_code_dir(code_dir, output_dir): |
1215 | | - """Does the code directory house the output directory?""" |
1216 | | - code_abs_path = os.path.abspath(code_dir) |
1217 | | - output_abs_path = os.path.abspath(output_dir) |
1218 | | - return os.path.commonpath([code_dir, output_abs_path]) == code_abs_path |
| 1212 | +def _output_in_code_dir(code_dir: Path, output_dir: Path) -> bool: |
| 1213 | + """Return True if output_dir is inside code_dir.""" |
| 1214 | + try: |
| 1215 | + output_dir.resolve().relative_to(code_dir.resolve()) |
| 1216 | + return True |
| 1217 | + except ValueError: |
| 1218 | + return False |
1219 | 1219 |
|
1220 | 1220 |
|
1221 | | -def create_custom_inference_model_folder(code_dir, output_dir): |
1222 | | - readme = """ |
1223 | | - This folder was generated by the DRUM tool. It provides functionality for making |
1224 | | - predictions using the model trained by DRUM |
| 1221 | +def _copy_tree(src_dir: Path, dst_dir: Path) -> set[Path]: |
1225 | 1222 | """ |
1226 | | - files_in_output = set(glob.glob(output_dir + "/**")) |
1227 | | - if output_in_code_dir(code_dir, output_dir): |
1228 | | - # since the output directory is in the code directory use a tempdir to copy into first and |
1229 | | - # cleanup files and prevent errors related to copying the output into itself. |
1230 | | - with tempfile.TemporaryDirectory() as tempdir: |
1231 | | - copy_tree(code_dir, tempdir) |
1232 | | - # remove the temporary version of the target dir |
1233 | | - shutil.rmtree(os.path.join(tempdir, os.path.relpath(output_dir, code_dir))) |
1234 | | - shutil.rmtree(os.path.join(tempdir, "__pycache__"), ignore_errors=True) |
1235 | | - copied_files = set(copy_tree(tempdir, output_dir)) |
| 1223 | + Recursively copy contents of src_dir into dst_dir. |
| 1224 | + Returns a set of all copied file paths. |
| 1225 | + """ |
| 1226 | + copied_files = set() |
| 1227 | + dst_dir.mkdir(parents=True, exist_ok=True) |
| 1228 | + |
| 1229 | + for item in src_dir.iterdir(): |
| 1230 | + dst_item = dst_dir / item.name |
| 1231 | + if item.is_dir(): |
| 1232 | + shutil.copytree(item, dst_item, dirs_exist_ok=True) |
| 1233 | + copied_files.update(p for p in dst_item.rglob("*") if p.is_file()) |
| 1234 | + else: |
| 1235 | + shutil.copy2(item, dst_item) |
| 1236 | + copied_files.add(dst_item) |
| 1237 | + |
| 1238 | + return copied_files |
| 1239 | + |
| 1240 | + |
| 1241 | +def create_custom_inference_model_folder(code_dir: str, output_dir: str) -> None: |
| 1242 | + """ |
| 1243 | + Prepares a model inference folder by copying code_dir into output_dir, |
| 1244 | + avoiding recursive self-copying if output_dir is inside code_dir. |
| 1245 | + """ |
| 1246 | + code_path = Path(code_dir).resolve() |
| 1247 | + output_path = Path(output_dir).resolve() |
| 1248 | + |
| 1249 | + readme_content = ( |
| 1250 | + "This folder was generated by the DRUM tool. It provides functionality for making\n" |
| 1251 | + "predictions using the model trained by DRUM\n" |
| 1252 | + ) |
| 1253 | + |
| 1254 | + existing_files = set(p for p in output_path.rglob("*") if p.is_file()) |
| 1255 | + |
| 1256 | + if _output_in_code_dir(code_path, output_path): |
| 1257 | + with tempfile.TemporaryDirectory() as tmp: |
| 1258 | + tmp_path = Path(tmp) |
| 1259 | + _copy_tree(code_path, tmp_path) |
| 1260 | + |
| 1261 | + # Remove the output subdir inside the copied tree |
| 1262 | + rel_output = output_path.relative_to(code_path) |
| 1263 | + shutil.rmtree(tmp_path / rel_output, ignore_errors=True) |
| 1264 | + |
| 1265 | + # Clean up __pycache__ |
| 1266 | + shutil.rmtree(tmp_path / "__pycache__", ignore_errors=True) |
| 1267 | + |
| 1268 | + copied_files = _copy_tree(tmp_path, output_path) |
1236 | 1269 | else: |
1237 | | - copied_files = set(copy_tree(code_dir, output_dir)) |
1238 | | - shutil.rmtree(os.path.join(output_dir, "__pycache__"), ignore_errors=True) |
1239 | | - with open(os.path.join(output_dir, "README.md"), "w") as fp: |
1240 | | - fp.write(readme) |
1241 | | - if files_in_output & copied_files: |
1242 | | - print("Files were overwritten: {}".format(files_in_output & copied_files)) |
| 1270 | + copied_files = _copy_tree(code_path, output_path) |
| 1271 | + shutil.rmtree(output_path / "__pycache__", ignore_errors=True) |
| 1272 | + |
| 1273 | + # Add README |
| 1274 | + (output_path / "README.md").write_text(readme_content) |
| 1275 | + |
| 1276 | + # Check overwritten files |
| 1277 | + overwritten = existing_files & copied_files |
| 1278 | + if overwritten: |
| 1279 | + print(f"Files were overwritten: {sorted(overwritten)}") |
1243 | 1280 |
|
1244 | 1281 |
|
1245 | 1282 | def _get_default_numeric_param_value(param_config: Dict, cast_to_int: bool) -> Union[int, float]: |
|
0 commit comments