diff --git a/noxfile.py b/noxfile.py index 7b89265..2e903fc 100644 --- a/noxfile.py +++ b/noxfile.py @@ -3,6 +3,9 @@ # > External packages import nox +# > Ignore certain dirs +ASSETS_DIR = "src/oet/assets" + # > Making sure Nox session only see their packages and not any globally installed packages. os.environ.pop("PYTHONPATH", None) # > Hiding any virtual environments from outside. @@ -37,7 +40,7 @@ @nox.session(tags=["static_check"]) def type_check(session): session.install(".[type-check]") - session.run("mypy") + session.run("mypy", "--exclude", f"^{ASSETS_DIR}/") # ////////////////////////////////////////////////// @@ -47,7 +50,7 @@ def type_check(session): def remove_unused_imports(session): session.install(".[lint]") # > Sorting imports with ruff instead of isort - session.run("ruff", "check", "--fix", "--select", "F401") + session.run("ruff", "check", "--fix", "--select", "F401", "--exclude", ASSETS_DIR) # ////////////////////////////////////////// @@ -57,7 +60,7 @@ def remove_unused_imports(session): def sort_imports(session): session.install(".[lint]") # > Sorting imports with ruff instead of isort - session.run("ruff", "check", "--fix", "--select", "I") + session.run("ruff", "check", "--fix", "--select", "I", "--exclude", ASSETS_DIR) # //////////////////////////////////////// @@ -66,7 +69,7 @@ def sort_imports(session): @nox.session(tags=["style", "static_check"]) def lint(session): session.install(".[lint]") - session.run("ruff", "check", "--fix") + session.run("ruff", "check", "--fix", "--exclude", ASSETS_DIR) # ////////////////////////////////////////// @@ -76,7 +79,7 @@ def lint(session): def format_code(session): # Installs the project + the "lint" extra into this nox venv using pip session.install(".[lint]") - session.run("ruff", "format") + session.run("ruff", "format", "--exclude", ASSETS_DIR) # //////////////////////////////////////////////////// @@ -85,7 +88,7 @@ def format_code(session): @nox.session(tags=["static_check"]) def spell_check(session): session.install(".[spell-check]") - session.run("codespell", "src/oet") + session.run("codespell", "src/oet", "--skip", ASSETS_DIR) # ////////////////////////////////////////////// @@ -94,4 +97,4 @@ def spell_check(session): @nox.session(tags=["static_check"], default=True) def dead_code(session): session.install(".[dead-code]") - session.run("vulture") + session.run("vulture", "src", "--exclude", ASSETS_DIR) diff --git a/src/oet/core/test_utilities.py b/src/oet/core/test_utilities.py index 068acea..8cb8188 100644 --- a/src/oet/core/test_utilities.py +++ b/src/oet/core/test_utilities.py @@ -2,8 +2,14 @@ Utilities used in the test suite """ +import multiprocessing as mp import subprocess +from enum import StrEnum from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from multiprocessing.queues import Queue WATER = [ ("O", 0.0000, 0.0000, 0.0000), @@ -195,3 +201,112 @@ def clear_files(basename: str) -> None: for f in dir_path.glob(basename + "*"): if f.is_file(): f.unlink() # remove file + + +def _worker( + fn: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any], q: "Queue[bool]" +) -> None: + """ + Helper for executing a function. + + Parameters + ---------- + fn: Callable[..., T] + Callable function that is executed. + args: tuple[Any] + Any positional arguments that are given to the function call. + kwargs: dict[str, Any] + Any keyword arguments that are given to the function call. + q: Queue[bool] + Queue used to put the function call. + """ + try: + # Call the function + _ = fn(*args, **kwargs) + # Don't check what it did, just return ok, if the function didn't crash + q.put(True) + except Exception: + q.put(False) + + +class TimeoutCallError(StrEnum): + """Possible errors that are returned by TimeoutCall""" + + # Function timed out + TIMEOUT = "timeout" + # Function crashed + CRASH = "crash" + # General error + ERROR = "error" + + +class TimeoutCall: + """ + Class for calling a function with a certain timeout. + Useful for functions that, e.g., download files. + Doesn't return the result of the function as it might not be pickled. + """ + + def __init__(self, fn: Callable[..., Any]) -> None: + """ + Initialization of the class. + + Parameters + ---------- + fn: Callable[..., T] + Callable function that is executed. + """ + self.fn = fn + self.timed_out = False + + def __call__( + self, *args: Any, timeout: float = 10, **kwargs: Any + ) -> tuple[bool, TimeoutCallError | None]: + """ + Execute the function set in __init__ with the timeout defined there. + + Parameters + ---------- + args: Any + Any positional arguments that are given to the function call. + timeout: float, default: 10 sec. + Timeout in sec. + kwargs: Any + Any keyword arguments that are given to the function call. + + Returns + ------- + bool + True, if everything was ok. False otherwise. + TimeoutCallError | None + Either the error type if failed or None. + """ + + # Start process and wait the timeout + q: "Queue[bool]" = mp.Queue() + p: mp.Process = mp.Process(target=_worker, args=(self.fn, args, kwargs, q)) + p.start() + p.join(timeout) + + # Check if there was any error + if p.exitcode not in (0, None): + return False, TimeoutCallError.CRASH + + # Check if the process is still alive. If yes, it has timed out. + if p.is_alive(): + p.terminate() + p.join() + return False, TimeoutCallError.TIMEOUT + + # Check if the worker provides the correct result + try: + status_ok = q.get(timeout=1) + except Exception: + return False, TimeoutCallError.CRASH + + # Check if there was a general error + if not status_ok: + return False, TimeoutCallError.ERROR + + # If everything went well, return True and no Error. + return True, None diff --git a/tests/aimnet2/test_aiment2_client.py b/tests/aimnet2/test_aiment2_client.py index 3c7e0ef..8244011 100644 --- a/tests/aimnet2/test_aiment2_client.py +++ b/tests/aimnet2/test_aiment2_client.py @@ -15,6 +15,7 @@ write_xyz_file, ) +# Path to the scripts, adjust if needed. aimnet2_script_path = ROOT_DIR / "../../bin/oet_client" aimnet2_server_path = ROOT_DIR / "../../bin/oet_server" # Default ID and port of server. Change if needed diff --git a/tests/aimnet2/test_aiment2_standalone.py b/tests/aimnet2/test_aiment2_standalone.py index cedfe56..2c88e3a 100644 --- a/tests/aimnet2/test_aiment2_standalone.py +++ b/tests/aimnet2/test_aiment2_standalone.py @@ -11,6 +11,7 @@ write_xyz_file, ) +# Path to the script, adjust if needed. aimnet2_script_path = ROOT_DIR / "../../bin/oet_aimnet2" diff --git a/tests/g-xtb/test_gxtb.py b/tests/g-xtb/test_gxtb.py index c1ca583..709d008 100644 --- a/tests/g-xtb/test_gxtb.py +++ b/tests/g-xtb/test_gxtb.py @@ -11,6 +11,7 @@ write_xyz_file, ) +# Path to the scripts, adjust if needed. gxtb_script_path = ROOT_DIR / "../../bin/oet_gxtb" # Leave uma_executable_path empty, if gxtb from system path should be called gxtb_executable_path = "" diff --git a/tests/mlatom/test_mlatom.py b/tests/mlatom/test_mlatom.py index 12eb806..790d3f0 100644 --- a/tests/mlatom/test_mlatom.py +++ b/tests/mlatom/test_mlatom.py @@ -6,6 +6,8 @@ from oet.core.test_utilities import ( OH, WATER, + TimeoutCall, + TimeoutCallError, get_filenames, read_result_file, run_wrapper, @@ -13,7 +15,10 @@ write_xyz_file, ) +# Path to the script, adjust if needed. mlatom_script_path = ROOT_DIR / "../../bin/oet_mlatom" +# Default maximum time (in sec) to download the model files if not present +timeout = 300 # Leave mlatom_executable_path empty, if mlatom from system path should be called mlatom_executable_path = "" @@ -33,7 +38,30 @@ class MLatomTests(unittest.TestCase): @classmethod def setUpClass(cls): # Force download / initialization of ANI-1ccx once - torchani.models.ANI1ccx(periodic_table_index=True) + print("Checking the model files and downloading them if necessary.") + # Make a timeout call to avoid hanging forever + get_ani1ccx_timeout = TimeoutCall(fn=torchani.models.ANI1ccx) + ok, payload = get_ani1ccx_timeout(timeout=timeout, periodic_table_index=True) + # Check if the model files could not be loaded + if not ok: + # Timeout + if payload == TimeoutCallError.TIMEOUT: + print( + "Loading the model files timed out. " + "Please check your internet connection and consider increasing the time before timing out." + ) + raise unittest.SkipTest("Timed out.") + # General errors and crashes + elif payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR: + print( + "Loading the model files failed. Make sure that " + "the virtual environment with MLAtom installed is active." + ) + raise unittest.SkipTest("Loading failed.") + # Unresolved error + else: + print("Could not load the model files.") + raise unittest.SkipTest("Loading failed.") def test_H2O_engrad(self): xyz_file, input_file, engrad_out, output_file = get_filenames("H2O") diff --git a/tests/mopac/test_mopac.py b/tests/mopac/test_mopac.py index b5a7644..c8860ea 100644 --- a/tests/mopac/test_mopac.py +++ b/tests/mopac/test_mopac.py @@ -11,6 +11,7 @@ write_xyz_file, ) +# Path to the script, adjust if needed. mopac_script_path = ROOT_DIR / "../../bin/oet_mopac" # Leave moppac_executable_path empty, if mopac from system path should be called mopac_executable_path = "" diff --git a/tests/uma/test_uma_client.py b/tests/uma/test_uma_client.py index 83986ae..0b253a1 100644 --- a/tests/uma/test_uma_client.py +++ b/tests/uma/test_uma_client.py @@ -5,9 +5,12 @@ import unittest from oet import ROOT_DIR +from oet.calculator.uma import DEFAULT_CACHE_DIR, UmaCalc from oet.core.test_utilities import ( OH, WATER, + TimeoutCall, + TimeoutCallError, get_filenames, read_result_file, run_wrapper, @@ -15,10 +18,34 @@ write_xyz_file, ) +# Path to the scripts, adjust if needed. uma_script_path = ROOT_DIR / "../../bin/oet_client" uma_server_path = ROOT_DIR / "../../bin/oet_server" +# Default maximum time (in sec) to download the model files if not present +timeout = 600 # Default ID and port of server. Change if needed id_port = "127.0.0.1:9000" +# UMA model to use +uma_model = "uma-s-1p1" + + +def cache_model_files( + basemodel: str, param: str = "omol", device: str = "cpu", cache_dir: str = DEFAULT_CACHE_DIR +) -> None: + """ + Wrapper to set up an UMA calculator that downloads the model files into the same cache-directory used for actual oet calculations. + + basemodel: str + Basemodel used to calculate the test cases + param: str, default: omol + Parameter set. + device str, default: cpu + Device used for the calculations. + cache_dir: str, default: DEFAULT_CACHE_DIR + The cache directory used to store the model data. + """ + calculator = UmaCalc() + calculator.set_calculator(param=param, basemodel=basemodel, device=device, cache_dir=cache_dir) def run_uma(inputfile: str, output_file: str) -> None: @@ -26,7 +53,7 @@ def run_uma(inputfile: str, output_file: str) -> None: inputfile=inputfile, script_path=uma_script_path, outfile=output_file, - args=["--bind", id_port], + args=["--bind", id_port, "--model", uma_model], ) @@ -36,6 +63,31 @@ def setUpClass(cls): """ Test starting the server """ + # Pre-download UMA model files + print("Checking the model files and downloading them if necessary.") + # Make a timeout call to avoid hanging forever + get_pretrained_mlip_timeout = TimeoutCall(fn=cache_model_files) + ok, payload = get_pretrained_mlip_timeout(uma_model, timeout=timeout) + # Check if the model files could not be loaded + if not ok: + # Timeout + if payload == TimeoutCallError.TIMEOUT: + print( + "Loading the model files timed out. " + "Please check your internet connection and consider increasing the time before timing out." + ) + raise unittest.SkipTest("Timed out.") + # General errors and crashes + elif payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR: + print( + "Loading the model files failed. Make sure that " + "the virtual environment with UMA installed is active." + ) + raise unittest.SkipTest("Loading failed.") + # Unresolved error + else: + print("Could not load the model files.") + raise unittest.SkipTest("Loading failed.") print("Starting the server. A detailed server log can be found on file server.out") with open("server.out", "a") as f: cls.server = subprocess.Popen( diff --git a/tests/uma/test_uma_standalone.py b/tests/uma/test_uma_standalone.py index 7a675ff..a8924c0 100644 --- a/tests/uma/test_uma_standalone.py +++ b/tests/uma/test_uma_standalone.py @@ -1,9 +1,12 @@ import unittest from oet import ROOT_DIR +from oet.calculator.uma import DEFAULT_CACHE_DIR, UmaCalc from oet.core.test_utilities import ( OH, WATER, + TimeoutCall, + TimeoutCallError, get_filenames, read_result_file, run_wrapper, @@ -11,14 +14,70 @@ write_xyz_file, ) +# Path to the script, adjust if needed. uma_script_path = ROOT_DIR / "../../bin/oet_uma" +# Default maximum time (in sec) to download the model files if not present +timeout = 600 +# UMA model to use +uma_model = "uma-s-1p1" + + +def cache_model_files( + basemodel: str, param: str = "omol", device: str = "cpu", cache_dir: str = DEFAULT_CACHE_DIR +) -> None: + """ + Wrapper to set up an UMA calculator that downloads the model files into the same cache-directory used for actual oet calculations. + + basemodel: str + Basemodel used to calculate the test cases + param: str, default: omol + Parameter set. + device str, default: cpu + Device used for the calculations. + cache_dir: str, default: DEFAULT_CACHE_DIR + The cache directory used to store the model data. + """ + calculator = UmaCalc() + calculator.set_calculator(param=param, basemodel=basemodel, device=device, cache_dir=cache_dir) def run_uma(inputfile: str, output_file: str) -> None: - run_wrapper(inputfile=inputfile, script_path=uma_script_path, outfile=output_file) + # Run the wrapper with an increased timeout as loading the UMA model files might take a while + run_wrapper(inputfile=inputfile, script_path=uma_script_path, outfile=output_file, timeout=30) class UmaTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + """ + Test starting the server + """ + # Pre-download UMA model files + print("Checking the model files and downloading them if necessary.") + # Make a timeout call to avoid hanging forever + get_pretrained_mlip_timeout = TimeoutCall(fn=cache_model_files) + ok, payload = get_pretrained_mlip_timeout(uma_model, timeout=timeout) + # Check if the model files could not be loaded + if not ok: + # Timeout + if payload == TimeoutCallError.TIMEOUT: + print( + "Loading the model files timed out. " + "Please check your internet connection and consider increasing the time before timing out." + ) + raise unittest.SkipTest("Timed out.") + # General errors and crashes + elif payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR: + print( + "Loading the model files failed. Make sure that " + "the virtual environment with UMA installed is active." + ) + raise unittest.SkipTest("Loading failed.") + # Unresolved error + else: + print("Could not load the model files.") + raise unittest.SkipTest("Loading failed.") + def test_H2O_engrad(self): xyz_file, input_file, engrad_out, output_file = get_filenames("H2O") diff --git a/tests/xtb/test_xtb.py b/tests/xtb/test_xtb.py index bdd2ada..90463fd 100644 --- a/tests/xtb/test_xtb.py +++ b/tests/xtb/test_xtb.py @@ -11,6 +11,7 @@ write_xyz_file, ) +# Path to the scripts, adjust if needed. xtb_script_path = ROOT_DIR / "../../bin/oet_xtb" # Leave xtb_executable_path empty, if xtb from system path should be called xtb_executable_path = ""