Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 18 additions & 25 deletions doc/source/train/doc_code/key_concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@
from pathlib import Path
import tempfile

from ray import train
from ray.train import Checkpoint
import ray.train
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer


def train_fn(config):
for i in range(3):
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
Path(temp_checkpoint_dir).joinpath("model.pt").touch()
train.report(
{"loss": i}, checkpoint=Checkpoint.from_directory(temp_checkpoint_dir)
ray.train.report(
{"loss": i},
checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
)


trainer = DataParallelTrainer(
train_fn, scaling_config=train.ScalingConfig(num_workers=2)
train_fn, scaling_config=ray.train.ScalingConfig(num_workers=2)
)


Expand Down Expand Up @@ -57,22 +57,6 @@ def train_fn(config):
)
# __checkpoint_config_end__

# __checkpoint_config_ckpt_freq_start__
from ray.train import RunConfig, CheckpointConfig

run_config = RunConfig(
checkpoint_config=CheckpointConfig(
# Checkpoint every iteration.
checkpoint_frequency=1,
# Only keep the latest checkpoint and delete the others.
num_to_keep=1,
)
)

# from ray.train.xgboost import XGBoostTrainer
# trainer = XGBoostTrainer(..., run_config=run_config)
# __checkpoint_config_ckpt_freq_end__


# __result_metrics_start__
result = trainer.fit()
Expand Down Expand Up @@ -129,9 +113,18 @@ def train_fn(config):
# __result_restore_end__


# __result_error_start__
if result.error:
assert isinstance(result.error, Exception)
def error_train_fn(config):
raise RuntimeError("Simulated training error")


trainer = DataParallelTrainer(
error_train_fn, scaling_config=ray.train.ScalingConfig(num_workers=1)
)

print("Got exception:", result.error)
# __result_error_start__
try:
result = trainer.fit()
except ray.train.TrainingFailedError as e:
if isinstance(e, ray.train.WorkerGroupError):
print(e.worker_failures)
# __result_error_end__
6 changes: 3 additions & 3 deletions doc/source/train/user-guides/results.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ access the storage location, which is useful if the path is on cloud storage.
Viewing Errors
--------------
Catching Errors
---------------
If an error occurred during training,
:attr:`Result.error <ray.train.Result>` will be set and contain the exception
that was raised.
Expand All @@ -138,7 +138,7 @@ that was raised.

Finding results on persistent storage
-------------------------------------
All training results, including reported metrics, checkpoints, and error files,
All training results including reported metrics and checkpoints
are stored on the configured :ref:`persistent storage <train-log-dir>`.

