From 67ffa90e2730b71936f0bac329eba67916ccc61a Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Wed, 22 Oct 2025 15:13:51 -0700 Subject: [PATCH 1/7] new ckpt config with deprecations Signed-off-by: Justin Yu --- python/ray/train/__init__.py | 1 + python/ray/train/v2/api/config.py | 63 ++++++++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/python/ray/train/__init__.py b/python/ray/train/__init__.py index 9051eded3ef8..2d3dd6bd328a 100644 --- a/python/ray/train/__init__.py +++ b/python/ray/train/__init__.py @@ -37,6 +37,7 @@ ) from exc from ray.train.v2.api.callback import UserCallback # noqa: F811 from ray.train.v2.api.config import ( # noqa: F811 + CheckpointConfig, FailureConfig, RunConfig, ScalingConfig, diff --git a/python/ray/train/v2/api/config.py b/python/ray/train/v2/api/config.py index 2b8a82f41701..5842449b184b 100644 --- a/python/ray/train/v2/api/config.py +++ b/python/ray/train/v2/api/config.py @@ -2,12 +2,11 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, List, Literal, Optional, Union import pyarrow.fs from ray.air.config import ( - CheckpointConfig, FailureConfig as FailureConfigV1, ScalingConfig as ScalingConfigV1, ) @@ -132,6 +131,66 @@ def num_tpus_per_worker(self): return self._resources_per_worker_not_none.get("TPU", 0) +@dataclass +@PublicAPI(stability="stable") +class CheckpointConfig: + """Configuration for checkpointing. + + Default behavior is to persist all checkpoints reported with + :meth:`ray.train.report` to disk. If ``num_to_keep`` is set, + the default retention policy is to keep the most recent checkpoints. + + Args: + num_to_keep: The maximum number of checkpoints to keep. + If you report more checkpoints than this, the oldest + (or lowest-scoring, if ``checkpoint_score_attribute`` is set) + checkpoint will be deleted. + If this is ``None`` then all checkpoints will be kept. Must be >= 1. + checkpoint_score_attribute: The attribute that will be used to + score checkpoints to determine which checkpoints should be kept. + This attribute must be a key from the metrics dictionary + attached to the checkpoint. This attribute must have a numerical value. + checkpoint_score_order: Either "max" or "min". + If "max"/"min", then checkpoints with highest/lowest values of + the ``checkpoint_score_attribute`` will be kept. Defaults to "max". + checkpoint_frequency: [Deprecated] + checkpoint_at_end: [Deprecated] + """ + + num_to_keep: Optional[int] = None + checkpoint_score_attribute: Optional[str] = None + checkpoint_score_order: Literal["max", "min"] = "max" + checkpoint_frequency: Union[Optional[int], Literal[_DEPRECATED]] = _DEPRECATED + checkpoint_at_end: Union[Optional[bool], Literal[_DEPRECATED]] = _DEPRECATED + + def __post_init__(self): + if self.checkpoint_frequency != _DEPRECATED: + raise DeprecationWarning( + "`checkpoint_frequency` is deprecated since it does not " + "apply to user-defined training functions. " + "Please remove this argument from your CheckpointConfig." + ) + + if self.checkpoint_at_end != _DEPRECATED: + raise DeprecationWarning( + "`checkpoint_at_end` is deprecated since it does not " + "apply to user-defined training functions. " + "Please remove this argument from your CheckpointConfig." + ) + + if self.num_to_keep is not None and self.num_to_keep <= 0: + raise ValueError( + f"Received invalid num_to_keep: {self.num_to_keep}. " + "Must be None or an integer >= 1." + ) + + if self.checkpoint_score_order not in ("max", "min"): + raise ValueError( + f"Received invalid checkpoint_score_order: {self.checkpoint_score_order}. " + "Must be 'max' or 'min'." + ) + + @dataclass class FailureConfig(FailureConfigV1): """Configuration related to failure handling of each training run. From bd468baf5983b3196e896cfcabb6cae63d12a604 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Wed, 22 Oct 2025 15:14:29 -0700 Subject: [PATCH 2/7] remove beta for storage path Signed-off-by: Justin Yu --- python/ray/train/v2/api/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/train/v2/api/config.py b/python/ray/train/v2/api/config.py index 5842449b184b..d6c4f7823021 100644 --- a/python/ray/train/v2/api/config.py +++ b/python/ray/train/v2/api/config.py @@ -221,12 +221,12 @@ class RunConfig: Args: name: Name of the trial or experiment. If not provided, will be deduced from the Trainable. - storage_path: [Beta] Path where all results and checkpoints are persisted. + storage_path: Path where all results and checkpoints are persisted. Can be a local directory or a destination on cloud storage. For multi-node training/tuning runs, this must be set to a shared storage location (e.g., S3, NFS). This defaults to the local ``~/ray_results`` directory. - storage_filesystem: [Beta] A custom filesystem to use for storage. + storage_filesystem: A custom filesystem to use for storage. If this is provided, `storage_path` should be a path with its prefix stripped (e.g., `s3://bucket/path` -> `bucket/path`). failure_config: Failure mode configuration. From 4e698580e4da015c6edda55ad9ccb3215e661262 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Wed, 22 Oct 2025 15:14:53 -0700 Subject: [PATCH 3/7] remove todo Signed-off-by: Justin Yu --- python/ray/train/v2/api/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/ray/train/v2/api/config.py b/python/ray/train/v2/api/config.py index d6c4f7823021..5c0c6672ecf5 100644 --- a/python/ray/train/v2/api/config.py +++ b/python/ray/train/v2/api/config.py @@ -303,7 +303,6 @@ def __post_init__(self): "https://github.com/ray-project/ray/issues/49454" ) - # TODO: Create a separate V2 CheckpointConfig class. if not isinstance(self.checkpoint_config, CheckpointConfig): raise ValueError( f"Invalid `CheckpointConfig` type: {self.checkpoint_config.__class__}. " From 0c1b6d9a086759d3a5a13e59240d77ba3f3694a8 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Wed, 22 Oct 2025 15:21:55 -0700 Subject: [PATCH 4/7] deprecated in docstrings Signed-off-by: Justin Yu --- python/ray/train/v2/jax/jax_trainer.py | 7 +------ python/ray/train/v2/lightgbm/lightgbm_trainer.py | 8 ++------ python/ray/train/v2/tensorflow/tensorflow_trainer.py | 6 ++---- python/ray/train/v2/torch/torch_trainer.py | 8 ++------ python/ray/train/v2/xgboost/xgboost_trainer.py | 8 ++------ 5 files changed, 9 insertions(+), 28 deletions(-) diff --git a/python/ray/train/v2/jax/jax_trainer.py b/python/ray/train/v2/jax/jax_trainer.py index 1ea2eb8072d9..04d8c7f076e3 100644 --- a/python/ray/train/v2/jax/jax_trainer.py +++ b/python/ray/train/v2/jax/jax_trainer.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Callable, Dict, Optional, Union from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated -from ray.train import Checkpoint, DataConfig +from ray.train import DataConfig from ray.train.trainer import GenDataset from ray.train.v2.api.config import RunConfig, ScalingConfig from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer @@ -116,9 +116,6 @@ def main(argv: Sequence[str]): by calling ``ray.train.get_dataset_shard(name)``. Sharding and additional configuration can be done by passing in a ``dataset_config``. - resume_from_checkpoint: A checkpoint to resume training from. - This checkpoint can be accessed from within ``train_loop_per_worker`` - by calling ``ray.train.get_checkpoint()``. """ def __init__( @@ -131,7 +128,6 @@ def __init__( dataset_config: Optional[Dict[str, DataConfig]] = None, run_config: Optional[RunConfig] = None, datasets: Optional[Dict[str, GenDataset]] = None, - resume_from_checkpoint: Optional[Checkpoint] = None, ): if not jax_config: jax_config = JaxConfig( @@ -145,7 +141,6 @@ def __init__( dataset_config=dataset_config, run_config=run_config, datasets=datasets, - resume_from_checkpoint=resume_from_checkpoint, ) @classmethod diff --git a/python/ray/train/v2/lightgbm/lightgbm_trainer.py b/python/ray/train/v2/lightgbm/lightgbm_trainer.py index b7626fe3007d..dd3c30acf1ba 100644 --- a/python/ray/train/v2/lightgbm/lightgbm_trainer.py +++ b/python/ray/train/v2/lightgbm/lightgbm_trainer.py @@ -105,12 +105,8 @@ def train_fn_per_worker(config: dict): dataset_config: The configuration for ingesting the input ``datasets``. By default, all the Ray Dataset are split equally across workers. See :class:`~ray.train.DataConfig` for more details. - resume_from_checkpoint: A checkpoint to resume training from. - This checkpoint can be accessed from within ``train_loop_per_worker`` - by calling ``ray.train.get_checkpoint()``. - metadata: Dict that should be made available via - `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()` - for checkpoints saved from this Trainer. Must be JSON-serializable. + resume_from_checkpoint: [Deprecated] + metadata: [Deprecated] """ def __init__( diff --git a/python/ray/train/v2/tensorflow/tensorflow_trainer.py b/python/ray/train/v2/tensorflow/tensorflow_trainer.py index 43a79b458f2a..44e7628bf9ea 100644 --- a/python/ray/train/v2/tensorflow/tensorflow_trainer.py +++ b/python/ray/train/v2/tensorflow/tensorflow_trainer.py @@ -156,10 +156,8 @@ def train_loop_per_worker(config): by calling ``ray.train.get_dataset_shard(name)``. Sharding and additional configuration can be done by passing in a ``dataset_config``. - resume_from_checkpoint: A checkpoint to resume training from. - metadata: Dict that should be made available via - `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()` - for checkpoints saved from this Trainer. Must be JSON-serializable. + resume_from_checkpoint: [Deprecated] + metadata: [Deprecated] """ def __init__( diff --git a/python/ray/train/v2/torch/torch_trainer.py b/python/ray/train/v2/torch/torch_trainer.py index ded4ddb17450..454cfcec4601 100644 --- a/python/ray/train/v2/torch/torch_trainer.py +++ b/python/ray/train/v2/torch/torch_trainer.py @@ -163,12 +163,8 @@ def train_fn_per_worker(config): dataset_config: The configuration for ingesting the input ``datasets``. By default, all the Ray Dataset are split equally across workers. See :class:`~ray.train.DataConfig` for more details. - resume_from_checkpoint: A checkpoint to resume training from. - This checkpoint can be accessed from within ``train_loop_per_worker`` - by calling ``ray.train.get_checkpoint()``. - metadata: Dict that should be made available via - `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()` - for checkpoints saved from this Trainer. Must be JSON-serializable. + resume_from_checkpoint: [Deprecated] + metadata: [Deprecated] """ def __init__( diff --git a/python/ray/train/v2/xgboost/xgboost_trainer.py b/python/ray/train/v2/xgboost/xgboost_trainer.py index 065ca078df2f..61de1212a8c1 100644 --- a/python/ray/train/v2/xgboost/xgboost_trainer.py +++ b/python/ray/train/v2/xgboost/xgboost_trainer.py @@ -103,12 +103,8 @@ def train_fn_per_worker(config: dict): dataset_config: The configuration for ingesting the input ``datasets``. By default, all the Ray Dataset are split equally across workers. See :class:`~ray.train.DataConfig` for more details. - resume_from_checkpoint: A checkpoint to resume training from. - This checkpoint can be accessed from within ``train_loop_per_worker`` - by calling ``ray.train.get_checkpoint()``. - metadata: Dict that should be made available via - `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()` - for checkpoints saved from this Trainer. Must be JSON-serializable. + resume_from_checkpoint: [Deprecated] + metadata: [Deprecated] """ def __init__( From 3e26689e735f4cb1da31064620f6157a2945990f Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Thu, 23 Oct 2025 14:13:55 -0700 Subject: [PATCH 5/7] fix checkpoint config import Signed-off-by: Justin Yu --- .../ray/train/v2/tests/test_async_checkpointing_validation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/ray/train/v2/tests/test_async_checkpointing_validation.py b/python/ray/train/v2/tests/test_async_checkpointing_validation.py index de3ec6b599b0..fb8f8a8a3a01 100644 --- a/python/ray/train/v2/tests/test_async_checkpointing_validation.py +++ b/python/ray/train/v2/tests/test_async_checkpointing_validation.py @@ -6,8 +6,7 @@ import ray import ray.cloudpickle as ray_pickle -from ray.air.config import CheckpointConfig -from ray.train import Checkpoint, RunConfig, ScalingConfig +from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig from ray.train.tests.util import create_dict_checkpoint, load_dict_checkpoint from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer from ray.train.v2.api.exceptions import WorkerGroupError From f91b524ae8f87ba99345b2f699ded79dca7b8060 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Mon, 27 Oct 2025 13:43:52 -0700 Subject: [PATCH 6/7] remove unused deprecated doc code Signed-off-by: Justin Yu --- doc/source/train/doc_code/key_concepts.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/doc/source/train/doc_code/key_concepts.py b/doc/source/train/doc_code/key_concepts.py index f35a79cf8f84..3fee6b5c4bec 100644 --- a/doc/source/train/doc_code/key_concepts.py +++ b/doc/source/train/doc_code/key_concepts.py @@ -57,22 +57,6 @@ def train_fn(config): ) # __checkpoint_config_end__ -# __checkpoint_config_ckpt_freq_start__ -from ray.train import RunConfig, CheckpointConfig - -run_config = RunConfig( - checkpoint_config=CheckpointConfig( - # Checkpoint every iteration. - checkpoint_frequency=1, - # Only keep the latest checkpoint and delete the others. - num_to_keep=1, - ) -) - -# from ray.train.xgboost import XGBoostTrainer -# trainer = XGBoostTrainer(..., run_config=run_config) -# __checkpoint_config_ckpt_freq_end__ - # __result_metrics_start__ result = trainer.fit() From d112149ce2390ad6cbe8ff479477778df6ec0529 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Mon, 27 Oct 2025 14:34:10 -0700 Subject: [PATCH 7/7] update result error doc code Signed-off-by: Justin Yu --- doc/source/train/doc_code/key_concepts.py | 27 +++++++++++++++-------- doc/source/train/user-guides/results.rst | 6 ++--- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/doc/source/train/doc_code/key_concepts.py b/doc/source/train/doc_code/key_concepts.py index 3fee6b5c4bec..e1ce8dad273f 100644 --- a/doc/source/train/doc_code/key_concepts.py +++ b/doc/source/train/doc_code/key_concepts.py @@ -4,8 +4,7 @@ from pathlib import Path import tempfile -from ray import train -from ray.train import Checkpoint +import ray.train from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer @@ -13,13 +12,14 @@ def train_fn(config): for i in range(3): with tempfile.TemporaryDirectory() as temp_checkpoint_dir: Path(temp_checkpoint_dir).joinpath("model.pt").touch() - train.report( - {"loss": i}, checkpoint=Checkpoint.from_directory(temp_checkpoint_dir) + ray.train.report( + {"loss": i}, + checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir), ) trainer = DataParallelTrainer( - train_fn, scaling_config=train.ScalingConfig(num_workers=2) + train_fn, scaling_config=ray.train.ScalingConfig(num_workers=2) ) @@ -113,9 +113,18 @@ def train_fn(config): # __result_restore_end__ -# __result_error_start__ -if result.error: - assert isinstance(result.error, Exception) +def error_train_fn(config): + raise RuntimeError("Simulated training error") + + +trainer = DataParallelTrainer( + error_train_fn, scaling_config=ray.train.ScalingConfig(num_workers=1) +) - print("Got exception:", result.error) +# __result_error_start__ +try: + result = trainer.fit() +except ray.train.TrainingFailedError as e: + if isinstance(e, ray.train.WorkerGroupError): + print(e.worker_failures) # __result_error_end__ diff --git a/doc/source/train/user-guides/results.rst b/doc/source/train/user-guides/results.rst index 703b45166441..63d6985645dd 100644 --- a/doc/source/train/user-guides/results.rst +++ b/doc/source/train/user-guides/results.rst @@ -124,8 +124,8 @@ access the storage location, which is useful if the path is on cloud storage. -Viewing Errors --------------- +Catching Errors +--------------- If an error occurred during training, :attr:`Result.error ` will be set and contain the exception that was raised. @@ -138,7 +138,7 @@ that was raised. Finding results on persistent storage ------------------------------------- -All training results, including reported metrics, checkpoints, and error files, +All training results including reported metrics and checkpoints are stored on the configured :ref:`persistent storage `. See :ref:`the persistent storage guide ` to configure this location