Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def input_fn(tf_records,
dataset = tf.data.TFRecordDataset(tf_records, buffer_size=10000)
dataset = dataset.shuffle(buffer_size=buffer_size)

dataset = dataset.map(parse_example,num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.map(lambda
x:tf.py_function(func=parse_example,inp=[x],Tout=(tf.int32,tf.int32)))
dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes)
dataset = dataset.repeat(epoch)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
Expand Down
2 changes: 1 addition & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

PYTHON="python3.10"

$PYTHON pre_process.py
# $PYTHON pre_process.py
10 changes: 9 additions & 1 deletion train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from data_pipeline import input_fn
from gpt2_model import *

from scripts.utils import write_csv
import timeit
_ROOT = os.path.abspath(os.path.dirname(__file__))
LOG_DIR = _ROOT + "/log"
MODEL_DIR = _ROOT + "/model"
Expand Down Expand Up @@ -44,7 +45,10 @@ def train(num_layers, embedding_size, num_heads, dff, max_seq_len, vocab_size,
train_tf_records = tf_records[:train_percent]
test_tf_records = tf_records[train_percent:]

print("Train dataset:")
train_dataset = input_fn(train_tf_records, batch_size=batch_size)

print("Test dataset:")
test_dataset = input_fn(test_tf_records, batch_size=batch_size)

if distributed:
Expand Down Expand Up @@ -74,4 +78,8 @@ def train(num_layers, embedding_size, num_heads, dff, max_seq_len, vocab_size,


if __name__ == "__main__":
start_time = timeit.default_timer()
skipped_time = 0
train()
time = timeit.default_timer() - skipped_time -start_time
write_csv(__file,time=time)