See :ref:`the persistent storage guide <train-log-dir>` to configure this location
Expand Down
1 change: 1 addition & 0 deletions python/ray/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
) from exc
from ray.train.v2.api.callback import UserCallback # noqa: F811
from ray.train.v2.api.config import ( # noqa: F811
CheckpointConfig,
FailureConfig,
RunConfig,
ScalingConfig,
Expand Down
68 changes: 63 additions & 5 deletions python/ray/train/v2/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Literal, Optional, Union

import pyarrow.fs

from ray.air.config import (
CheckpointConfig,
FailureConfig as FailureConfigV1,
ScalingConfig as ScalingConfigV1,
)
Expand Down Expand Up @@ -132,6 +131,66 @@ def num_tpus_per_worker(self):
return self._resources_per_worker_not_none.get("TPU", 0)


@dataclass
@PublicAPI(stability="stable")
class CheckpointConfig:
"""Configuration for checkpointing.

Default behavior is to persist all checkpoints reported with
:meth:`ray.train.report` to disk. If ``num_to_keep`` is set,
the default retention policy is to keep the most recent checkpoints.

Args:
num_to_keep: The maximum number of checkpoints to keep.
If you report more checkpoints than this, the oldest
(or lowest-scoring, if ``checkpoint_score_attribute`` is set)
checkpoint will be deleted.
If this is ``None`` then all checkpoints will be kept. Must be >= 1.
checkpoint_score_attribute: The attribute that will be used to
score checkpoints to determine which checkpoints should be kept.
This attribute must be a key from the metrics dictionary
attached to the checkpoint. This attribute must have a numerical value.
checkpoint_score_order: Either "max" or "min".
If "max"/"min", then checkpoints with highest/lowest values of
the ``checkpoint_score_attribute`` will be kept. Defaults to "max".
checkpoint_frequency: [Deprecated]
checkpoint_at_end: [Deprecated]
"""

num_to_keep: Optional[int] = None
checkpoint_score_attribute: Optional[str] = None
checkpoint_score_order: Literal["max", "min"] = "max"
checkpoint_frequency: Union[Optional[int], Literal[_DEPRECATED]] = _DEPRECATED
checkpoint_at_end: Union[Optional[bool], Literal[_DEPRECATED]] = _DEPRECATED

def __post_init__(self):
if self.checkpoint_frequency != _DEPRECATED:
raise DeprecationWarning(
"`checkpoint_frequency` is deprecated since it does not "
"apply to user-defined training functions. "
"Please remove this argument from your CheckpointConfig."
)

if self.checkpoint_at_end != _DEPRECATED:
raise DeprecationWarning(
"`checkpoint_at_end` is deprecated since it does not "
"apply to user-defined training functions. "
"Please remove this argument from your CheckpointConfig."
)

if self.num_to_keep is not None and self.num_to_keep <= 0:
raise ValueError(
f"Received invalid num_to_keep: {self.num_to_keep}. "
"Must be None or an integer >= 1."
)

if self.checkpoint_score_order not in ("max", "min"):
raise ValueError(
f"Received invalid checkpoint_score_order: {self.checkpoint_score_order}. "
"Must be 'max' or 'min'."
)


@dataclass
class FailureConfig(FailureConfigV1):
"""Configuration related to failure handling of each training run.
Expand Down Expand Up @@ -162,12 +221,12 @@ class RunConfig:
Args:
name: Name of the trial or experiment. If not provided, will be deduced
from the Trainable.
storage_path: [Beta] Path where all results and checkpoints are persisted.
storage_path: Path where all results and checkpoints are persisted.
Can be a local directory or a destination on cloud storage.
For multi-node training/tuning runs, this must be set to a
shared storage location (e.g., S3, NFS).
This defaults to the local ``~/ray_results`` directory.
storage_filesystem: [Beta] A custom filesystem to use for storage.
storage_filesystem: A custom filesystem to use for storage.
If this is provided, `storage_path` should be a path with its
prefix stripped (e.g., `s3://bucket/path` -> `bucket/path`).
failure_config: Failure mode configuration.
Expand Down Expand Up @@ -244,7 +303,6 @@ def __post_init__(self):
"https://github.com/ray-project/ray/issues/49454"
)

# TODO: Create a separate V2 CheckpointConfig class.
if not isinstance(self.checkpoint_config, CheckpointConfig):
raise ValueError(
f"Invalid `CheckpointConfig` type: {self.checkpoint_config.__class__}. "
Expand Down
7 changes: 1 addition & 6 deletions python/ray/train/v2/jax/jax_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union

from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
from ray.train import Checkpoint, DataConfig
from ray.train import DataConfig
from ray.train.trainer import GenDataset
from ray.train.v2.api.config import RunConfig, ScalingConfig
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
Expand Down Expand Up @@ -116,9 +116,6 @@ def main(argv: Sequence[str]):
by calling ``ray.train.get_dataset_shard(name)``.
Sharding and additional configuration can be done by
passing in a ``dataset_config``.
resume_from_checkpoint: A checkpoint to resume training from.
This checkpoint can be accessed from within ``train_loop_per_worker``
by calling ``ray.train.get_checkpoint()``.
"""

def __init__(
Expand All @@ -131,7 +128,6 @@ def __init__(
dataset_config: Optional[Dict[str, DataConfig]] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
):
if not jax_config:
jax_config = JaxConfig(
Expand All @@ -145,7 +141,6 @@ def __init__(
dataset_config=dataset_config,
run_config=run_config,
datasets=datasets,
resume_from_checkpoint=resume_from_checkpoint,
)

@classmethod
Expand Down
8 changes: 2 additions & 6 deletions python/ray/train/v2/lightgbm/lightgbm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,8 @@ def train_fn_per_worker(config: dict):
dataset_config: The configuration for ingesting the input ``datasets``.
By default, all the Ray Dataset are split equally across workers.
See :class:`~ray.train.DataConfig` for more details.
resume_from_checkpoint: A checkpoint to resume training from.
This checkpoint can be accessed from within ``train_loop_per_worker``
by calling ``ray.train.get_checkpoint()``.
metadata: Dict that should be made available via
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
for checkpoints saved from this Trainer. Must be JSON-serializable.
resume_from_checkpoint: [Deprecated]
metadata: [Deprecated]
"""

def __init__(
Expand Down
6 changes: 2 additions & 4 deletions python/ray/train/v2/tensorflow/tensorflow_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,8 @@ def train_loop_per_worker(config):
by calling ``ray.train.get_dataset_shard(name)``.
Sharding and additional configuration can be done by
passing in a ``dataset_config``.
resume_from_checkpoint: A checkpoint to resume training from.
metadata: Dict that should be made available via
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
for checkpoints saved from this Trainer. Must be JSON-serializable.
resume_from_checkpoint: [Deprecated]
metadata: [Deprecated]
"""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

import ray
import ray.cloudpickle as ray_pickle
from ray.air.config import CheckpointConfig
from ray.train import Checkpoint, RunConfig, ScalingConfig
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig
from ray.train.tests.util import create_dict_checkpoint, load_dict_checkpoint
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
from ray.train.v2.api.exceptions import WorkerGroupError
Expand Down
8 changes: 2 additions & 6 deletions python/ray/train/v2/torch/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,8 @@ def train_fn_per_worker(config):
dataset_config: The configuration for ingesting the input ``datasets``.
By default, all the Ray Dataset are split equally across workers.
See :class:`~ray.train.DataConfig` for more details.
resume_from_checkpoint: A checkpoint to resume training from.
This checkpoint can be accessed from within ``train_loop_per_worker``
by calling ``ray.train.get_checkpoint()``.
metadata: Dict that should be made available via
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
for checkpoints saved from this Trainer. Must be JSON-serializable.
resume_from_checkpoint: [Deprecated]
metadata: [Deprecated]
"""

def __init__(
Expand Down
8 changes: 2 additions & 6 deletions python/ray/train/v2/xgboost/xgboost_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,8 @@ def train_fn_per_worker(config: dict):
dataset_config: The configuration for ingesting the input ``datasets``.
By default, all the Ray Dataset are split equally across workers.
See :class:`~ray.train.DataConfig` for more details.
resume_from_checkpoint: A checkpoint to resume training from.
This checkpoint can be accessed from within ``train_loop_per_worker``
by calling ``ray.train.get_checkpoint()``.
metadata: Dict that should be made available via
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
for checkpoints saved from this Trainer. Must be JSON-serializable.
resume_from_checkpoint: [Deprecated]
metadata: [Deprecated]
"""

def __init__(
Expand Down