diff --git a/trax/data/inputs.py b/trax/data/inputs.py index de15497f4..ab814153d 100644 --- a/trax/data/inputs.py +++ b/trax/data/inputs.py @@ -62,7 +62,7 @@ data.Tokenize(vocab_file='en_8k.subword'), lambda g: filter(lambda x: x.shape[0] < 513, g), # At most 512 tokens. data.Shuffle(), - lambda g: map(lambda x: (x, x)), # Language models have inputs = targets. + lambda g: map(lambda x: (x, x), g), # Language models have inputs = targets. data.BucketByLength(boundaries=[ 32, 64, 128, 256, 512], batch_sizes=[ 32, 16, 8, 4, 2, 1]), data.AddLossWeights(id_to_mask=0)