Skip to content

Commit

Permalink
Merge pull request #34 from eurunuela/33-dask-scheduler-gets-killed-w…
Browse files Browse the repository at this point in the history
…hen-running-with-big-data

[FIX] Use dask scheduler with big data
  • Loading branch information
eurunuela authored Oct 5, 2022
2 parents 0c7afe7 + 716079e commit cb38b5f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 17 deletions.
17 changes: 11 additions & 6 deletions pySPFM/deconvolution/stability_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_subsampling_indices(n_scans, n_echos, mode="same"):
return subsample_idx


def calculate_auc(coefs, lambdas):
def calculate_auc(coefs, lambdas, n_surrogates):

# Create shared space of lambdas and coefficients
lambdas_shared = np.zeros((lambdas.shape[0] * lambdas.shape[1]))
Expand All @@ -54,14 +54,19 @@ def calculate_auc(coefs, lambdas):
lambdas_sorted_idx = np.argsort(lambdas_shared)
lambdas_sorted = np.sort(lambdas_shared)

# Sort coefficients to match lambdas
# Sum of all lambdas
sum_lambdas = np.sum(lambdas_sorted)

# Sort coefficients
coefs_sorted = coefs_shared[lambdas_sorted_idx]

# Turn coefs_sorted into a binary vector
# Make coefs_sorted binary
coefs_sorted[coefs_sorted != 0] = 1

# Calculate the AUC as the normalized area under the curve
auc = np.trapz(coefs_sorted, lambdas_sorted) / np.sum(lambdas_sorted) / coefs.shape[-1]
# Calculate the AUC
auc = 0
for i in range(lambdas_shared.shape[0]):
auc += coefs_sorted[i] * lambdas_sorted[i] / sum_lambdas

return auc

Expand Down Expand Up @@ -94,6 +99,6 @@ def stability_selection(hrf_norm, data, n_lambdas, n_surrogates):
# Calculate the AUC for each TR
auc = np.zeros((n_scans))
for tr_idx in range(n_scans):
auc[tr_idx] = calculate_auc(estimates[tr_idx, :, :], lambdas)
auc[tr_idx] = calculate_auc(estimates[tr_idx, :, :], lambdas, n_surrogates)

return auc
41 changes: 30 additions & 11 deletions pySPFM/workflows/pySPFM.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def pySPFM(
final_estimates = np.empty((n_scans, n_voxels))

# Iterate between temporal and spatial regularizations
_, cluster = dask_scheduler(n_jobs)
client, _ = dask_scheduler(n_jobs)
for iter_idx in range(max_iter):
if spatial_weight > 0:
data_temp_reg = final_estimates - estimates_temporal + data_masked
Expand All @@ -491,18 +491,28 @@ def pySPFM(
estimates = np.zeros((n_scans, n_voxels))
lambda_map = np.zeros(n_voxels)

# Scatter data to workers if client is not None
if client is not None:
hrf_norm_fut = client.scatter(hrf_norm)
else:
hrf_norm_fut = hrf_norm

if criterion in lars_criteria:
LGR.info("Solving inverse problem with LARS...")
n_lambdas = int(np.ceil(max_iter_factor * n_scans))
# Solve LARS for each voxel with parallelization
futures = []
for vox_idx in range(n_voxels):
fut = delayed_dask(solve_regularization_path, pure=False)(
hrf_norm, data_temp_reg[:, vox_idx], n_lambdas, criterion
hrf_norm_fut, data_temp_reg[:, vox_idx], n_lambdas, criterion
)
futures.append(fut)

lars_estimates = compute(futures)[0]
# Gather results
if client is not None:
lars_estimates = compute(futures)[0]
else:
lars_estimates = compute(futures, scheduler="single-threaded")[0]

for vox_idx in range(n_voxels):
estimates[:, vox_idx] = np.squeeze(lars_estimates[vox_idx][0])
Expand All @@ -514,7 +524,7 @@ def pySPFM(
futures = []
for vox_idx in range(n_voxels):
fut = delayed_dask(fista, pure=False)(
hrf_norm,
hrf_norm_fut,
data_temp_reg[:, vox_idx],
criterion,
max_iter_fista,
Expand All @@ -527,7 +537,11 @@ def pySPFM(
)
futures.append(fut)

fista_estimates = compute(futures)[0]
# Gather results
if client is not None:
fista_estimates = compute(futures)[0]
else:
fista_estimates = compute(futures, scheduler="single-threaded")[0]

for vox_idx in range(n_voxels):
estimates[:, vox_idx] = np.squeeze(fista_estimates[vox_idx][0])
Expand All @@ -539,17 +553,22 @@ def pySPFM(
auc = np.zeros((n_scans, n_voxels))

# Solve stability regularization
futures = []
for vox_idx in range(n_voxels):
fut = delayed_dask(stability_selection)(
hrf_norm,
futures = [
delayed_dask(stability_selection)(
hrf_norm_fut,
data_temp_reg[:, vox_idx],
n_lambdas,
n_surrogates,
)
futures.append(fut)
for vox_idx in range(n_voxels)
]

# Gather results
if client is not None:
stability_estimates = compute(futures)[0]
else:
stability_estimates = compute(futures, scheduler="single-threaded")[0]

stability_estimates = compute(futures)[0]
for vox_idx in range(n_voxels):
auc[:, vox_idx] = np.squeeze(stability_estimates[vox_idx])

Expand Down

0 comments on commit cb38b5f

Please sign in to comment.