Skip to content

Commit 8b6e300

Browse files
committed
Follow the original example in specifying input/output shape
1 parent af9ba92 commit 8b6e300

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

examples/ml_perf/configs/v6e_8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@
196196
# Set `num_steps` in the main config file instead of num_epochs, because we are
197197
# using a Python generator.
198198
training_config.num_steps = 2
199+
training_config.eval_freq = 1
199200

200201
# === Assign all configs to the root config ===
201202
config.dataset = dataset_config

examples/ml_perf/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def main(
8484
name=feature_name,
8585
table=table_config,
8686
# TODO: Verify whether it should be `(bsz, 1)` or
87-
# `(bsz, feature_list_length)`.
88-
input_shape=(per_host_batch_size, feature_list_length),
89-
output_shape=(per_host_batch_size, embedding_dim),
87+
# `(bsz, feature_list_length)`. The original example uses 1.
88+
input_shape=(global_batch_size, 1),
89+
output_shape=(global_batch_size, embedding_dim),
9090
)
9191

9292
# === Instantiate model ===

0 commit comments

Comments
 (0)