|
2 | 2 | from dataclasses import dataclass |
3 | 3 | from functools import cached_property |
4 | 4 | from pathlib import Path |
5 | | -from typing import TYPE_CHECKING, List, Optional, Union |
| 5 | +from typing import TYPE_CHECKING, List, Literal, Optional, Union |
6 | 6 |
|
7 | 7 | import pyarrow.fs |
8 | 8 |
|
9 | 9 | from ray.air.config import ( |
10 | | - CheckpointConfig, |
11 | 10 | FailureConfig as FailureConfigV1, |
12 | 11 | ScalingConfig as ScalingConfigV1, |
13 | 12 | ) |
@@ -132,6 +131,66 @@ def num_tpus_per_worker(self): |
132 | 131 | return self._resources_per_worker_not_none.get("TPU", 0) |
133 | 132 |
|
134 | 133 |
|
| 134 | +@dataclass |
| 135 | +@PublicAPI(stability="stable") |
| 136 | +class CheckpointConfig: |
| 137 | + """Configuration for checkpointing. |
| 138 | +
|
| 139 | + Default behavior is to persist all checkpoints reported with |
| 140 | + :meth:`ray.train.report` to disk. If ``num_to_keep`` is set, |
| 141 | + the default retention policy is to keep the most recent checkpoints. |
| 142 | +
|
| 143 | + Args: |
| 144 | + num_to_keep: The maximum number of checkpoints to keep. |
| 145 | + If you report more checkpoints than this, the oldest |
| 146 | + (or lowest-scoring, if ``checkpoint_score_attribute`` is set) |
| 147 | + checkpoint will be deleted. |
| 148 | + If this is ``None`` then all checkpoints will be kept. Must be >= 1. |
| 149 | + checkpoint_score_attribute: The attribute that will be used to |
| 150 | + score checkpoints to determine which checkpoints should be kept. |
| 151 | + This attribute must be a key from the metrics dictionary |
| 152 | + attached to the checkpoint. This attribute must have a numerical value. |
| 153 | + checkpoint_score_order: Either "max" or "min". |
| 154 | + If "max"/"min", then checkpoints with highest/lowest values of |
| 155 | + the ``checkpoint_score_attribute`` will be kept. Defaults to "max". |
| 156 | + checkpoint_frequency: [Deprecated] |
| 157 | + checkpoint_at_end: [Deprecated] |
| 158 | + """ |
| 159 | + |
| 160 | + num_to_keep: Optional[int] = None |
| 161 | + checkpoint_score_attribute: Optional[str] = None |
| 162 | + checkpoint_score_order: Literal["max", "min"] = "max" |
| 163 | + checkpoint_frequency: Union[Optional[int], Literal[_DEPRECATED]] = _DEPRECATED |
| 164 | + checkpoint_at_end: Union[Optional[bool], Literal[_DEPRECATED]] = _DEPRECATED |
| 165 | + |
| 166 | + def __post_init__(self): |
| 167 | + if self.checkpoint_frequency != _DEPRECATED: |
| 168 | + raise DeprecationWarning( |
| 169 | + "`checkpoint_frequency` is deprecated since it does not " |
| 170 | + "apply to user-defined training functions. " |
| 171 | + "Please remove this argument from your CheckpointConfig." |
| 172 | + ) |
| 173 | + |
| 174 | + if self.checkpoint_at_end != _DEPRECATED: |
| 175 | + raise DeprecationWarning( |
| 176 | + "`checkpoint_at_end` is deprecated since it does not " |
| 177 | + "apply to user-defined training functions. " |
| 178 | + "Please remove this argument from your CheckpointConfig." |
| 179 | + ) |
| 180 | + |
| 181 | + if self.num_to_keep is not None and self.num_to_keep <= 0: |
| 182 | + raise ValueError( |
| 183 | + f"Received invalid num_to_keep: {self.num_to_keep}. " |
| 184 | + "Must be None or an integer >= 1." |
| 185 | + ) |
| 186 | + |
| 187 | + if self.checkpoint_score_order not in ("max", "min"): |
| 188 | + raise ValueError( |
| 189 | + f"Received invalid checkpoint_score_order: {self.checkpoint_score_order}. " |
| 190 | + "Must be 'max' or 'min'." |
| 191 | + ) |
| 192 | + |
| 193 | + |
135 | 194 | @dataclass |
136 | 195 | class FailureConfig(FailureConfigV1): |
137 | 196 | """Configuration related to failure handling of each training run. |
@@ -162,12 +221,12 @@ class RunConfig: |
162 | 221 | Args: |
163 | 222 | name: Name of the trial or experiment. If not provided, will be deduced |
164 | 223 | from the Trainable. |
165 | | - storage_path: [Beta] Path where all results and checkpoints are persisted. |
| 224 | + storage_path: Path where all results and checkpoints are persisted. |
166 | 225 | Can be a local directory or a destination on cloud storage. |
167 | 226 | For multi-node training/tuning runs, this must be set to a |
168 | 227 | shared storage location (e.g., S3, NFS). |
169 | 228 | This defaults to the local ``~/ray_results`` directory. |
170 | | - storage_filesystem: [Beta] A custom filesystem to use for storage. |
| 229 | + storage_filesystem: A custom filesystem to use for storage. |
171 | 230 | If this is provided, `storage_path` should be a path with its |
172 | 231 | prefix stripped (e.g., `s3://bucket/path` -> `bucket/path`). |
173 | 232 | failure_config: Failure mode configuration. |
@@ -244,7 +303,6 @@ def __post_init__(self): |
244 | 303 | "https://github.com/ray-project/ray/issues/49454" |
245 | 304 | ) |
246 | 305 |
|
247 | | - # TODO: Create a separate V2 CheckpointConfig class. |
248 | 306 | if not isinstance(self.checkpoint_config, CheckpointConfig): |
249 | 307 | raise ValueError( |
250 | 308 | f"Invalid `CheckpointConfig` type: {self.checkpoint_config.__class__}. " |
|
0 commit comments