Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 14 additions & 22 deletions src/romitask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ class FileByFileTask(RomiTask):
and applies some function to it and saves it back to the target.
"""
query = luigi.DictParameter(default={})
# n_jobs = luigi.IntParameter(default=-1) # control number of parallel jobs
n_jobs = luigi.IntParameter(default=-1) # control number of parallel jobs
type = None

reader = None
Expand All @@ -485,31 +485,23 @@ def run(self):
"""Run the task on every file in the fileset."""
input_fileset = self.input().get()
output_fileset = self.output().get()
for fi in tqdm(input_fileset.get_files(query=self.query), unit="file"):
outfi = self.f(fi, output_fileset)

def apply_f(f, fi, output_fileset):
outfi = f(fi, output_fileset)
return fi, outfi

from joblib import Parallel
from joblib import delayed
files = Parallel(n_jobs=self.n_jobs)(
delayed(apply_f)(self.f, fi, output_fileset) for fi in
tqdm(input_fileset.get_files(query=self.query), unit="file"))

for (infi, outfi) in files:
if outfi is not None:
m = fi.get_metadata()
m = infi.get_metadata()
outm = outfi.get_metadata()
outfi.set_metadata({**m, **outm})

# ATTEMPT to parallelize:
# def apply_f(f, fi, output_fileset):
# outfi = f(fi, output_fileset)
# return fi, outfi
#
# from joblib import Parallel
# from joblib import delayed
# files = Parallel(n_jobs=self.n_jobs)(
# delayed(apply_f)(self.f, fi, output_fileset) for fi in
# tqdm(input_fileset.get_files(query=self.query), unit="file"))
#
# for (infi, outfi) in files:
# if outfi is not None:
# m = infi.get_metadata()
# outm = outfi.get_metadata()
# outfi.set_metadata({**m, **outm})



@RomiTask.event_handler(luigi.Event.FAILURE)
def mourn_failure(task, exception):
Expand Down