diff --git a/src/romitask/task.py b/src/romitask/task.py index 7a28993..bb98d59 100644 --- a/src/romitask/task.py +++ b/src/romitask/task.py @@ -458,7 +458,7 @@ class FileByFileTask(RomiTask): and applies some function to it and saves it back to the target. """ query = luigi.DictParameter(default={}) - # n_jobs = luigi.IntParameter(default=-1) # control number of parallel jobs + n_jobs = luigi.IntParameter(default=-1) # control number of parallel jobs type = None reader = None @@ -485,31 +485,23 @@ def run(self): """Run the task on every file in the fileset.""" input_fileset = self.input().get() output_fileset = self.output().get() - for fi in tqdm(input_fileset.get_files(query=self.query), unit="file"): - outfi = self.f(fi, output_fileset) + + def apply_f(f, fi, output_fileset): + outfi = f(fi, output_fileset) + return fi, outfi + + from joblib import Parallel + from joblib import delayed + files = Parallel(n_jobs=self.n_jobs)( + delayed(apply_f)(self.f, fi, output_fileset) for fi in + tqdm(input_fileset.get_files(query=self.query), unit="file")) + + for (infi, outfi) in files: if outfi is not None: - m = fi.get_metadata() + m = infi.get_metadata() outm = outfi.get_metadata() outfi.set_metadata({**m, **outm}) - # ATTEMPT to parallelize: - # def apply_f(f, fi, output_fileset): - # outfi = f(fi, output_fileset) - # return fi, outfi - # - # from joblib import Parallel - # from joblib import delayed - # files = Parallel(n_jobs=self.n_jobs)( - # delayed(apply_f)(self.f, fi, output_fileset) for fi in - # tqdm(input_fileset.get_files(query=self.query), unit="file")) - # - # for (infi, outfi) in files: - # if outfi is not None: - # m = infi.get_metadata() - # outm = outfi.get_metadata() - # outfi.set_metadata({**m, **outm}) - - @RomiTask.event_handler(luigi.Event.FAILURE) def mourn_failure(task, exception):