|
37 | 37 |
|
38 | 38 | import tensorflow as tf |
39 | 39 |
|
| 40 | +from tensorflow.python.eager import context |
40 | 41 | from tensorflow.python.util import nest |
41 | 42 |
|
42 | 43 |
|
@@ -324,7 +325,7 @@ def symbols_to_logits_fn(ids, i, cache): |
324 | 325 | # Note: Tensor.set_shape() does not work here since it merges shape info. |
325 | 326 | # TODO(llion); Find a more robust solution. |
326 | 327 | # pylint: disable=protected-access |
327 | | - if not self._hparams.use_eager_mode: |
| 328 | + if not context.in_eager_mode(): |
328 | 329 | for layer in cache: |
329 | 330 | cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels]) |
330 | 331 | cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels]) |
@@ -452,8 +453,7 @@ def transformer_prepare_encoder(inputs, target_space, hparams, features=None): |
452 | 453 | common_layers.shape_list(inputs)[1]) |
453 | 454 | # Append target_space_id embedding to inputs. |
454 | 455 | emb_target_space = common_layers.embedding( |
455 | | - target_space, 32, ishape_static[-1], name="target_space_embedding", |
456 | | - use_eager_mode=hparams.use_eager_mode) |
| 456 | + target_space, 32, ishape_static[-1], name="target_space_embedding") |
457 | 457 | emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) |
458 | 458 | encoder_input += emb_target_space |
459 | 459 | if hparams.pos == "timing": |
|
0 commit comments