Skip to content

Commit 092550b

Browse files
authored
[train] Clean up checkpoint config and trainer param deprecations (#58022)
Deprecate `CheckpointConfig(checkpoint_at_end, checkpoint_frequency)` and mark the `resume_from_checkpoint, metadata` Trainer constructor arguments as deprecated in the docstrings. Update the "inspecting results" user guide doc code to show how to catch and inspect errors raised by trainer.fit(). The previous recommendation to check result.error is unusable because we always raise the error which prevents the user from accessing the result object. --------- Signed-off-by: Justin Yu <[email protected]>
1 parent 4b0a268 commit 092550b

File tree

10 files changed

+95
-63
lines changed

10 files changed

+95
-63
lines changed

doc/source/train/doc_code/key_concepts.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,22 @@
44
from pathlib import Path
55
import tempfile
66

7-
from ray import train
8-
from ray.train import Checkpoint
7+
import ray.train
98
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
109

1110

1211
def train_fn(config):
1312
for i in range(3):
1413
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
1514
Path(temp_checkpoint_dir).joinpath("model.pt").touch()
16-
train.report(
17-
{"loss": i}, checkpoint=Checkpoint.from_directory(temp_checkpoint_dir)
15+
ray.train.report(
16+
{"loss": i},
17+
checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
1818
)
1919

2020

2121
trainer = DataParallelTrainer(
22-
train_fn, scaling_config=train.ScalingConfig(num_workers=2)
22+
train_fn, scaling_config=ray.train.ScalingConfig(num_workers=2)
2323
)
2424

2525

@@ -57,22 +57,6 @@ def train_fn(config):
5757
)
5858
# __checkpoint_config_end__
5959

60-
# __checkpoint_config_ckpt_freq_start__
61-
from ray.train import RunConfig, CheckpointConfig
62-
63-
run_config = RunConfig(
64-
checkpoint_config=CheckpointConfig(
65-
# Checkpoint every iteration.
66-
checkpoint_frequency=1,
67-
# Only keep the latest checkpoint and delete the others.
68-
num_to_keep=1,
69-
)
70-
)
71-
72-
# from ray.train.xgboost import XGBoostTrainer
73-
# trainer = XGBoostTrainer(..., run_config=run_config)
74-
# __checkpoint_config_ckpt_freq_end__
75-
7660

7761
# __result_metrics_start__
7862
result = trainer.fit()
@@ -129,9 +113,18 @@ def train_fn(config):
129113
# __result_restore_end__
130114

131115

132-
# __result_error_start__
133-
if result.error:
134-
assert isinstance(result.error, Exception)
116+
def error_train_fn(config):
117+
raise RuntimeError("Simulated training error")
118+
119+
120+
trainer = DataParallelTrainer(
121+
error_train_fn, scaling_config=ray.train.ScalingConfig(num_workers=1)
122+
)
135123

136-
print("Got exception:", result.error)
124+
# __result_error_start__
125+
try:
126+
result = trainer.fit()
127+
except ray.train.TrainingFailedError as e:
128+
if isinstance(e, ray.train.WorkerGroupError):
129+
print(e.worker_failures)
137130
# __result_error_end__

doc/source/train/user-guides/results.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ access the storage location, which is useful if the path is on cloud storage.
124124
125125
126126
127-
Viewing Errors
128-
--------------
127+
Catching Errors
128+
---------------
129129
If an error occurred during training,
130130
:attr:`Result.error <ray.train.Result>` will be set and contain the exception
131131
that was raised.
@@ -138,7 +138,7 @@ that was raised.
138138

139139
Finding results on persistent storage
140140
-------------------------------------
141-
All training results, including reported metrics, checkpoints, and error files,
141+
All training results including reported metrics and checkpoints
142142
are stored on the configured :ref:`persistent storage <train-log-dir>`.
143143

144144
See :ref:`the persistent storage guide <train-log-dir>` to configure this location

python/ray/train/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
) from exc
3838
from ray.train.v2.api.callback import UserCallback # noqa: F811
3939
from ray.train.v2.api.config import ( # noqa: F811
40+
CheckpointConfig,
4041
FailureConfig,
4142
RunConfig,
4243
ScalingConfig,

python/ray/train/v2/api/config.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
from dataclasses import dataclass
33
from functools import cached_property
44
from pathlib import Path
5-
from typing import TYPE_CHECKING, List, Optional, Union
5+
from typing import TYPE_CHECKING, List, Literal, Optional, Union
66

77
import pyarrow.fs
88

99
from ray.air.config import (
10-
CheckpointConfig,
1110
FailureConfig as FailureConfigV1,
1211
ScalingConfig as ScalingConfigV1,
1312
)
@@ -132,6 +131,66 @@ def num_tpus_per_worker(self):
132131
return self._resources_per_worker_not_none.get("TPU", 0)
133132

134133

134+
@dataclass
135+
@PublicAPI(stability="stable")
136+
class CheckpointConfig:
137+
"""Configuration for checkpointing.
138+
139+
Default behavior is to persist all checkpoints reported with
140+
:meth:`ray.train.report` to disk. If ``num_to_keep`` is set,
141+
the default retention policy is to keep the most recent checkpoints.
142+
143+
Args:
144+
num_to_keep: The maximum number of checkpoints to keep.
145+
If you report more checkpoints than this, the oldest
146+
(or lowest-scoring, if ``checkpoint_score_attribute`` is set)
147+
checkpoint will be deleted.
148+
If this is ``None`` then all checkpoints will be kept. Must be >= 1.
149+
checkpoint_score_attribute: The attribute that will be used to
150+
score checkpoints to determine which checkpoints should be kept.
151+
This attribute must be a key from the metrics dictionary
152+
attached to the checkpoint. This attribute must have a numerical value.
153+
checkpoint_score_order: Either "max" or "min".
154+
If "max"/"min", then checkpoints with highest/lowest values of
155+
the ``checkpoint_score_attribute`` will be kept. Defaults to "max".
156+
checkpoint_frequency: [Deprecated]
157+
checkpoint_at_end: [Deprecated]
158+
"""
159+
160+
num_to_keep: Optional[int] = None
161+
checkpoint_score_attribute: Optional[str] = None
162+
checkpoint_score_order: Literal["max", "min"] = "max"
163+
checkpoint_frequency: Union[Optional[int], Literal[_DEPRECATED]] = _DEPRECATED
164+
checkpoint_at_end: Union[Optional[bool], Literal[_DEPRECATED]] = _DEPRECATED
165+
166+
def __post_init__(self):
167+
if self.checkpoint_frequency != _DEPRECATED:
168+
raise DeprecationWarning(
169+
"`checkpoint_frequency` is deprecated since it does not "
170+
"apply to user-defined training functions. "
171+
"Please remove this argument from your CheckpointConfig."
172+
)
173+
174+
if self.checkpoint_at_end != _DEPRECATED:
175+
raise DeprecationWarning(
176+
"`checkpoint_at_end` is deprecated since it does not "
177+
"apply to user-defined training functions. "
178+
"Please remove this argument from your CheckpointConfig."
179+
)
180+
181+
if self.num_to_keep is not None and self.num_to_keep <= 0:
182+
raise ValueError(
183+
f"Received invalid num_to_keep: {self.num_to_keep}. "
184+
"Must be None or an integer >= 1."
185+
)
186+
187+
if self.checkpoint_score_order not in ("max", "min"):
188+
raise ValueError(
189+
f"Received invalid checkpoint_score_order: {self.checkpoint_score_order}. "
190+
"Must be 'max' or 'min'."
191+
)
192+
193+
135194
@dataclass
136195
class FailureConfig(FailureConfigV1):
137196
"""Configuration related to failure handling of each training run.
@@ -162,12 +221,12 @@ class RunConfig:
162221
Args:
163222
name: Name of the trial or experiment. If not provided, will be deduced
164223
from the Trainable.
165-
storage_path: [Beta] Path where all results and checkpoints are persisted.
224+
storage_path: Path where all results and checkpoints are persisted.
166225
Can be a local directory or a destination on cloud storage.
167226
For multi-node training/tuning runs, this must be set to a
168227
shared storage location (e.g., S3, NFS).
169228
This defaults to the local ``~/ray_results`` directory.
170-
storage_filesystem: [Beta] A custom filesystem to use for storage.
229+
storage_filesystem: A custom filesystem to use for storage.
171230
If this is provided, `storage_path` should be a path with its
172231
prefix stripped (e.g., `s3://bucket/path` -> `bucket/path`).
173232
failure_config: Failure mode configuration.
@@ -244,7 +303,6 @@ def __post_init__(self):
244303
"https://github.com/ray-project/ray/issues/49454"
245304
)
246305

247-
# TODO: Create a separate V2 CheckpointConfig class.
248306
if not isinstance(self.checkpoint_config, CheckpointConfig):
249307
raise ValueError(
250308
f"Invalid `CheckpointConfig` type: {self.checkpoint_config.__class__}. "

python/ray/train/v2/jax/jax_trainer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
33

44
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
5-
from ray.train import Checkpoint, DataConfig
5+
from ray.train import DataConfig
66
from ray.train.trainer import GenDataset
77
from ray.train.v2.api.config import RunConfig, ScalingConfig
88
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
@@ -116,9 +116,6 @@ def main(argv: Sequence[str]):
116116
by calling ``ray.train.get_dataset_shard(name)``.
117117
Sharding and additional configuration can be done by
118118
passing in a ``dataset_config``.
119-
resume_from_checkpoint: A checkpoint to resume training from.
120-
This checkpoint can be accessed from within ``train_loop_per_worker``
121-
by calling ``ray.train.get_checkpoint()``.
122119
"""
123120

124121
def __init__(
@@ -131,7 +128,6 @@ def __init__(
131128
dataset_config: Optional[Dict[str, DataConfig]] = None,
132129
run_config: Optional[RunConfig] = None,
133130
datasets: Optional[Dict[str, GenDataset]] = None,
134-
resume_from_checkpoint: Optional[Checkpoint] = None,
135131
):
136132
if not jax_config:
137133
jax_config = JaxConfig(
@@ -145,7 +141,6 @@ def __init__(
145141
dataset_config=dataset_config,
146142
run_config=run_config,
147143
datasets=datasets,
148-
resume_from_checkpoint=resume_from_checkpoint,
149144
)
150145

151146
@classmethod

python/ray/train/v2/lightgbm/lightgbm_trainer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,8 @@ def train_fn_per_worker(config: dict):
105105
dataset_config: The configuration for ingesting the input ``datasets``.
106106
By default, all the Ray Dataset are split equally across workers.
107107
See :class:`~ray.train.DataConfig` for more details.
108-
resume_from_checkpoint: A checkpoint to resume training from.
109-
This checkpoint can be accessed from within ``train_loop_per_worker``
110-
by calling ``ray.train.get_checkpoint()``.
111-
metadata: Dict that should be made available via
112-
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
113-
for checkpoints saved from this Trainer. Must be JSON-serializable.
108+
resume_from_checkpoint: [Deprecated]
109+
metadata: [Deprecated]
114110
"""
115111

116112
def __init__(

python/ray/train/v2/tensorflow/tensorflow_trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,8 @@ def train_loop_per_worker(config):
156156
by calling ``ray.train.get_dataset_shard(name)``.
157157
Sharding and additional configuration can be done by
158158
passing in a ``dataset_config``.
159-
resume_from_checkpoint: A checkpoint to resume training from.
160-
metadata: Dict that should be made available via
161-
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
162-
for checkpoints saved from this Trainer. Must be JSON-serializable.
159+
resume_from_checkpoint: [Deprecated]
160+
metadata: [Deprecated]
163161
"""
164162

165163
def __init__(

python/ray/train/v2/tests/test_async_checkpointing_validation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
import ray
88
import ray.cloudpickle as ray_pickle
9-
from ray.air.config import CheckpointConfig
10-
from ray.train import Checkpoint, RunConfig, ScalingConfig
9+
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig
1110
from ray.train.tests.util import create_dict_checkpoint, load_dict_checkpoint
1211
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
1312
from ray.train.v2.api.exceptions import WorkerGroupError

python/ray/train/v2/torch/torch_trainer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,8 @@ def train_fn_per_worker(config):
163163
dataset_config: The configuration for ingesting the input ``datasets``.
164164
By default, all the Ray Dataset are split equally across workers.
165165
See :class:`~ray.train.DataConfig` for more details.
166-
resume_from_checkpoint: A checkpoint to resume training from.
167-
This checkpoint can be accessed from within ``train_loop_per_worker``
168-
by calling ``ray.train.get_checkpoint()``.
169-
metadata: Dict that should be made available via
170-
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
171-
for checkpoints saved from this Trainer. Must be JSON-serializable.
166+
resume_from_checkpoint: [Deprecated]
167+
metadata: [Deprecated]
172168
"""
173169

174170
def __init__(

python/ray/train/v2/xgboost/xgboost_trainer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,8 @@ def train_fn_per_worker(config: dict):
103103
dataset_config: The configuration for ingesting the input ``datasets``.
104104
By default, all the Ray Dataset are split equally across workers.
105105
See :class:`~ray.train.DataConfig` for more details.
106-
resume_from_checkpoint: A checkpoint to resume training from.
107-
This checkpoint can be accessed from within ``train_loop_per_worker``
108-
by calling ``ray.train.get_checkpoint()``.
109-
metadata: Dict that should be made available via
110-
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
111-
for checkpoints saved from this Trainer. Must be JSON-serializable.
106+
resume_from_checkpoint: [Deprecated]
107+
metadata: [Deprecated]
112108
"""
113109

114110
def __init__(

0 commit comments

Comments
 (0)