From fe6250812b263b77e6a9103f75ad05fa97644c0c Mon Sep 17 00:00:00 2001 From: Christoph Plett Date: Wed, 10 Dec 2025 15:45:38 +0100 Subject: [PATCH 1/4] Fixing UMA test Signed-off-by: Christoph Plett --- tests/uma/test_uma_client.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/uma/test_uma_client.py b/tests/uma/test_uma_client.py index 83986ae..f00527f 100644 --- a/tests/uma/test_uma_client.py +++ b/tests/uma/test_uma_client.py @@ -3,6 +3,7 @@ import subprocess import time import unittest +from fairchem.core import pretrained_mlip from oet import ROOT_DIR from oet.core.test_utilities import ( @@ -19,6 +20,8 @@ uma_server_path = ROOT_DIR / "../../bin/oet_server" # Default ID and port of server. Change if needed id_port = "127.0.0.1:9000" +# UMA model to use +uma_model = "uma-s-1p1" def run_uma(inputfile: str, output_file: str) -> None: @@ -26,7 +29,7 @@ def run_uma(inputfile: str, output_file: str) -> None: inputfile=inputfile, script_path=uma_script_path, outfile=output_file, - args=["--bind", id_port], + args=["--bind", id_port, "--model", uma_model] ) @@ -36,6 +39,9 @@ def setUpClass(cls): """ Test starting the server """ + # Pre-download UMA model files + print("Checking the model files and downloading them if necessary.") + pretrained_mlip.get_predict_unit(uma_model, device="cpu") print("Starting the server. A detailed server log can be found on file server.out") with open("server.out", "a") as f: cls.server = subprocess.Popen( From 5d929fd66dabdd6eeddec3b2077052fc7384c769 Mon Sep 17 00:00:00 2001 From: Christoph Plett Date: Tue, 16 Dec 2025 10:52:51 +0100 Subject: [PATCH 2/4] Updating testsuite Signed-off-by: Christoph Plett --- noxfile.py | 17 ++-- src/oet/core/test_utilities.py | 115 +++++++++++++++++++++++ tests/aimnet2/test_aiment2_client.py | 1 + tests/aimnet2/test_aiment2_standalone.py | 1 + tests/g-xtb/test_gxtb.py | 1 + tests/mlatom/test_mlatom.py | 23 ++++- tests/mopac/test_mopac.py | 1 + tests/uma/test_uma_client.py | 25 ++++- tests/uma/test_uma_standalone.py | 1 + tests/xtb/test_xtb.py | 1 + 10 files changed, 176 insertions(+), 10 deletions(-) diff --git a/noxfile.py b/noxfile.py index 7b89265..2e903fc 100644 --- a/noxfile.py +++ b/noxfile.py @@ -3,6 +3,9 @@ # > External packages import nox +# > Ignore certain dirs +ASSETS_DIR = "src/oet/assets" + # > Making sure Nox session only see their packages and not any globally installed packages. os.environ.pop("PYTHONPATH", None) # > Hiding any virtual environments from outside. @@ -37,7 +40,7 @@ @nox.session(tags=["static_check"]) def type_check(session): session.install(".[type-check]") - session.run("mypy") + session.run("mypy", "--exclude", f"^{ASSETS_DIR}/") # ////////////////////////////////////////////////// @@ -47,7 +50,7 @@ def type_check(session): def remove_unused_imports(session): session.install(".[lint]") # > Sorting imports with ruff instead of isort - session.run("ruff", "check", "--fix", "--select", "F401") + session.run("ruff", "check", "--fix", "--select", "F401", "--exclude", ASSETS_DIR) # ////////////////////////////////////////// @@ -57,7 +60,7 @@ def remove_unused_imports(session): def sort_imports(session): session.install(".[lint]") # > Sorting imports with ruff instead of isort - session.run("ruff", "check", "--fix", "--select", "I") + session.run("ruff", "check", "--fix", "--select", "I", "--exclude", ASSETS_DIR) # //////////////////////////////////////// @@ -66,7 +69,7 @@ def sort_imports(session): @nox.session(tags=["style", "static_check"]) def lint(session): session.install(".[lint]") - session.run("ruff", "check", "--fix") + session.run("ruff", "check", "--fix", "--exclude", ASSETS_DIR) # ////////////////////////////////////////// @@ -76,7 +79,7 @@ def lint(session): def format_code(session): # Installs the project + the "lint" extra into this nox venv using pip session.install(".[lint]") - session.run("ruff", "format") + session.run("ruff", "format", "--exclude", ASSETS_DIR) # //////////////////////////////////////////////////// @@ -85,7 +88,7 @@ def format_code(session): @nox.session(tags=["static_check"]) def spell_check(session): session.install(".[spell-check]") - session.run("codespell", "src/oet") + session.run("codespell", "src/oet", "--skip", ASSETS_DIR) # ////////////////////////////////////////////// @@ -94,4 +97,4 @@ def spell_check(session): @nox.session(tags=["static_check"], default=True) def dead_code(session): session.install(".[dead-code]") - session.run("vulture") + session.run("vulture", "src", "--exclude", ASSETS_DIR) diff --git a/src/oet/core/test_utilities.py b/src/oet/core/test_utilities.py index 068acea..3a47c6e 100644 --- a/src/oet/core/test_utilities.py +++ b/src/oet/core/test_utilities.py @@ -2,8 +2,20 @@ Utilities used in the test suite """ +import multiprocessing as mp import subprocess +import traceback +from enum import StrEnum from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Literal, TypeAlias, TypeVar + +if TYPE_CHECKING: + from multiprocessing.queues import Queue + +T = TypeVar("T") +Status: TypeAlias = Literal["ok", "err"] +Payload: TypeAlias = bool | str +QueueItem: TypeAlias = tuple[Status, Payload] WATER = [ ("O", 0.0000, 0.0000, 0.0000), @@ -195,3 +207,106 @@ def clear_files(basename: str) -> None: for f in dir_path.glob(basename + "*"): if f.is_file(): f.unlink() # remove file + + +def _worker( + fn: Callable[..., T], args: tuple[Any], kwargs: dict[str, Any], q: "Queue[QueueItem]" +) -> None: + """ + Helper for executing a function. + + Parameters + ---------- + fn: Callable[..., T] + Callable function that is executed. + args: tuple[Any] + Any positional arguments that are given to the function call. + kwargs: dict[str, Any] + Any keyword arguments that are given to the function call. + q: Queue[QueueItem] + Queue used to put the function call. + """ + try: + # Call the function + _ = fn(*args, **kwargs) + # Don't check what it did, just return ok, if the function didn't crash + q.put(("ok", True)) + except Exception: + q.put(("err", traceback.format_exc())) + + +class TimeoutCallError(StrEnum): + """Possible errors that are returned by TimeoutCall""" + + # Function timed out + TIMEOUT = "timeout" + # Function crashed + CRASH = "crash" + # General error + ERROR = "error" + + +class TimeoutCall: + """ + Class for calling a function with a certain timeout. + Useful for functions that, e.g., download files. + Doesn't return the result of the function as it might not be pickled. + """ + + def __init__(self, fn: Callable[..., T]) -> None: + """ + Initialization of the class. + + Parameters + ---------- + fn: Callable[..., T] + Callable function that is executed. + """ + self.fn = fn + self.timed_out = False + + def __call__(self, *args: Any, timeout: float = 10, **kwargs: Any) -> tuple[bool, Any]: + """ + Execute the function set in __init__ with the timeout defined there. + + Parameters + ---------- + args: Any + Any positional arguments that are given to the function call. + timeout: float, default: 10 sec. + Timeout in sec. + kwargs: Any + Any keyword arguments that are given to the function call. + + Returns + ------- + bool + True, if everything was ok. False otherwise. + Any + Either the error type if failed or the result of the function. + """ + + # Start process and wait the timeout + q: "Queue[QueueItem]" = mp.Queue() + p: mp.Process = mp.Process(target=_worker, args=(self.fn, args, kwargs, q)) + p.start() + p.join(timeout) + + # Check if the process is still alive. If yes, it has timed out. + if p.is_alive(): + p.terminate() + p.join() + return False, TimeoutCallError.TIMEOUT + + # Check if there are any results, otherwise the function is crashed + try: + status, payload = q.get(timeout=1) + except Exception: + return False, TimeoutCallError.CRASH + + # Check if there was a general error + if status == "err": + return False, TimeoutCallError.ERROR + + # If everything went well, return the function result. + return True, payload diff --git a/tests/aimnet2/test_aiment2_client.py b/tests/aimnet2/test_aiment2_client.py index 3c7e0ef..8244011 100644 --- a/tests/aimnet2/test_aiment2_client.py +++ b/tests/aimnet2/test_aiment2_client.py @@ -15,6 +15,7 @@ write_xyz_file, ) +# Path to the scripts, adjust if needed. aimnet2_script_path = ROOT_DIR / "../../bin/oet_client" aimnet2_server_path = ROOT_DIR / "../../bin/oet_server" # Default ID and port of server. Change if needed diff --git a/tests/aimnet2/test_aiment2_standalone.py b/tests/aimnet2/test_aiment2_standalone.py index cedfe56..2c88e3a 100644 --- a/tests/aimnet2/test_aiment2_standalone.py +++ b/tests/aimnet2/test_aiment2_standalone.py @@ -11,6 +11,7 @@ write_xyz_file, ) +# Path to the script, adjust if needed. aimnet2_script_path = ROOT_DIR / "../../bin/oet_aimnet2" diff --git a/tests/g-xtb/test_gxtb.py b/tests/g-xtb/test_gxtb.py index c1ca583..709d008 100644 --- a/tests/g-xtb/test_gxtb.py +++ b/tests/g-xtb/test_gxtb.py @@ -11,6 +11,7 @@ write_xyz_file, ) +# Path to the scripts, adjust if needed. gxtb_script_path = ROOT_DIR / "../../bin/oet_gxtb" # Leave uma_executable_path empty, if gxtb from system path should be called gxtb_executable_path = "" diff --git a/tests/mlatom/test_mlatom.py b/tests/mlatom/test_mlatom.py index 12eb806..f9fc6f7 100644 --- a/tests/mlatom/test_mlatom.py +++ b/tests/mlatom/test_mlatom.py @@ -6,6 +6,8 @@ from oet.core.test_utilities import ( OH, WATER, + TimeoutCall, + TimeoutCallError, get_filenames, read_result_file, run_wrapper, @@ -13,7 +15,10 @@ write_xyz_file, ) +# Path to the script, adjust if needed. mlatom_script_path = ROOT_DIR / "../../bin/oet_mlatom" +# Default maximum time (in sec) to download the model files if not present +timeout = 300 # Leave mlatom_executable_path empty, if mlatom from system path should be called mlatom_executable_path = "" @@ -33,7 +38,23 @@ class MLatomTests(unittest.TestCase): @classmethod def setUpClass(cls): # Force download / initialization of ANI-1ccx once - torchani.models.ANI1ccx(periodic_table_index=True) + print("Checking the model files and downloading them if necessary.") + # Make a timeout call to avoid hanging forever + get_ani1ccx_timeout = TimeoutCall(fn=torchani.models.ANI1ccx) + ok, payload = get_ani1ccx_timeout(timeout=timeout, periodic_table_index=True) + if not ok: + if payload == TimeoutCallError.TIMEOUT: + print( + "Loading the model files timed out. " + "Please check your internet connection and consider increasing the time before timing out." + ) + raise unittest.SkipTest("Timed out.") + if payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR: + print( + "Loading the model files failed. Make sure that " + "the virtual environment with MLAtoms installed is active." + ) + raise unittest.SkipTest("Loading failed.") def test_H2O_engrad(self): xyz_file, input_file, engrad_out, output_file = get_filenames("H2O") diff --git a/tests/mopac/test_mopac.py b/tests/mopac/test_mopac.py index b5a7644..c8860ea 100644 --- a/tests/mopac/test_mopac.py +++ b/tests/mopac/test_mopac.py @@ -11,6 +11,7 @@ write_xyz_file, ) +# Path to the script, adjust if needed. mopac_script_path = ROOT_DIR / "../../bin/oet_mopac" # Leave moppac_executable_path empty, if mopac from system path should be called mopac_executable_path = "" diff --git a/tests/uma/test_uma_client.py b/tests/uma/test_uma_client.py index f00527f..8c541de 100644 --- a/tests/uma/test_uma_client.py +++ b/tests/uma/test_uma_client.py @@ -3,12 +3,15 @@ import subprocess import time import unittest + from fairchem.core import pretrained_mlip from oet import ROOT_DIR from oet.core.test_utilities import ( OH, WATER, + TimeoutCall, + TimeoutCallError, get_filenames, read_result_file, run_wrapper, @@ -16,8 +19,11 @@ write_xyz_file, ) +# Path to the scripts, adjust if needed. uma_script_path = ROOT_DIR / "../../bin/oet_client" uma_server_path = ROOT_DIR / "../../bin/oet_server" +# Default maximum time (in sec) to download the model files if not present +timeout = 600 # Default ID and port of server. Change if needed id_port = "127.0.0.1:9000" # UMA model to use @@ -29,7 +35,7 @@ def run_uma(inputfile: str, output_file: str) -> None: inputfile=inputfile, script_path=uma_script_path, outfile=output_file, - args=["--bind", id_port, "--model", uma_model] + args=["--bind", id_port, "--model", uma_model], ) @@ -41,7 +47,22 @@ def setUpClass(cls): """ # Pre-download UMA model files print("Checking the model files and downloading them if necessary.") - pretrained_mlip.get_predict_unit(uma_model, device="cpu") + # Make a timeout call to avoid hanging forever + get_pretrained_mlip_timeout = TimeoutCall(fn=pretrained_mlip.get_predict_unit) + ok, payload = get_pretrained_mlip_timeout(uma_model, timeout=timeout, device="cpu") + if not ok: + if payload == TimeoutCallError.TIMEOUT: + print( + "Loading the model files timed out. " + "Please check your internet connection and consider increasing the time before timing out." + ) + raise unittest.SkipTest("Timed out.") + if payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR: + print( + "Loading the model files failed. Make sure that " + "the virtual environment with UMA installed is active." + ) + raise unittest.SkipTest("Loading failed.") print("Starting the server. A detailed server log can be found on file server.out") with open("server.out", "a") as f: cls.server = subprocess.Popen( diff --git a/tests/uma/test_uma_standalone.py b/tests/uma/test_uma_standalone.py index 7a675ff..a1ed04f 100644 --- a/tests/uma/test_uma_standalone.py +++ b/tests/uma/test_uma_standalone.py @@ -11,6 +11,7 @@ write_xyz_file, ) +# Path to the script, adjust if needed. uma_script_path = ROOT_DIR / "../../bin/oet_uma" diff --git a/tests/xtb/test_xtb.py b/tests/xtb/test_xtb.py index bdd2ada..90463fd 100644 --- a/tests/xtb/test_xtb.py +++ b/tests/xtb/test_xtb.py @@ -11,6 +11,7 @@ write_xyz_file, ) +# Path to the scripts, adjust if needed. xtb_script_path = ROOT_DIR / "../../bin/oet_xtb" # Leave xtb_executable_path empty, if xtb from system path should be called xtb_executable_path = "" From 33bc3138ef092392f106fd01d02040bb4a98b744 Mon Sep 17 00:00:00 2001 From: Christoph Plett Date: Tue, 16 Dec 2025 15:50:33 +0100 Subject: [PATCH 3/4] Test updates Signed-off-by: Christoph Plett --- src/oet/core/test_utilities.py | 40 ++++++++++++------------- tests/uma/test_uma_client.py | 26 ++++++++++++++--- tests/uma/test_uma_standalone.py | 50 ++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 24 deletions(-) diff --git a/src/oet/core/test_utilities.py b/src/oet/core/test_utilities.py index 3a47c6e..eeabdcb 100644 --- a/src/oet/core/test_utilities.py +++ b/src/oet/core/test_utilities.py @@ -4,19 +4,13 @@ import multiprocessing as mp import subprocess -import traceback from enum import StrEnum from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Literal, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Any, Callable if TYPE_CHECKING: from multiprocessing.queues import Queue -T = TypeVar("T") -Status: TypeAlias = Literal["ok", "err"] -Payload: TypeAlias = bool | str -QueueItem: TypeAlias = tuple[Status, Payload] - WATER = [ ("O", 0.0000, 0.0000, 0.0000), ("H", 0.2774, 0.8929, 0.2544), @@ -210,7 +204,7 @@ def clear_files(basename: str) -> None: def _worker( - fn: Callable[..., T], args: tuple[Any], kwargs: dict[str, Any], q: "Queue[QueueItem]" + fn: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any], q: "Queue[bool]" ) -> None: """ Helper for executing a function. @@ -223,16 +217,16 @@ def _worker( Any positional arguments that are given to the function call. kwargs: dict[str, Any] Any keyword arguments that are given to the function call. - q: Queue[QueueItem] + q: Queue[bool] Queue used to put the function call. """ try: # Call the function _ = fn(*args, **kwargs) # Don't check what it did, just return ok, if the function didn't crash - q.put(("ok", True)) + q.put(True) except Exception: - q.put(("err", traceback.format_exc())) + q.put(False) class TimeoutCallError(StrEnum): @@ -253,7 +247,7 @@ class TimeoutCall: Doesn't return the result of the function as it might not be pickled. """ - def __init__(self, fn: Callable[..., T]) -> None: + def __init__(self, fn: Callable[..., Any]) -> None: """ Initialization of the class. @@ -265,7 +259,9 @@ def __init__(self, fn: Callable[..., T]) -> None: self.fn = fn self.timed_out = False - def __call__(self, *args: Any, timeout: float = 10, **kwargs: Any) -> tuple[bool, Any]: + def __call__( + self, *args: Any, timeout: float = 10, **kwargs: Any + ) -> tuple[bool, TimeoutCallError | None]: """ Execute the function set in __init__ with the timeout defined there. @@ -282,31 +278,35 @@ def __call__(self, *args: Any, timeout: float = 10, **kwargs: Any) -> tuple[bool ------- bool True, if everything was ok. False otherwise. - Any - Either the error type if failed or the result of the function. + TimeoutCallError | None + Either the error type if failed or None. """ # Start process and wait the timeout - q: "Queue[QueueItem]" = mp.Queue() + q: "Queue[bool]" = mp.Queue() p: mp.Process = mp.Process(target=_worker, args=(self.fn, args, kwargs, q)) p.start() p.join(timeout) + # Check if there was any error + if p.exitcode not in (0, None): + return False, TimeoutCallError.CRASH + # Check if the process is still alive. If yes, it has timed out. if p.is_alive(): p.terminate() p.join() return False, TimeoutCallError.TIMEOUT - # Check if there are any results, otherwise the function is crashed + # Check if the worker provides the correct result try: - status, payload = q.get(timeout=1) + status_ok = q.get(timeout=1) except Exception: return False, TimeoutCallError.CRASH # Check if there was a general error - if status == "err": + if not status_ok: return False, TimeoutCallError.ERROR # If everything went well, return the function result. - return True, payload + return True, None diff --git a/tests/uma/test_uma_client.py b/tests/uma/test_uma_client.py index 8c541de..24ba2bf 100644 --- a/tests/uma/test_uma_client.py +++ b/tests/uma/test_uma_client.py @@ -4,9 +4,8 @@ import time import unittest -from fairchem.core import pretrained_mlip - from oet import ROOT_DIR +from oet.calculator.uma import DEFAULT_CACHE_DIR, UmaCalc from oet.core.test_utilities import ( OH, WATER, @@ -30,6 +29,25 @@ uma_model = "uma-s-1p1" +def cache_model_files( + basemodel: str, param: str = "omol", device: str = "cpu", cache_dir: str = DEFAULT_CACHE_DIR +) -> None: + """ + Wrapper to set up an UMA calculator that downloads the model files into the same cache-directory used for actual oet calculations. + + basemodel: str + Basemodel used to calculate the test cases + param: str, default: omol + Parameter set. + device str, default: cpu + Device used for the calculations. + cache_dir: str, default: DEFAULT_CACHE_DIR + The cache directory used to store the model data. + """ + calculator = UmaCalc() + calculator.set_calculator(param=param, basemodel=basemodel, device=device, cache_dir=cache_dir) + + def run_uma(inputfile: str, output_file: str) -> None: run_wrapper( inputfile=inputfile, @@ -48,8 +66,8 @@ def setUpClass(cls): # Pre-download UMA model files print("Checking the model files and downloading them if necessary.") # Make a timeout call to avoid hanging forever - get_pretrained_mlip_timeout = TimeoutCall(fn=pretrained_mlip.get_predict_unit) - ok, payload = get_pretrained_mlip_timeout(uma_model, timeout=timeout, device="cpu") + get_pretrained_mlip_timeout = TimeoutCall(fn=cache_model_files) + ok, payload = get_pretrained_mlip_timeout(uma_model, timeout=timeout) if not ok: if payload == TimeoutCallError.TIMEOUT: print( diff --git a/tests/uma/test_uma_standalone.py b/tests/uma/test_uma_standalone.py index a1ed04f..cd1ccbd 100644 --- a/tests/uma/test_uma_standalone.py +++ b/tests/uma/test_uma_standalone.py @@ -1,9 +1,12 @@ import unittest from oet import ROOT_DIR +from oet.calculator.uma import DEFAULT_CACHE_DIR, UmaCalc from oet.core.test_utilities import ( OH, WATER, + TimeoutCall, + TimeoutCallError, get_filenames, read_result_file, run_wrapper, @@ -13,6 +16,29 @@ # Path to the script, adjust if needed. uma_script_path = ROOT_DIR / "../../bin/oet_uma" +# Default maximum time (in sec) to download the model files if not present +timeout = 600 +# UMA model to use +uma_model = "uma-s-1p1" + + +def cache_model_files( + basemodel: str, param: str = "omol", device: str = "cpu", cache_dir: str = DEFAULT_CACHE_DIR +) -> None: + """ + Wrapper to set up an UMA calculator that downloads the model files into the same cache-directory used for actual oet calculations. + + basemodel: str + Basemodel used to calculate the test cases + param: str, default: omol + Parameter set. + device str, default: cpu + Device used for the calculations. + cache_dir: str, default: DEFAULT_CACHE_DIR + The cache directory used to store the model data. + """ + calculator = UmaCalc() + calculator.set_calculator(param=param, basemodel=basemodel, device=device, cache_dir=cache_dir) def run_uma(inputfile: str, output_file: str) -> None: @@ -20,6 +46,30 @@ def run_uma(inputfile: str, output_file: str) -> None: class UmaTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + """ + Test starting the server + """ + # Pre-download UMA model files + print("Checking the model files and downloading them if necessary.") + # Make a timeout call to avoid hanging forever + get_pretrained_mlip_timeout = TimeoutCall(fn=cache_model_files) + ok, payload = get_pretrained_mlip_timeout(uma_model, timeout=timeout) + if not ok: + if payload == TimeoutCallError.TIMEOUT: + print( + "Loading the model files timed out. " + "Please check your internet connection and consider increasing the time before timing out." + ) + raise unittest.SkipTest("Timed out.") + if payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR: + print( + "Loading the model files failed. Make sure that " + "the virtual environment with UMA installed is active." + ) + raise unittest.SkipTest("Loading failed.") + def test_H2O_engrad(self): xyz_file, input_file, engrad_out, output_file = get_filenames("H2O") From d732ddc897fa512267847a7e26ab441c2327968f Mon Sep 17 00:00:00 2001 From: Christoph Plett Date: Wed, 17 Dec 2025 08:30:30 +0100 Subject: [PATCH 4/4] Formatting Signed-off-by: Christoph Plett --- src/oet/core/test_utilities.py | 2 +- tests/mlatom/test_mlatom.py | 11 +++++++++-- tests/uma/test_uma_client.py | 9 ++++++++- tests/uma/test_uma_standalone.py | 12 ++++++++++-- 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/oet/core/test_utilities.py b/src/oet/core/test_utilities.py index eeabdcb..8cb8188 100644 --- a/src/oet/core/test_utilities.py +++ b/src/oet/core/test_utilities.py @@ -308,5 +308,5 @@ def __call__( if not status_ok: return False, TimeoutCallError.ERROR - # If everything went well, return the function result. + # If everything went well, return True and no Error. return True, None diff --git a/tests/mlatom/test_mlatom.py b/tests/mlatom/test_mlatom.py index f9fc6f7..790d3f0 100644 --- a/tests/mlatom/test_mlatom.py +++ b/tests/mlatom/test_mlatom.py @@ -42,19 +42,26 @@ def setUpClass(cls): # Make a timeout call to avoid hanging forever get_ani1ccx_timeout = TimeoutCall(fn=torchani.models.ANI1ccx) ok, payload = get_ani1ccx_timeout(timeout=timeout, periodic_table_index=True) + # Check if the model files could not be loaded if not ok: + # Timeout if payload == TimeoutCallError.TIMEOUT: print( "Loading the model files timed out. " "Please check your internet connection and consider increasing the time before timing out." ) raise unittest.SkipTest("Timed out.") - if payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR: + # General errors and crashes + elif payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR: print( "Loading the model files failed. Make sure that " - "the virtual environment with MLAtoms installed is active." + "the virtual environment with MLAtom installed is active." ) raise unittest.SkipTest("Loading failed.") + # Unresolved error + else: + print("Could not load the model files.") + raise unittest.SkipTest("Loading failed.") def test_H2O_engrad(self): xyz_file, input_file, engrad_out, output_file = get_filenames("H2O") diff --git a/tests/uma/test_uma_client.py b/tests/uma/test_uma_client.py index 24ba2bf..0b253a1 100644 --- a/tests/uma/test_uma_client.py +++ b/tests/uma/test_uma_client.py @@ -68,19 +68,26 @@ def setUpClass(cls): # Make a timeout call to avoid hanging forever get_pretrained_mlip_timeout = TimeoutCall(fn=cache_model_files) ok, payload = get_pretrained_mlip_timeout(uma_model, timeout=timeout) + # Check if the model files could not be loaded if not ok: + # Timeout if payload == TimeoutCallError.TIMEOUT: print( "Loading the model files timed out. " "Please check your internet connection and consider increasing the time before timing out." ) raise unittest.SkipTest("Timed out.") - if payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR: + # General errors and crashes + elif payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR: print( "Loading the model files failed. Make sure that " "the virtual environment with UMA installed is active." ) raise unittest.SkipTest("Loading failed.") + # Unresolved error + else: + print("Could not load the model files.") + raise unittest.SkipTest("Loading failed.") print("Starting the server. A detailed server log can be found on file server.out") with open("server.out", "a") as f: cls.server = subprocess.Popen( diff --git a/tests/uma/test_uma_standalone.py b/tests/uma/test_uma_standalone.py index cd1ccbd..a8924c0 100644 --- a/tests/uma/test_uma_standalone.py +++ b/tests/uma/test_uma_standalone.py @@ -42,7 +42,8 @@ def cache_model_files( def run_uma(inputfile: str, output_file: str) -> None: - run_wrapper(inputfile=inputfile, script_path=uma_script_path, outfile=output_file) + # Run the wrapper with an increased timeout as loading the UMA model files might take a while + run_wrapper(inputfile=inputfile, script_path=uma_script_path, outfile=output_file, timeout=30) class UmaTests(unittest.TestCase): @@ -56,19 +57,26 @@ def setUpClass(cls): # Make a timeout call to avoid hanging forever get_pretrained_mlip_timeout = TimeoutCall(fn=cache_model_files) ok, payload = get_pretrained_mlip_timeout(uma_model, timeout=timeout) + # Check if the model files could not be loaded if not ok: + # Timeout if payload == TimeoutCallError.TIMEOUT: print( "Loading the model files timed out. " "Please check your internet connection and consider increasing the time before timing out." ) raise unittest.SkipTest("Timed out.") - if payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR: + # General errors and crashes + elif payload == TimeoutCallError.CRASH or payload == TimeoutCallError.ERROR: print( "Loading the model files failed. Make sure that " "the virtual environment with UMA installed is active." ) raise unittest.SkipTest("Loading failed.") + # Unresolved error + else: + print("Could not load the model files.") + raise unittest.SkipTest("Loading failed.") def test_H2O_engrad(self): xyz_file, input_file, engrad_out, output_file = get_filenames("H2O")