Skip to content

Commit

Permalink
slurmrunner import
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffwoollard committed Jan 16, 2025
1 parent f7cf274 commit 232357b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dask import delayed, compute
from dask.distributed import Client
from dask.diagnostics import ProgressBar
from dask_hpc_runner import SlurmRunner

from cryo_challenge._preprocessing.fourier_utils import downsample_volume

Expand Down Expand Up @@ -409,14 +410,13 @@ def main(args):
if __name__ == "__main__":
args = parse_args()
if args.slurm:
pass
# job_id = os.environ["SLURM_JOB_ID"]
# with SlurmRunner(
# scheduler_file=args.scheduler_file,
# ) as runner:
# # The runner object contains the scheduler address and can be passed directly to a client
# with Client(runner) as client:
# get_distance_matrix_dask_gw = main(args)
job_id = os.environ["SLURM_JOB_ID"]
with SlurmRunner(
scheduler_file=args.scheduler_file,
) as runner:
# The runner object contains the scheduler address and can be passed directly to a client
with Client(runner) as client:
get_distance_matrix_dask_gw = main(args)

else:
with Client(local_directory=args.local_directory) as client:
Expand Down
40 changes: 20 additions & 20 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import mrcfile
import numpy as np
from dask.distributed import Client
from dask_hpc_runner import SlurmRunner

from .gromov_wasserstein.gw_weighted_voxels import get_distance_matrix_dask_gw

Expand Down Expand Up @@ -498,26 +499,25 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
maps2 = maps2.reshape((len(maps2),) + maps1.shape[1:])

if extra_params["slurm"]:
pass
# job_id = os.environ["SLURM_JOB_ID"]
# scheduler_file = os.path.join(
# extra_params["scheduler_file_dir"], f"scheduler-{job_id}.json"
# )
# with SlurmRunner(
# scheduler_file=scheduler_file,
# ) as runner:
# # The runner object contains the scheduler address and can be passed directly to a client
# with Client(runner) as client:
# distance_matrix_dask_gw = get_distance_matrix_dask_gw(
# volumes_i=maps1,
# volumes_j=maps2,
# top_k=extra_params["top_k"],
# n_downsample_pix=extra_params["n_downsample_pix"],
# exponent=extra_params["exponent"],
# cost_scale_factor=extra_params["cost_scale_factor"],
# scheduler=extra_params["scheduler"],
# element_wise=extra_params["element_wise"],
# )
job_id = os.environ["SLURM_JOB_ID"]
scheduler_file = os.path.join(
extra_params["scheduler_file_dir"], f"scheduler-{job_id}.json"
)
with SlurmRunner(
scheduler_file=scheduler_file,
) as runner:
# The runner object contains the scheduler address and can be passed directly to a client
with Client(runner) as client:
distance_matrix_dask_gw = get_distance_matrix_dask_gw(
volumes_i=maps1,
volumes_j=maps2,
top_k=extra_params["top_k"],
n_downsample_pix=extra_params["n_downsample_pix"],
exponent=extra_params["exponent"],
cost_scale_factor=extra_params["cost_scale_factor"],
scheduler=extra_params["scheduler"],
element_wise=extra_params["element_wise"],
)

else:
local_directory = extra_params["local_directory"]
Expand Down

0 comments on commit 232357b

Please sign in to comment.