Skip to content

First changes #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
wandb:
project_name: "board-representation-experiments-5th-feb"
id: "cross-entropy-gpt-warmup-scheduler-whole-seq-a"

training:
mode: 2
num_layers: 8
num_heads: 4
batch_size: 64
seq_len: 100
train_ratio: 0.8
val_ratio: 0.1
data_path: "info.txt"
epochs: 15
lr: 0.00001
weight_decay: 0.09
save_directory: "./save"
embedding_size: 128
seed: 1243
loss_type: 0
seq_type: 0
61 changes: 61 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
accelerate==1.2.1
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
certifi==2024.12.14
charset-normalizer==3.4.1
click==8.1.8
cloudpickle==3.1.1
colorama==0.4.6
docker-pycreds==0.4.0
filelock==3.16.1
fsspec==2024.12.0
gitdb==4.0.12
GitPython==3.1.44
gym==0.26.2
gym-notices==0.0.8
gym_connect4 @ git+https://github.com/Danielhp95/gym-connect4.git@bfc12d659308dfcf1132a31aee9b52eceb8901b5
huggingface-hub==0.27.1
hydra-core==1.3.2
idna==3.10
Jinja2==3.1.5
MarkupSafe==3.0.2
mcts @ git+https://github.com/metric-space/mcts.git@6028ada55d9690238c2db14d423c34d98698999a
mpmath==1.3.0
networkx==3.4.2
numpy==2.2.1
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
omegaconf==2.3.0
packaging==24.2
pillow==11.1.0
platformdirs==4.3.6
protobuf==5.29.3
psutil==6.1.1
pydantic==2.10.5
pydantic_core==2.27.2
PyYAML==6.0.2
requests==2.32.3
safetensors==0.5.2
sentry-sdk==2.20.0
setproctitle==1.3.4
six==1.17.0
smmap==5.0.2
sympy==1.13.1
torch==2.5.1
torchvision==0.20.1
tqdm==4.67.1
-e git+ssh://[email protected]/llm-engineering/transformers-learn-MDP.git@cafe152c60c4ddef960c1f5a066f235071e24fcd#egg=transformers_learn_mdp
triton==3.1.0
typing_extensions==4.12.2
urllib3==2.3.0
wandb==0.19.4
13 changes: 13 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from setuptools import setup, find_packages

def read_requirements():
with open("requirements.txt") as f:
return f.read().splitlines()

setup(
name="transformers_learn_mdp",
version="0.1.0",
package_dir={"": "src"},
packages=find_packages(where="src"),
install_requires=read_requirements() + ["mcts@git+https://github.com/metric-space/mcts.git"]
)
Empty file.
218 changes: 218 additions & 0 deletions src/transformers_learn_mdp/connect4_train_mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import os
import sys
import pickle
import shutil
import torch
import wandb
from tqdm import tqdm
import hydra
import itertools
from omegaconf import DictConfig, OmegaConf, open_dict

from accelerate import Accelerator
from .dataset import EpisodeDataset, collate_fn
from .model import Config, GPTModel
from .trainer import train_model, validate_model, Loss, Mode, SeqSubSet
from torch.utils.data import DataLoader

from .data_utils import information_parser, actions_to_col_row
from enum import Enum

def get_lr_scheduler(optimizer, warmup_epochs, total_epochs, base_lr, max_lr):
"""
Combines warmup and cosine annealing for learning rate scheduling.

Args:
optimizer: PyTorch optimizer
warmup_epochs: Number of warmup epochs
total_epochs: Total number of training epochs
base_lr: Starting learning rate (during warmup)
max_lr: Peak learning rate (after warmup)

Returns:
scheduler: Learning rate scheduler
"""
def lr_lambda(epoch):
if epoch < warmup_epochs:
return 2*epoch # Linear warmup
else:
return 10*epoch

return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def train(training_config, training_dataset, validation_dataset, token_to_idx, wandb):

train_dataset = EpisodeDataset(training_dataset, token_to_idx)
valid_dataset = EpisodeDataset(validation_dataset, token_to_idx)

accelerator = Accelerator()

train_loader = DataLoader(
train_dataset,
batch_size=training_config.batch_size,
shuffle=True,
collate_fn=collate_fn,
)
valid_loader = DataLoader(
valid_dataset,
batch_size=training_config.batch_size,
shuffle=True,
collate_fn=collate_fn,
)

config = Config(
training_config.vocab_size,
training_config.seq_len,
n_layer=training_config.num_layers,
n_head=training_config.num_heads,
n_embd=training_config.embedding_size,
)
model = GPTModel(config)

optimizer = torch.optim.AdamW(
model.parameters(),
lr=training_config.lr,
weight_decay=training_config.weight_decay,
)
#optimizer = torch.optim.SGD(
# model.parameters(),
# lr=training_config.lr,
# weight_decay=training_config.weight_decay,
#)
#scheduler = torch.optim.lr_scheduler.OneCycleLR(
# optimizer,
# max_lr=0.0005,
# steps_per_epoch=len(train_loader),
# epochs=training_config.epochs,
#)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs)

scheduler = get_lr_scheduler(optimizer,5, training_config.epochs, 0.00001, 0.001)

