Skip to content
This repository has been archived by the owner on Jul 18, 2024. It is now read-only.

Commit

Permalink
Update dask4dvc arguments (#25)
Browse files Browse the repository at this point in the history
* update readme

* allow targets

* cleanup tests

* add force option

* bugfix

* allow -f/--force

* bugfix

* bugfix

* use dvc '_get_steps'

* name stages

* poetry update
  • Loading branch information
PythonFZ authored Apr 21, 2023
1 parent 556449c commit fb2bddc
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 197 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The `dask4dvc` package combines [Dask Distributed](https://distributed.dask.org/

The `dask4dvc` package will try to run the DVC graph in parallel.

> :warning: dask4dvc will disbale a few of the checks that DVC implements. Do not make changes to your workspace during the runtime of `dask4dvc repro`.
## Usage
Dask4DVC provides a CLI similar to DVC.

Expand Down
6 changes: 4 additions & 2 deletions dask4dvc/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import dask.distributed
import typer

import typing
from dask4dvc import methods, utils


Expand Down Expand Up @@ -35,11 +35,13 @@ class Help:

@app.command()
def repro(
targets: typing.List[str] = typer.Argument(None),
address: str = typer.Option(None, help=Help.address),
leave: bool = typer.Option(True, help=Help.leave),
config: str = typer.Option(None, help=Help.config),
max_workers: int = typer.Option(None, help=Help.max_workers),
retries: int = typer.Option(10, help=Help.retries),
force: bool = typer.Option(False, "--force/", "-f/", help="use `dvc repro --force`"),
) -> None:
"""Replicate 'dvc repro' command using dask."""
utils.CONFIG.retries = retries
Expand All @@ -52,7 +54,7 @@ def repro(
if max_workers is not None:
client.cluster.adapt(minimum=1, maximum=max_workers)
log.info(client)
results = methods.parallel_submit(client)
results = methods.parallel_submit(client, targets=targets, force=force)

utils.dask.wait_for_futures(results)
if not leave:
Expand Down
60 changes: 44 additions & 16 deletions dask4dvc/methods.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Some general 'dask4dvc' methods."""


import contextlib
import typing
import logging

import dask.distributed
import dvc.lock
import dvc.exceptions
import dvc.repo
from dvc.repo.reproduce import _get_steps
import dvc.utils.strictyaml
import dvc.stage
from dvc.stage.cache import RunCacheNotFoundError
Expand Down Expand Up @@ -68,14 +71,17 @@ def _load_run_cache(repo: dvc.repo.Repo, stage: dvc.stage.Stage) -> None:
)


def submit_stage(name: str, successors: list) -> str:
def submit_stage(name: str, force: bool, successors: list) -> str:
"""Submit a stage to the Dask cluster."""
repo = dvc.repo.Repo()

# dvc reproduce returns the stages that are not checked out
stages = _run_locked_cmd(repo, repo.reproduce, name, dry=True, single_item=True)
if force:
stages = [repo.stage.get_target(name)]
else:
# dvc reproduce returns the stages that are not checked out
stages = _run_locked_cmd(repo, repo.reproduce, name, dry=True, single_item=True)

if len(stages) == 0:
if len(stages) == 0 and not force:
# if the stage is already checked out, we don't need to run it
log.info(f"Stage '{name}' didn't change, skipping")

Expand All @@ -85,33 +91,55 @@ def submit_stage(name: str, successors: list) -> str:
raise ValueError("Something went wrong")

for stage in stages:
try:
# check if the stage is already in the run cache
_run_locked_cmd(repo, _load_run_cache, repo, stages[0])
except RunCacheNotFoundError:
# if not, run the stage
log.info(f"Running stage '{name}': \n > {stage.cmd}")
subprocess.check_call(stage.cmd, shell=True)
# add the stage to the run cache
_run_locked_cmd(repo, repo.commit, name, force=True)
if not force:
with contextlib.suppress(RunCacheNotFoundError):
# check if the stage is already in the run cache
_run_locked_cmd(repo, _load_run_cache, repo, stages[0])
return name
# if not, run the stage
log.info(f"Running stage '{name}': \n > {stage.cmd}")
subprocess.check_call(stage.cmd, shell=True)
# add the stage to the run cache
_run_locked_cmd(repo, repo.commit, name, force=True)

return name


def parallel_submit(
client: dask.distributed.Client,
client: dask.distributed.Client, targets: list[str], force: bool
) -> typing.Dict[str, dask.distributed.Future]:
"""Submit all stages to the Dask cluster."""
mapping = {}
repo = dvc.repo.Repo()

for node in repo.index.graph.nodes:
if len(targets) == 0:
targets = repo.index.graph.nodes
else:
targets = [repo.stage.get_target(x) for x in targets]

nodes = _get_steps(repo.index.graph, targets, downstream=False, single_item=False)

for node in nodes:
if node.cmd is None:
# if the stage doesn't have a command, e.g. a dvc tracked file
# we don't need to run it
mapping[node] = None
continue
successors = [
mapping[successor] for successor in repo.index.graph.successors(node)
]

mapping[node] = client.submit(
submit_stage, node.name, successors=successors, pure=False
submit_stage,
node.addressing,
force=force,
successors=successors,
pure=False,
key=node.addressing,
)

mapping = {
node.addressing: future for node, future in mapping.items() if future is not None
}

return mapping
Loading

0 comments on commit fb2bddc

Please sign in to comment.