Skip to content

Commit 5d4805b

Browse files
1vndeliahu
authored andcommitted
Simplify reviews example (#52)
(cherry picked from commit d8b55a7)
1 parent 6aa228b commit 5d4805b

File tree

3 files changed

+11
-13
lines changed

3 files changed

+11
-13
lines changed

examples/reviews/implementations/models/t2t_transformer.py renamed to examples/reviews/implementations/models/transformer.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,32 @@
1010

1111

1212
def create_estimator(run_config, model_config):
13-
# t2t expects these keys in run_config
14-
run_config.data_parallelism = None
15-
run_config.t2t_device_info = {"num_async_replicas": 1}
16-
1713
hparams = trainer_lib.create_hparams("transformer_base_single_gpu")
1814

15+
# SentimentIMDBCortex subclasses SentimentIMDB
1916
problem = SentimentIMDBCortex(list(model_config["aggregates"]["reviews_vocab"]))
20-
p_hparams = problem.get_hparams(hparams)
2117
hparams.problem = problem
22-
hparams.problem_hparams = p_hparams
18+
hparams.problem_hparams = problem.get_hparams(hparams)
2319

20+
# metrics specific to the sentiment problem
2421
problem.eval_metrics = lambda: [
2522
metrics.Metrics.ACC_TOP5,
2623
metrics.Metrics.ACC_PER_SEQ,
2724
metrics.Metrics.NEG_LOG_PERPLEXITY,
2825
]
2926

30-
# t2t expects this key
31-
hparams.warm_start_from = None
32-
3327
# reduce memory load
3428
hparams.num_hidden_layers = 2
3529
hparams.hidden_size = 32
3630
hparams.filter_size = 32
3731
hparams.num_heads = 2
3832

39-
estimator = trainer_lib.create_estimator("transformer", hparams, run_config)
40-
return estimator
33+
# t2t expects these keys
34+
hparams.warm_start_from = None
35+
run_config.data_parallelism = None
36+
run_config.t2t_device_info = {"num_async_replicas": 1}
37+
38+
return trainer_lib.create_estimator("transformer", hparams, run_config)
4139

4240

4341
def transform_tensorflow(features, labels, model_config):

examples/reviews/resources/apis.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@
1212

1313
- kind: api
1414
name: sentiment-t2t
15-
model_name: t2t_transformer
15+
model_name: transformer
1616
compute:
1717
replicas: 1

examples/reviews/resources/models.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
num_steps: 5000
3232

3333
- kind: model
34-
name: t2t_transformer
34+
name: transformer
3535
type: classification
3636
target_column: label_indexed
3737
feature_columns:

0 commit comments

Comments
 (0)