diff --git a/conf/config.yaml b/conf/config.yaml new file mode 100644 index 0000000..0a55f5a --- /dev/null +++ b/conf/config.yaml @@ -0,0 +1,21 @@ +wandb: + project_name: "board-representation-experiments-5th-feb" + id: "cross-entropy-gpt-warmup-scheduler-whole-seq-a" + +training: + mode: 2 + num_layers: 8 + num_heads: 4 + batch_size: 64 + seq_len: 100 + train_ratio: 0.8 + val_ratio: 0.1 + data_path: "info.txt" + epochs: 15 + lr: 0.00001 + weight_decay: 0.09 + save_directory: "./save" + embedding_size: 128 + seed: 1243 + loss_type: 0 + seq_type: 0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9d36c8a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,61 @@ +accelerate==1.2.1 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +certifi==2024.12.14 +charset-normalizer==3.4.1 +click==8.1.8 +cloudpickle==3.1.1 +colorama==0.4.6 +docker-pycreds==0.4.0 +filelock==3.16.1 +fsspec==2024.12.0 +gitdb==4.0.12 +GitPython==3.1.44 +gym==0.26.2 +gym-notices==0.0.8 +gym_connect4 @ git+https://github.com/Danielhp95/gym-connect4.git@bfc12d659308dfcf1132a31aee9b52eceb8901b5 +huggingface-hub==0.27.1 +hydra-core==1.3.2 +idna==3.10 +Jinja2==3.1.5 +MarkupSafe==3.0.2 +mcts @ git+https://github.com/metric-space/mcts.git@6028ada55d9690238c2db14d423c34d98698999a +mpmath==1.3.0 +networkx==3.4.2 +numpy==2.2.1 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 +omegaconf==2.3.0 +packaging==24.2 +pillow==11.1.0 +platformdirs==4.3.6 +protobuf==5.29.3 +psutil==6.1.1 +pydantic==2.10.5 +pydantic_core==2.27.2 +PyYAML==6.0.2 +requests==2.32.3 +safetensors==0.5.2 +sentry-sdk==2.20.0 +setproctitle==1.3.4 +six==1.17.0 +smmap==5.0.2 +sympy==1.13.1 +torch==2.5.1 +torchvision==0.20.1 +tqdm==4.67.1 +-e git+ssh://git@github.com/llm-engineering/transformers-learn-MDP.git@cafe152c60c4ddef960c1f5a066f235071e24fcd#egg=transformers_learn_mdp +triton==3.1.0 +typing_extensions==4.12.2 +urllib3==2.3.0 +wandb==0.19.4 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..d8fd836 --- /dev/null +++ b/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +def read_requirements(): + with open("requirements.txt") as f: + return f.read().splitlines() + +setup( + name="transformers_learn_mdp", + version="0.1.0", + package_dir={"": "src"}, + packages=find_packages(where="src"), + install_requires=read_requirements() + ["mcts@git+https://github.com/metric-space/mcts.git"] +) diff --git a/src/transformers_learn_mdp/__init__.py b/src/transformers_learn_mdp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/transformers_learn_mdp/connect4_train_mcts.py b/src/transformers_learn_mdp/connect4_train_mcts.py new file mode 100644 index 0000000..47d45f1 --- /dev/null +++ b/src/transformers_learn_mdp/connect4_train_mcts.py @@ -0,0 +1,218 @@ +import os +import sys +import pickle +import shutil +import torch +import wandb +from tqdm import tqdm +import hydra +import itertools +from omegaconf import DictConfig, OmegaConf, open_dict + +from accelerate import Accelerator +from .dataset import EpisodeDataset, collate_fn +from .model import Config, GPTModel +from .trainer import train_model, validate_model, Loss, Mode, SeqSubSet +from torch.utils.data import DataLoader + +from .data_utils import information_parser, actions_to_col_row +from enum import Enum + +def get_lr_scheduler(optimizer, warmup_epochs, total_epochs, base_lr, max_lr): + """ + Combines warmup and cosine annealing for learning rate scheduling. + + Args: + optimizer: PyTorch optimizer + warmup_epochs: Number of warmup epochs + total_epochs: Total number of training epochs + base_lr: Starting learning rate (during warmup) + max_lr: Peak learning rate (after warmup) + + Returns: + scheduler: Learning rate scheduler + """ + def lr_lambda(epoch): + if epoch < warmup_epochs: + return 2*epoch # Linear warmup + else: + return 10*epoch + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + +def train(training_config, training_dataset, validation_dataset, token_to_idx, wandb): + + train_dataset = EpisodeDataset(training_dataset, token_to_idx) + valid_dataset = EpisodeDataset(validation_dataset, token_to_idx) + + accelerator = Accelerator() + + train_loader = DataLoader( + train_dataset, + batch_size=training_config.batch_size, + shuffle=True, + collate_fn=collate_fn, + ) + valid_loader = DataLoader( + valid_dataset, + batch_size=training_config.batch_size, + shuffle=True, + collate_fn=collate_fn, + ) + + config = Config( + training_config.vocab_size, + training_config.seq_len, + n_layer=training_config.num_layers, + n_head=training_config.num_heads, + n_embd=training_config.embedding_size, + ) + model = GPTModel(config) + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=training_config.lr, + weight_decay=training_config.weight_decay, + ) + #optimizer = torch.optim.SGD( + # model.parameters(), + # lr=training_config.lr, + # weight_decay=training_config.weight_decay, + #) + #scheduler = torch.optim.lr_scheduler.OneCycleLR( + # optimizer, + # max_lr=0.0005, + # steps_per_epoch=len(train_loader), + # epochs=training_config.epochs, + #) + # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs) + + scheduler = get_lr_scheduler(optimizer,5, training_config.epochs, 0.00001, 0.001) + + train_loader, valid_loader, model, scheduler, optimizer = accelerator.prepare( + train_loader, valid_loader, model, scheduler, optimizer + ) + + epoch = 0 + + model_path = None + min_loss = 1e10 + + train_losses = [] + valid_losses = [] + + + + # TODO: this is just pulling things out from a config + mode = Mode(training_config.mode) + loss_type = Loss(training_config.loss_type) + seq_type = SeqSubSet(training_config.seq_type) + + for epoch in tqdm(range(training_config.epochs), desc="Epoch"): + accelerator.print(f"Epoch {epoch}") + wandb.log({"Epoch": epoch}) + + train_loss = train_model( + model, train_loader, optimizer, accelerator, None, wandb, mode, loss_type, seq_type + ) + valid_loss, p1_acc, p2_acc, total_acc = validate_model(model, valid_loader, accelerator, mode, loss_type, seq_type) + train_losses.append(train_loss) + valid_losses.append(valid_loss) + scheduler.step() + accelerator.print({"Learning Rate": scheduler.get_last_lr()[0]}) + + # print("Learning Rate: ", scheduler.get_last_lr()) + + mode = training_config.mode + seed = training_config.seed + + if accelerator.is_main_process: + val_loss_str = f"Validation loss {valid_loss:.8f}" + wandb.log({"Validation Loss": valid_loss, "Training Loss": train_loss, "P1 Acc": p1_acc, "P2 Acc": p2_acc, "Total accuracy": total_acc}) + accelerator.print(val_loss_str) + + model_save_path = f"model_{epoch+1}_mode_{mode}_seed_{seed}.pth" + accelerator.save( + accelerator.unwrap_model(model).state_dict(), model_save_path + ) + + if valid_loss < min_loss: + min_loss = valid_loss + model_path = model_save_path + + accelerator.wait_for_everyone() + + if accelerator.is_main_process: + shutil.copy(model_path, training_config.save_directory) + + with open(f"train_losses_mode_{mode}_seed_{seed}.pkl", "wb") as f: + pickle.dump(train_losses, f) + with open(f"valid_losses_mode_{mode}_seed_{seed}.pkl", "wb") as f: + pickle.dump(valid_losses, f) + + wandb.finish() + + +def split_dataset(data, train_ratio, valid_ratio): + train = data[: int(train_ratio * len(data))] + valid = data[ + int(train_ratio * len(data)) : int((train_ratio + valid_ratio) * len(data)) + ] + test = data[int((train_ratio + valid_ratio) * len(data)) :] + return train, valid, test + + +def mode_to_token_to_idx(mode): + if mode == 0: + token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} + vocab_size = 43 + transformation = actions_to_col_row + elif mode == 1: + token_to_idx = {i: i + 1 for i in range(7)} + vocab_size = 8 + transformation = lambda x: x + elif mode == 2: + token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} | { + i: i + 44 for i in range(7) + } + vocab_size = 51 + transformation = lambda x: list(itertools.chain(*zip(x,actions_to_col_row(x)))) + token_to_idx[""] = 0 # Padding token + + token_to_idx[51] = 51 + vocab_size += 1 + + return token_to_idx, vocab_size, transformation + + +@hydra.main(version_base=None, config_path="../../conf", config_name="config") +def main(cfg: DictConfig) -> None: + + training_config = cfg.training + + mode = training_config.mode + token_to_idx, vocab_size, transformation = mode_to_token_to_idx(mode) + + # Make this a function + with open(training_config.data_path, "r") as f: + data = f.readlines() + data = information_parser(data) + raw_dataset = [transformation([action for (_, action) in x]) for x in data] + + + training_dataset, validation_dataset, test_dataset = split_dataset( + raw_dataset, training_config.train_ratio, training_config.val_ratio + ) + + with open_dict(training_config): + training_config["vocab_size"] = vocab_size + training_config["dataset_length"] = len(raw_dataset) + + wandb.init(project=cfg.wandb.project_name, config=dict(training_config), id=cfg.wandb.id) + + train(training_config, training_dataset, validation_dataset, token_to_idx, wandb) + + +if __name__ == "__main__": + main() diff --git a/src/transformers_learn_mdp/data_utils.py b/src/transformers_learn_mdp/data_utils.py new file mode 100644 index 0000000..b2aa957 --- /dev/null +++ b/src/transformers_learn_mdp/data_utils.py @@ -0,0 +1,58 @@ +from typing import List + + +def actions_to_col_row(actions, board_height=6): + """ + Converts a sequence of Connect4 column moves into (column, row) pairs. + + Args: + actions (list): List of column indices (0-6) representing moves. + board_height (int): Number of rows in Connect4 (default: 6). + + Returns: + list of tuples: [(col, row), ...] where row is where the piece lands. + """ + heights = [0] * 7 # Track how filled each column is + col_row_sequence = [] + + for col in actions: + row = board_height - 1 - heights[col] # Compute the landing row + if row < 0: + raise ValueError(f"Invalid move: Column {col} is full!") + + col_row_sequence.append((row, col)) + heights[col] += 1 # Update column height + + return col_row_sequence + + +def information_parser(info: List[str]): + """ + + + """ + # + parsed_info = [] + + for line in info: + temp = [] + raw = line.split(",") + counter = 0 + while counter < len(raw): + + leap_steps = int(raw[counter]) * 2 + counter += 1 + + q_values = {} + fragment = raw[counter:counter + leap_steps ] + zip_object = zip(fragment[::2], fragment[1::2]) + for key, value in zip_object: + q_values[int(key)] = float(value) + counter += leap_steps + + temp.append((q_values, int(raw[counter]))) + counter += 1 + + parsed_info.append(temp) + + return parsed_info \ No newline at end of file diff --git a/src/transformers_learn_mdp/dataset.py b/src/transformers_learn_mdp/dataset.py new file mode 100644 index 0000000..c5965b0 --- /dev/null +++ b/src/transformers_learn_mdp/dataset.py @@ -0,0 +1,44 @@ +import torch + +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence +import tqdm + + +class EpisodeDataset(Dataset): + + def __init__(self, data, token_to_idx, packing_length=30,padding_value=0): + self.token_to_idx = token_to_idx + print("Tokenizing and packing the dataset") + self.packed_data = [] + + self.tokenized_data = [[51] + [self.token_to_idx[token] for token in sequence] + [51] for sequence in data] + # flatten the list and insert padding value at the end of each sequence + #self.data = [] + #for sequence in self.tokenized_data: + # self.data.extend(sequence) + # self.data.append(padding_value) + #del self.tokenized_data + #self.data = [self.data[i:i+packing_length] for i in range(0, len(self.data), packing_length)] + self.data = self.tokenized_data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + + sequence = self.data[idx] + + X_indices = sequence[:-1] + Y_indices = sequence[1:] + + return torch.tensor(X_indices, dtype=torch.long), torch.tensor(Y_indices, dtype=torch.long) + +def collate_fn(batch): + + Xs, Ys = zip(*batch) + + Xs_padded = pad_sequence(Xs, batch_first=True, padding_value=0) + Ys_padded = pad_sequence(Ys, batch_first=True, padding_value=0) + + return Xs_padded, Ys_padded \ No newline at end of file diff --git a/transformer_training_mcts/model.py b/src/transformers_learn_mdp/model.py similarity index 100% rename from transformer_training_mcts/model.py rename to src/transformers_learn_mdp/model.py diff --git a/src/transformers_learn_mdp/trainer.py b/src/transformers_learn_mdp/trainer.py new file mode 100644 index 0000000..cadba8b --- /dev/null +++ b/src/transformers_learn_mdp/trainer.py @@ -0,0 +1,257 @@ +import torch +import torch.nn as nn +import itertools + +from tqdm import tqdm +import torch.nn.functional as F +from enum import Enum + + +class Loss(Enum): + CrossEntropy = 0 + KLDivergence = 1 + + +class Mode(Enum): + STATE = 0 + ACTION = 1 + STATE_ACTION = 2 + + +# NOTE: as in whether to include just player 1 or the whole sequence +class SeqSubSet(Enum): + WHOLE = 0 + PLAYER_1 = 1 + + +def batch_one_hot(batch, number_of_classes): + """ + One-hot encode a batch of sequences. + """ + batch_size, seq_length = batch.size() + one_hot = torch.zeros(batch_size, seq_length, number_of_classes).to(batch.device) + one_hot.scatter_(2, batch.unsqueeze(-1), 1) + return one_hot + + +def loss_calc(loss_type, loss_fn, logits, target, indices, padding_token=0): + + assert loss_fn.reduction == "none" + + vocab_length = logits.shape[-1] + + target = ( + target[:, indices].contiguous().view(-1) + ) # expect this to be one dimensional + logits = logits[:, indices, :].contiguous().view(-1, logits.shape[-1]) + + if loss_type == Loss.KLDivergence: + """ + If loss is KL Divergence, the target needs to be a probability distribution over the + vocabulary + """ + logits = F.log_softmax(logits, dim=-1) + target = F.one_hot( + target, num_classes=vocab_length + ) # expect batch size is (batch_size x seq_length, vocab_length) + + loss = loss_fn(logits, target) + + if loss_type == Loss.KLDivergence: + """ + KLDivergence without reduction shoots out a seq of dim (seq_length ,vocab) + """ + loss = loss.mean(dim=-1) + + assert len(loss.shape) == 1 + + # ----- mask making ----------- + + mask = (target != padding_token).float() + + return (loss * mask, mask) + + +def logit_selection(length, mode, seq_type): + """ + + If mode is 0, select all the odd indices from 1 to length-1, because player 2 is randomly selecting the column. + + For mode 2 it's action state + + (a_0,s_0,a_1,s_1,a_0,s_1 ..... -> (a_0,s_0 ....), (s_0, a_1, ....) + + + """ + + whole_seq = range(length) + + if seq_type == SeqSubSet.WHOLE: + return list(whole_seq), list(whole_seq) + + if mode == 0 or mode == 1: + player_1 = list(range(0, length, 2)) + player_2 = [x for x in whole_seq if x not in player_1] + return (player_1, player_2) + + player_1 = [0, 1] # Why is 0 here? because action can be used to predict state + for i in range(3, length, 4): + player_1.extend(list(range(i, min(i + 3, length)))) + + if player_1[-1] != length -1: + player_1.append(length-1) + + return (player_1, [x for x in whole_seq if x not in player_1]) + + +def criterion_f(loss_type): + if loss_type == Loss.CrossEntropy: + return nn.CrossEntropyLoss( + ignore_index=0, reduction="none", label_smoothing=0.0 + ) + else: + return nn.KLDivLoss(reduction="none") + + +# TODO: mode +def validate_model(model, valid_loader, accelerator, mode, loss_type, seq_type): + + model.eval() + criterion = criterion_f(loss_type) + + valid_loss = torch.tensor(0.0).to(accelerator.device) + valid_data = torch.tensor(0.0).to(accelerator.device) + player_1_accuracy = torch.tensor(0.0).to(accelerator.device) + player_2_accuracy = torch.tensor(0.0).to(accelerator.device) + total_accuracy = torch.tensor(0.0).to(accelerator.device) + player_1_total = torch.tensor(0.0).to(accelerator.device) + player_2_total = torch.tensor(0.0).to(accelerator.device) + + with torch.no_grad(): + + for X_batch, Y_batch in valid_loader: + + p1_indices, p2_indices = logit_selection(X_batch.size(1), mode, seq_type) + + logits = model(X_batch) # Shape: [batch_size, seq_length, vocab_size] + + logits = F.log_softmax(logits, dim=-1) + + p1_indices_, p2_indices_ = logit_selection(X_batch.size(1), mode, SeqSubSet.PLAYER_1) + + player_1_accuracy_ = ( + logits[:, p1_indices_].argmax(dim=-1) == Y_batch[:, p1_indices_] + ).float() + player_2_accuracy_ = ( + logits[:, p2_indices_].argmax(dim=-1) == Y_batch[:, p2_indices_] + ).float() + total_accuracy_ = (logits.argmax(dim=-1) == Y_batch).float() + + masked_loss, mask = loss_calc( + loss_type, criterion, logits, Y_batch, p1_indices + ) + + valid_loss += masked_loss.sum() # Sum the losses at valid positions + valid_data += mask.sum() # Count valid positions + player_1_accuracy += player_1_accuracy_.sum() + player_2_accuracy += player_2_accuracy_.sum() + total_accuracy += total_accuracy_.sum() + player_1_total += player_1_accuracy_.numel() + player_2_total += player_2_accuracy_.numel() + + accelerator.wait_for_everyone() + + valid_loss = accelerator.gather(valid_loss).sum() + valid_data = accelerator.gather(valid_data).sum() + player_1_accuracy = accelerator.gather(player_1_accuracy).sum() + player_2_accuracy = accelerator.gather(player_2_accuracy).sum() + total_accuracy = accelerator.gather(total_accuracy).sum() + player_1_total = accelerator.gather(player_1_total).sum() + player_2_total = accelerator.gather(player_2_total).sum() + + if accelerator.is_main_process: + return ( + (valid_loss / valid_data).item(), + (player_1_accuracy / player_1_total).item(), + (player_2_accuracy / player_2_total).item(), + (total_accuracy / (player_1_total + player_2_total)).item(), + ) + else: + return None + + +def train_model( + model, + train_loader, + optimizer, + accelerator, + scheduler, + wandb, + mode, + loss_type, + seq_type, +): + + model.train() + + criterion = criterion_f(loss_type) + + train_loss = torch.tensor(0.0).to(accelerator.device) + train_data = torch.tensor(0.0).to(accelerator.device) + + for X_batch, Y_batch in tqdm(train_loader, desc="Training"): + + optimizer.zero_grad() + + p1_indices, p2_indices = logit_selection(X_batch.size(1), mode, seq_type) + + logits = model(X_batch) # Shape: [batch_size, seq_length, vocab_size] + + masked_loss, mask = loss_calc(None, criterion, logits, Y_batch, p1_indices) + + loss_sum = masked_loss.sum() + data_sum = mask.sum() + + loss = loss_sum / data_sum + accelerator.backward(loss) + nn.utils.clip_grad_norm_(model.parameters(), 5.0) + optimizer.step() + if scheduler is not None: + scheduler.step() + wandb.log({"Learning Rate": scheduler.get_last_lr()[0]}) + + train_loss += loss_sum.item() + train_data += mask.sum().item() + + grad_norms = [] + + for param in model.parameters(): + if param.grad is not None: + grad_norms.append(param.grad.norm().item()) + + if len(grad_norms) == 0: + return None # No gradients yet (e.g., before first backward pass) + + grad_tensor = torch.tensor(grad_norms) + + stats = { + "grad_mean": grad_tensor.mean().item(), + "grad_max": grad_tensor.max().item(), + "grad_std": grad_tensor.std().item(), + "grad_p95": grad_tensor.quantile(0.95).item(), + } + + # accelerator.print('Gradient norm:', stats) + wandb.log(stats) + + accelerator.wait_for_everyone() + + train_loss = accelerator.gather(train_loss).sum() + train_data = accelerator.gather(train_data).sum() + + accelerator.print("Training Loss:", (train_loss / train_data).item()) + + if accelerator.is_main_process: + return (train_loss / train_data).item() + else: + return None diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..82eb16c --- /dev/null +++ b/train.sh @@ -0,0 +1 @@ +python -m transformers_learn_mdp.connect4_train_mcts diff --git a/transformer_training_mcts/connect4_train_mcts.py b/transformer_training_mcts/connect4_train_mcts.py deleted file mode 100644 index ef40eed..0000000 --- a/transformer_training_mcts/connect4_train_mcts.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -import sys -import pickle -import shutil -import torch -import argparse -from tqdm import tqdm - -sys.path.append('../') - -from accelerate import Accelerator -from dataset import EpisodeDataset, collate_fn -from model import Config, GPTModel -from trainer import train_model, validate_model -from torch.utils.data import DataLoader - -""" -Training pipeline for transformer on Connect-4 data generated through MCTS. -""" - -def train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, embed_size, mode, seed, save_directory = None, epochs = 15): - - accelerator = Accelerator() - - train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) - valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) - - config = Config(vocab_size, block_size, n_layer=num_layers, n_head=num_layers, n_embd=embed_size) - model = GPTModel(config) - - optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs) - - train_loader, valid_loader, model, scheduler, optimizer = accelerator.prepare(train_loader, valid_loader, model, scheduler, optimizer) - - epoch = 0 - - model_path = None - min_loss = 1e10 - - train_losses = [] - valid_losses = [] - - for epoch in tqdm(range(epochs)): - accelerator.print(f'Epoch {epoch}') - - train_loss = train_model(model, train_loader, optimizer, accelerator) - valid_loss = validate_model(model, valid_loader, accelerator) - train_losses.append(train_loss) - valid_losses.append(valid_loss) - scheduler.step() - - if accelerator.is_main_process: - print(f'Validation Loss: {valid_loss:.8f}') - - model_save_path = f"model_{epoch+1}_mode_{mode}_seed_{seed}.pth" - accelerator.save(accelerator.unwrap_model(model).state_dict(), model_save_path) - - if valid_loss < min_loss: - min_loss = valid_loss - model_path = model_save_path - - accelerator.wait_for_everyone() - - if accelerator.is_main_process: - shutil.copy(model_path, save_directory) - - with open(f'train_losses_mode_{mode}_seed_{seed}.pkl', 'wb') as f: - pickle.dump(train_losses, f) - with open(f'valid_losses_mode_{mode}_seed_{seed}.pkl', 'wb') as f: - pickle.dump(valid_losses, f) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('-m', type=int, default=0, choices=[0, 1, 2], help='Data Mode (state, action, state-action)') - parser.add_argument('-s', type=int, default=0, choices=[0, 1, 2], help='Seed') - args = parser.parse_args() - if args.m == 0: - token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} - vocab_size = 43 - elif args.m == 1: - token_to_idx = {i: i + 1 for i in range(7)} - vocab_size = 8 - elif args.m == 2: - token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} | {i: i + 44 for i in range(7)} - vocab_size = 51 - token_to_idx[''] = 0 # Padding token - block_size = 42 - embed_size = 512 - num_layers = 8 - - path = '' - - with open(os.path.join(path, rf'training_data/mcts/training_games_mode_{args.m}.pkl'), 'rb') as f: - agent1 = pickle.load(f) - - train_ratio = 0.8 - valid_ratio = 0.1 - - d1 = len(agent1) - - train = agent1[:int(train_ratio * d1)] - valid = agent1[int(train_ratio * d1):int((train_ratio + valid_ratio) * d1) ] - test = agent1[int((train_ratio + valid_ratio) * d1): ] - - print(len(train)) - print(len(valid)) - print(len(test)) - - train_dataset = EpisodeDataset(train, token_to_idx) - valid_dataset = EpisodeDataset(valid, token_to_idx) - - train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, embed_size, args.m, args.s, "best_model") - -if __name__ == "__main__": - main() - diff --git a/transformer_training_mcts/dataset.py b/transformer_training_mcts/dataset.py deleted file mode 100644 index ff5ff4a..0000000 --- a/transformer_training_mcts/dataset.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch - -from torch.utils.data import Dataset -from torch.nn.utils.rnn import pad_sequence - -class EpisodeDataset(Dataset): - - def __init__(self, data, token_to_idx): - self.data = data - self.token_to_idx = token_to_idx - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - - X_sequence, Y_sequence = self.data[idx] - - X_indices = [self.token_to_idx[token] for token in X_sequence] - Y_indices = [self.token_to_idx[token] for token in Y_sequence] - - return torch.tensor(X_indices, dtype=torch.long), torch.tensor(Y_indices, dtype=torch.long) - -def collate_fn(batch): - - Xs, Ys = zip(*batch) - - Xs_padded = pad_sequence(Xs, batch_first=True, padding_value=0) - Ys_padded = pad_sequence(Ys, batch_first=True, padding_value=0) - - return Xs_padded, Ys_padded \ No newline at end of file diff --git a/transformer_training_mcts/trainer.py b/transformer_training_mcts/trainer.py deleted file mode 100644 index 28144f5..0000000 --- a/transformer_training_mcts/trainer.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.nn as nn - -from tqdm import tqdm - -def validate_model(model, valid_loader, accelerator): - - model.eval() - criterion = nn.CrossEntropyLoss(ignore_index=0, reduction = 'none') - - valid_loss = torch.tensor(0.0).to(accelerator.device) - valid_data = torch.tensor(0.0).to(accelerator.device) - - with torch.no_grad(): - - for X_batch, Y_batch in valid_loader: - - logits = model(X_batch) - logits = logits.view(-1, logits.size(-1)) # Shape: [batch_size * seq_length, vocab_size] - Y_batch = Y_batch.view(-1) # Shape: [batch_size * seq_length] - - # Assuming the padding token index is 0 - padding_token_index = 0 - mask = (Y_batch != padding_token_index).float() # Create a mask for valid positions - - loss = criterion(logits, Y_batch) # Calculate loss without reduction - masked_loss = loss * mask # Apply mask - - valid_loss += masked_loss.sum().item() # Sum the losses at valid positions - valid_data += mask.sum().item() # Count valid positions - - accelerator.wait_for_everyone() - - valid_loss = accelerator.gather(valid_loss).sum() - valid_data = accelerator.gather(valid_data).sum() - - if accelerator.is_main_process: - return (valid_loss / valid_data).item() - else: - return None - -def train_model(model, train_loader, optimizer, accelerator): - - model.train() - criterion = nn.CrossEntropyLoss(ignore_index=0, reduction = 'none') - - train_loss = torch.tensor(0.0).to(accelerator.device) - train_data = torch.tensor(0.0).to(accelerator.device) - - for X_batch, Y_batch in tqdm(train_loader, desc="Training"): - - optimizer.zero_grad() - logits = model(X_batch) - - logits = logits.view(-1, logits.size(-1)) # Shape: [batch_size * seq_length, vocab_size] - Y_batch = Y_batch.view(-1) # Shape: [batch_size * seq_length] - - padding_token_index = 0 # Assuming the padding token index is 0 - mask = (Y_batch != padding_token_index).float() - - loss = criterion(logits, Y_batch) - masked_loss = loss * mask - - loss_sum = masked_loss.sum() - data_sum = mask.sum() - - loss = loss_sum / data_sum - accelerator.backward(loss) - optimizer.step() - - train_loss += loss_sum.item() - train_data += mask.sum().item() - - accelerator.wait_for_everyone() - - train_loss = accelerator.gather(train_loss).sum() - train_data = accelerator.gather(train_data).sum() - - accelerator.print('Training Loss:', (train_loss / train_data).item()) - - return train_loss - -