Skip to content

Commit 80bef0c

Browse files
RecML authorsrecml authors
authored andcommitted
Reverts changelist 793734230
PiperOrigin-RevId: 806692343
1 parent 847628b commit 80bef0c

25 files changed

+4629
-166
lines changed

recml/core/data/tf_dataset_factory.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import re
2525
from typing import Any, Protocol
2626

27+
from absl import flags
2728
from absl import logging
2829
import jax
2930
from recml.core.utils import types
@@ -162,12 +163,12 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
162163
Defaults to False.
163164
seed: An optional seed to use for deterministic shuffling / preprocessing.
164165
Defaults to None.
165-
tf_data_service_address: An optional URI of a tf.data service to offload
166-
preprocessing onto during training. The URI should be in the format
167-
"protocol://address", e.g. "grpc://tf-data-service:5050". If `None` no
168-
data service will be applied.
166+
enable_tf_data_service: Whether to apply tf.data service for this dataset.
167+
If True, flag `tf_data_service_address` must be set.
169168
tf_data_service_policy: Sharding policy to use for tf.data service when it
170169
is enabled.
170+
tf_data_service_job_name: Job name to use for tf.data service. If None, the
171+
default job name will be used.
171172
feature_spec: A mapping of feature keys to `FixedLenFeature`,
172173
`VarLenFeature`, `SparseFeature`, or `RaggedFeature` values. This will be
173174
used to parse the TF examples, or as context_features spec to parse TF
@@ -208,7 +209,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
208209
tensorflow.
209210
debug: An optional boolean indicating whether to debug input boundedness. If
210211
`True`, the dataset will consist of a single batch that's cached and
211-
infinitely repeated
212+
infinitely repeated.
212213
"""
213214

214215
cache_reading: bool = False
@@ -231,7 +232,8 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
231232
readahead: str | None = None
232233
group_uris_by_dir: bool = False
233234
seed: int | None = None
234-
tf_data_service_address: str | None = None
235+
enable_tf_data_service: bool = False
236+
tf_data_service_job_name: str | None = None
235237
tf_data_service_policy: tf.data.experimental.service.ShardingPolicy = (
236238
tf.data.experimental.service.ShardingPolicy.OFF
237239
)
@@ -249,7 +251,12 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
249251
debug: bool = False
250252

251253
def __post_init__(self):
252-
if self.tf_data_service_address is not None:
254+
if self.enable_tf_data_service:
255+
if flags.FLAGS.tf_data_service_address is None:
256+
raise ValueError(
257+
"Flag `tf_data_service_address` must be set when"
258+
" `enable_tf_data_service` is True."
259+
)
253260
if self.seed is not None:
254261
raise ValueError("`seed` must be None for data service.")
255262
if self.sharding:
@@ -533,23 +540,26 @@ def _maybe_apply_tf_data_service(
533540
self, dataset: tf.data.Dataset
534541
) -> tf.data.Dataset:
535542
"""Applies the tf.data service to the dataset."""
536-
if self.tf_data_service_address is None:
543+
if not self.enable_tf_data_service:
537544
return dataset
538545

546+
tf_data_service_address = flags.FLAGS.tf_data_service_address
547+
539548
per_proc_batch_size = self.sharding_info.per_process_batch_size(
540549
self.global_batch_size
541550
)
542551
logging.info(
543552
"Applying tf.data service with address %s and per replica batch"
544553
" size %s",
545-
self.tf_data_service_address,
554+
tf_data_service_address,
546555
per_proc_batch_size,
547556
)
548557
return dataset.apply(
549558
tf.data.experimental.service.distribute(
550559
processing_mode=self.tf_data_service_policy,
551-
service=self.tf_data_service_address,
552-
job_name=f"bs_{per_proc_batch_size}",
560+
service=tf_data_service_address,
561+
job_name=self.tf_data_service_job_name
562+
or "tf_data_service_shared_job_name",
553563
)
554564
)
555565

recml/core/ops/hstu_ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def _apply_mask(
125125
masks = []
126126
if mask_ref is not None:
127127
if k_in_lanes:
128-
mask = pl.load(mask_ref, (slice(None), k_slice))
128+
mask = mask_ref[:, k_slice]
129129
else:
130-
mask = pl.load(mask_ref, (k_slice, slice(None)))
130+
mask = mask_ref[k_slice, :]
131131

132132
snm = jnp.where(should_not_mask, 1, 0)
133133
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
@@ -156,7 +156,7 @@ def _apply_mask(
156156
k_sequence = k_offset + jax.lax.broadcasted_iota(
157157
jnp.int32, (k_slice.size, bq), 0
158158
)
159-
q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq]
159+
q_sequence = q_sequence_ref[:1, :] # [1, bq]
160160
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))
161161

162162
assert q_sequence.shape == k_sequence.shape
@@ -170,7 +170,7 @@ def _apply_mask(
170170

171171
if q_segment_ids_ref is not None:
172172
if k_in_lanes:
173-
kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice]
173+
kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice]
174174
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
175175
if rem:
176176
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
@@ -181,9 +181,9 @@ def _apply_mask(
181181
if rem:
182182
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
183183
kv_ids = pltpu.repeat(
184-
pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1
184+
kv_segment_ids_ref[k_slice, :], repeats, axis=1
185185
) # [k_slice, bq]
186-
q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq]
186+
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
187187
masks.append(q_ids == kv_ids)
188188

189189
if masks:
@@ -228,7 +228,7 @@ def body(kv_compute_index, _):
228228
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)
229229

230230
q = q_ref[...]
231-
k = pl.load(k_ref, (slice_k, slice(None)))
231+
k = k_ref[slice_k, :]
232232
qk = jax.lax.dot_general(
233233
q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32
234234
)
@@ -256,7 +256,7 @@ def body(kv_compute_index, _):
256256
)
257257

258258
sv_dims = NN_DIM_NUMBERS
259-
v = pl.load(v_ref, (slice_k, slice(None)))
259+
v = v_ref[slice_k, :]
260260

261261
to_float32 = lambda x: x.astype(jnp.float32)
262262
v = to_float32(v)

recml/core/training/keras_trainer.py

Lines changed: 74 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import abc
1919
from collections.abc import Mapping
2020
import dataclasses
21+
import functools
2122
import gc
2223
import os
2324
import time
@@ -96,7 +97,6 @@ def export_model(self, model: keras.Model, model_dir: str):
9697
model: The Keras model constructed by `create_model`.
9798
model_dir: The model directory passed to the trainer.
9899
"""
99-
model.save(os.path.join(model_dir, core.KERAS_MODEL_SAVEFILE))
100100