train_loader, valid_loader, model, scheduler, optimizer = accelerator.prepare(
train_loader, valid_loader, model, scheduler, optimizer
)

epoch = 0

model_path = None
min_loss = 1e10

train_losses = []
valid_losses = []



# TODO: this is just pulling things out from a config
mode = Mode(training_config.mode)
loss_type = Loss(training_config.loss_type)
seq_type = SeqSubSet(training_config.seq_type)

for epoch in tqdm(range(training_config.epochs), desc="Epoch"):
accelerator.print(f"Epoch {epoch}")
wandb.log({"Epoch": epoch})

train_loss = train_model(
model, train_loader, optimizer, accelerator, None, wandb, mode, loss_type, seq_type
)
valid_loss, p1_acc, p2_acc, total_acc = validate_model(model, valid_loader, accelerator, mode, loss_type, seq_type)
train_losses.append(train_loss)
valid_losses.append(valid_loss)
scheduler.step()
accelerator.print({"Learning Rate": scheduler.get_last_lr()[0]})

# print("Learning Rate: ", scheduler.get_last_lr())

mode = training_config.mode
seed = training_config.seed

if accelerator.is_main_process:
val_loss_str = f"Validation loss {valid_loss:.8f}"
wandb.log({"Validation Loss": valid_loss, "Training Loss": train_loss, "P1 Acc": p1_acc, "P2 Acc": p2_acc, "Total accuracy": total_acc})
accelerator.print(val_loss_str)

model_save_path = f"model_{epoch+1}_mode_{mode}_seed_{seed}.pth"
accelerator.save(
accelerator.unwrap_model(model).state_dict(), model_save_path
)

if valid_loss < min_loss:
min_loss = valid_loss
model_path = model_save_path

accelerator.wait_for_everyone()

if accelerator.is_main_process:
shutil.copy(model_path, training_config.save_directory)

with open(f"train_losses_mode_{mode}_seed_{seed}.pkl", "wb") as f:
pickle.dump(train_losses, f)
with open(f"valid_losses_mode_{mode}_seed_{seed}.pkl", "wb") as f:
pickle.dump(valid_losses, f)

wandb.finish()


def split_dataset(data, train_ratio, valid_ratio):
train = data[: int(train_ratio * len(data))]
valid = data[
int(train_ratio * len(data)) : int((train_ratio + valid_ratio) * len(data))
]
test = data[int((train_ratio + valid_ratio) * len(data)) :]
return train, valid, test


def mode_to_token_to_idx(mode):
if mode == 0:
token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)}
vocab_size = 43
transformation = actions_to_col_row
elif mode == 1:
token_to_idx = {i: i + 1 for i in range(7)}
vocab_size = 8
transformation = lambda x: x
elif mode == 2:
token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} | {
i: i + 44 for i in range(7)
}
vocab_size = 51
transformation = lambda x: list(itertools.chain(*zip(x,actions_to_col_row(x))))
token_to_idx["<pad>"] = 0 # Padding token

token_to_idx[51] = 51
vocab_size += 1

return token_to_idx, vocab_size, transformation


@hydra.main(version_base=None, config_path="../../conf", config_name="config")
def main(cfg: DictConfig) -> None:

training_config = cfg.training

mode = training_config.mode
token_to_idx, vocab_size, transformation = mode_to_token_to_idx(mode)

# Make this a function
with open(training_config.data_path, "r") as f:
data = f.readlines()
data = information_parser(data)
raw_dataset = [transformation([action for (_, action) in x]) for x in data]


training_dataset, validation_dataset, test_dataset = split_dataset(
raw_dataset, training_config.train_ratio, training_config.val_ratio
)

with open_dict(training_config):
training_config["vocab_size"] = vocab_size
training_config["dataset_length"] = len(raw_dataset)

wandb.init(project=cfg.wandb.project_name, config=dict(training_config), id=cfg.wandb.id)

train(training_config, training_dataset, validation_dataset, token_to_idx, wandb)


if __name__ == "__main__":
main()
58 changes: 58 additions & 0 deletions src/transformers_learn_mdp/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import List


def actions_to_col_row(actions, board_height=6):
"""
Converts a sequence of Connect4 column moves into (column, row) pairs.

Args:
actions (list): List of column indices (0-6) representing moves.
board_height (int): Number of rows in Connect4 (default: 6).

Returns:
list of tuples: [(col, row), ...] where row is where the piece lands.
"""
heights = [0] * 7 # Track how filled each column is
col_row_sequence = []

for col in actions:
row = board_height - 1 - heights[col] # Compute the landing row
if row < 0:
raise ValueError(f"Invalid move: Column {col} is full!")

col_row_sequence.append((row, col))
heights[col] += 1 # Update column height

return col_row_sequence


def information_parser(info: List[str]):
"""


"""
#
parsed_info = []

for line in info:
temp = []
raw = line.split(",")
counter = 0
while counter < len(raw):

leap_steps = int(raw[counter]) * 2
counter += 1

q_values = {}
fragment = raw[counter:counter + leap_steps ]
zip_object = zip(fragment[::2], fragment[1::2])
for key, value in zip_object:
q_values[int(key)] = float(value)
counter += leap_steps

temp.append((q_values, int(raw[counter])))
counter += 1

parsed_info.append(temp)

return parsed_info
Loading