101101

102102
class KerasTrainer(core.Trainer[KerasTask]):
@@ -118,6 +118,7 @@ def __init__(
118118
max_checkpoints_to_keep: int = 5,
119119
checkpoint_save_interval_epochs: int = 1,
120120
rng_seed: int = core.DEFAULT_RNG_SEED,
121+
legacy_checkpoint_format: bool = True,
121122
):
122123
"""Initializes the instance."""
123124

@@ -143,60 +144,77 @@ def __init__(
143144
self._steps_per_eval = steps_per_eval
144145
self._continuous_eval_timeout = continuous_eval_timeout
145146
self._steps_per_loop = steps_per_loop
146-
self._checkpoint_manager = None
147147
self._marker_path = os.path.join(
148148
model_dir, core.TRAINING_COMPLETE_MARKER_FILE
149149
)
150150
self._checkpoint_dir = os.path.join(model_dir, core.CHECKPOINT_DIR)
151+
self._max_checkpoints_to_keep = max_checkpoints_to_keep
152+
self._checkpoint_save_interval_epochs = checkpoint_save_interval_epochs
153+
self._legacy_checkpoint_format = legacy_checkpoint_format
151154

155+
@functools.cached_property
156+
def train_callbacks(self) -> list[keras.callbacks.Callback]:
157+
"""Returns the training callbacks."""
152158
if keras.backend.backend() == "jax":
153-
self._checkpoint_manager = keras_utils.KerasOrbaxCheckpointManager(
154-
checkpoint_dir=self._checkpoint_dir,
155-
max_to_keep=max_checkpoints_to_keep,
156-
save_interval_epochs=checkpoint_save_interval_epochs,
157-
)
158-
self._train_callbacks = [
159+
if self._legacy_checkpoint_format:
160+
checkpoint_manager = keras_utils.KerasOrbaxCheckpointManager(
161+
checkpoint_dir=self._checkpoint_dir,
162+
max_to_keep=self._max_checkpoints_to_keep,
163+
save_interval_epochs=self._checkpoint_save_interval_epochs,
164+
)
165+
else:
166+
checkpoint_manager = keras_utils.KerasOrbaxCheckpointManagerV2(
167+
checkpoint_dir=self._checkpoint_dir,
168+
max_to_keep=self._max_checkpoints_to_keep,
169+
save_interval_epochs=self._checkpoint_save_interval_epochs,
170+
)
171+
return [
159172
keras_utils.EpochSummaryCallback(
160-
log_dir=os.path.join(model_dir, core.LOG_DIR),
161-
steps_per_epoch=steps_per_loop,
173+
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
174+
steps_per_epoch=self._steps_per_loop,
162175
write_steps_per_second=True,
163176
),
164177
keras_utils.EpochOrbaxCheckpointAndRestoreCallback(
165-
checkpoint_manager=self._checkpoint_manager,
178+
checkpoint_manager=checkpoint_manager,
166179
marker_path=self._marker_path,
167180
),
168181
]
169-
self._eval_callbacks = [
182+
return [
183+
keras.callbacks.TensorBoard(
184+
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
185+
write_steps_per_second=True,
186+
),
187+
keras.callbacks.BackupAndRestore(
188+
backup_dir=os.path.join(self._model_dir, core.BACKUP_DIR),
189+
),
190+
keras.callbacks.ModelCheckpoint(
191+
filepath=os.path.join(
192+
self._model_dir,
193+
core.CHECKPOINT_DIR,
194+
"ckpt-{epoch:d}.weights.h5",
195+
),
196+
save_weights_only=True,
197+
verbose=1,
198+
),
199+
]
200+
201+
@functools.cached_property
202+
def eval_callbacks(self) -> list[keras.callbacks.Callback]:
203+
"""Returns the evaluation callbacks."""
204+
if keras.backend.backend() == "jax":
205+
return [
170206
keras_utils.EpochSummaryCallback(
171-
log_dir=os.path.join(model_dir, core.LOG_DIR),
172-
steps_per_epoch=steps_per_loop,
207+
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
208+
steps_per_epoch=self._steps_per_loop,
173209
write_steps_per_second=False,
174210
),
175211
]
176-
else:
177-
self._checkpoint_manager = None
178-
self._train_callbacks = [
179-
keras.callbacks.TensorBoard(
180-
log_dir=os.path.join(model_dir, core.LOG_DIR),
181-
write_steps_per_second=True,
182-
),
183-
keras.callbacks.BackupAndRestore(
184-
backup_dir=os.path.join(model_dir, core.BACKUP_DIR),
185-
),
186-
keras.callbacks.ModelCheckpoint(
187-
filepath=os.path.join(
188-
model_dir, core.CHECKPOINT_DIR, "ckpt-{epoch:d}.weights.h5"
189-
),
190-
save_weights_only=True,
191-
verbose=1,
192-
),
193-
]
194-
self._eval_callbacks = [
195-
keras.callbacks.TensorBoard(
196-
log_dir=os.path.join(model_dir, core.LOG_DIR),
197-
write_steps_per_second=True,
198-
),
199-
]
212+
return [
213+
keras.callbacks.TensorBoard(
214+
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
215+
write_steps_per_second=True,
216+
),
217+
]
200218

201219
def _maybe_get_model_kws(
202220
self, task: KerasTask, dataset: tf.data.Dataset
@@ -218,7 +236,7 @@ def train(self, task: KerasTask) -> core.Logs:
218236
dataset,
219237
epochs=self._train_epochs,
220238
steps_per_epoch=self._steps_per_loop,
221-
callbacks=self._train_callbacks,
239+
callbacks=self.train_callbacks,
222240
)
223241
model.summary(print_fn=logging.info)
224242

@@ -237,14 +255,14 @@ def evaluate(self, task: KerasTask) -> core.Logs:
237255
if keras.backend.backend() == "jax":
238256
[tb_cbk] = [
239257
cbk
240-
for cbk in self._eval_callbacks
258+
for cbk in self.eval_callbacks
241259
if isinstance(cbk, keras_utils.EpochSummaryCallback)
242260
]
243261
epoch_start_time = time.time()
244262
history = model.evaluate(
245263
dataset,
246264
steps=self._steps_per_eval,
247-
callbacks=self._eval_callbacks,
265+
callbacks=self.eval_callbacks,
248266
return_dict=True,
249267
)
250268
epoch_dt = time.time() - epoch_start_time
@@ -257,7 +275,7 @@ def evaluate(self, task: KerasTask) -> core.Logs:
257275
return model.evaluate(
258276
dataset,
259277
steps=self._steps_per_eval,
260-
callbacks=self._eval_callbacks,
278+
callbacks=self.eval_callbacks,
261279
)
262280

263281
def train_and_evaluate(self, task: KerasTask) -> core.Logs:
@@ -277,7 +295,7 @@ def train_and_evaluate(self, task: KerasTask) -> core.Logs:
277295
steps_per_epoch=self._steps_per_loop,
278296
# Explicitly set to None for deterministic evaluation.
279297
validation_steps=None,
280-
callbacks=self._train_callbacks,
298+
callbacks=self.train_callbacks,
281299
)
282300
model.summary(print_fn=logging.info)
283301

@@ -308,7 +326,10 @@ def timeout_fn() -> bool:
308326
else:
309327
steps_msg = "running complete evaluation..."
310328

329+
use_legacy_checkpoint_format = self._legacy_checkpoint_format
330+
311331
class _RestoreCallback(keras.callbacks.Callback):
332+
"""Callback for restoring the model from the latest checkpoint."""
312333

313334
def __init__(
314335
self,
@@ -319,9 +340,14 @@ def __init__(
319340
self._epoch = epoch
320341

321342
def on_test_begin(self, logs: Mapping[str, Any] | None = None):
322-
keras_utils.restore_keras_model(
323-
model, self._checkpoint_dir, step=self._epoch
324-
)
343+
if use_legacy_checkpoint_format:
344+
keras_utils.restore_keras_model(
345+
model, self._checkpoint_dir, step=self._epoch
346+
)
347+
else:
348+
keras_utils.restore_keras_checkpoint(
349+
self._checkpoint_dir, model=model, epoch=self._epoch
350+
)
325351

326352
history = None
327353
for epoch in ocp.checkpoint_utils.checkpoints_iterator(
@@ -332,7 +358,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
332358
restore_callback = _RestoreCallback(self._checkpoint_dir, epoch)
333359
[tb_cbk] = [
334360
cbk
335-
for cbk in self._eval_callbacks
361+
for cbk in self.eval_callbacks
336362
if isinstance(cbk, keras_utils.EpochSummaryCallback)
337363
]
338364
try:
@@ -346,7 +372,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
346372
history = model.evaluate(
347373
eval_dataset,
348374
steps=self._steps_per_eval,
349-
callbacks=[restore_callback] + self._eval_callbacks,
375+
callbacks=[restore_callback] + self.eval_callbacks,
350376
return_dict=True,
351377
)
352378

0 commit comments

Comments
 (0)