diff --git a/.gitignore b/.gitignore index 4160575..6fe8679 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__/ *.so # Distribution / packaging +wandb/ .Python build/ develop-eggs/ diff --git a/pipegoose/distributed/parallel_context.py b/pipegoose/distributed/parallel_context.py index 652d0f3..6992e53 100644 --- a/pipegoose/distributed/parallel_context.py +++ b/pipegoose/distributed/parallel_context.py @@ -125,8 +125,8 @@ def __init__( self.init_global_dist(rank, world_size, backend, host, port) self.init_parallel_groups() - # if torch.cuda.is_available(): - # self.set_device() + if torch.cuda.is_available() and backend == "nccl": + self.set_device() self.map_rank_to_device() @@ -261,17 +261,20 @@ def set_seed(self, seed: int): def map_rank_to_device(self): """Map global rank to device.""" + rank_tensor = torch.zeros(len(self._local_ranks), dtype=torch.long) + rank_tensor = rank_tensor.cuda() if torch.cuda.is_available() else rank_tensor for idx, local_rank in enumerate(self._local_ranks.values()): rank_tensor[idx] = local_rank rank_tensor_list = [ - torch.zeros(rank_tensor.size(), dtype=torch.long) for _ in range(self.get_world_size(ParallelMode.GLOBAL)) + torch.zeros(rank_tensor.size(), dtype=torch.long).cuda() if torch.cuda.is_available() else torch.zeros(rank_tensor.size(), dtype=torch.long) + for _ in range(self.get_world_size(ParallelMode.GLOBAL)) ] dist.all_gather(tensor_list=rank_tensor_list, tensor=rank_tensor) - + for _rank, _rank_tensor in enumerate(rank_tensor_list): modes_and_ranks = {mode: rank for mode, rank in zip(self._local_ranks.keys(), _rank_tensor.tolist())} self._ranks_to_device[tuple(modes_and_ranks.items())] = _rank diff --git a/pipegoose/nn/tensor_parallel/_functional.py b/pipegoose/nn/tensor_parallel/_functional.py index 93cdd5f..409486e 100644 --- a/pipegoose/nn/tensor_parallel/_functional.py +++ b/pipegoose/nn/tensor_parallel/_functional.py @@ -25,7 +25,10 @@ def backward(ctx: Any, grad: torch.Tensor) -> Tuple[torch.Tensor, None, None]: all_reduce(grad, parallel_context=parallel_context, parallel_mode=ParallelMode.TENSOR) - return (grad, None, None) + return ( + grad, + None, + ) class _Gather(Function): diff --git a/pipegoose/nn/tensor_parallel/linear.py b/pipegoose/nn/tensor_parallel/linear.py index 53a5d68..c38e3b9 100644 --- a/pipegoose/nn/tensor_parallel/linear.py +++ b/pipegoose/nn/tensor_parallel/linear.py @@ -32,17 +32,16 @@ def __init__( if bias is True: self.bias = nn.Parameter(torch.randn(out_per_partition)) - + else: + self.bias = None + def _get_output_per_partition(self, out_features: int, parallel_context: ParallelContext) -> int: local_world_size = parallel_context.get_world_size(ParallelMode.TENSOR) return out_features // local_world_size def forward(self, input: torch.Tensor) -> torch.Tensor: input_parallel = broadcast_to_tensor_group(input, self.parallel_context) - outputs = F.linear(input_parallel, self.weight) - - if self.bias is not None: - outputs = outputs + self.bias + outputs = F.linear(input_parallel, self.weight, self.bias) if self.gather_output: outputs = gather_to_tensor_group(outputs, dim=-1, parallel_context=self.parallel_context) diff --git a/pipegoose/nn/tensor_parallel/parallel_mapping.py b/pipegoose/nn/tensor_parallel/parallel_mapping.py index 875981c..8f1228f 100644 --- a/pipegoose/nn/tensor_parallel/parallel_mapping.py +++ b/pipegoose/nn/tensor_parallel/parallel_mapping.py @@ -34,6 +34,8 @@ class ParallelMapping: Row(("mlp.dense_4h_to_h", "self_attention.dense")), LMHead(("lm_head",)), ], + "debug_single_mlp": [Column(("debug_single_mlp",))], + } @staticmethod diff --git a/pipegoose/nn/tensor_parallel/tensor_parallel.py b/pipegoose/nn/tensor_parallel/tensor_parallel.py index b0130bd..fb2cbca 100644 --- a/pipegoose/nn/tensor_parallel/tensor_parallel.py +++ b/pipegoose/nn/tensor_parallel/tensor_parallel.py @@ -6,10 +6,7 @@ from pipegoose.distributed.parallel_context import ParallelContext from pipegoose.nn.parallel import Parallel from pipegoose.nn.tensor_parallel.parallelizer import ( - EmbeddingParallelizer, - LayerNormParallelizer, LinearParallelizer, - LMHeadParallelizer, ModuleParallelizer, ) @@ -17,7 +14,8 @@ class TensorParallel(Parallel): """Turn a 🤗 transformers model into a tensor parallel model.""" - PARALLELIZERS = [EmbeddingParallelizer, LinearParallelizer, LayerNormParallelizer, LMHeadParallelizer] + # PARALLELIZERS = [EmbeddingParallelizer, LinearParallelizer, LayerNormParallelizer, LMHeadParallelizer] + PARALLELIZERS = [LinearParallelizer] def __init__(self, module: nn.Module, parallel_context: ParallelContext): self.module = module @@ -33,6 +31,11 @@ def parallelize(self) -> nn.Module: # multiple times. so we filter out and retain the non-repetitive modules (leaf modules) leaf_modules = self._get_leaf_modules(module) for module_name, leaf_module in leaf_modules: + # NOTE: just skip parallelizing query_key_value in attention + # for debugging purposes + if "query_key_value" in module_name: + continue + parallelizer = self._find_parallelizer(module_name, leaf_module) if parallelizer is not None: parallelizer(module_name, leaf_module, module, self.parallel_context).parallelize() diff --git a/pipegoose/utils/logger.py b/pipegoose/utils/logger.py new file mode 100644 index 0000000..32d45ef --- /dev/null +++ b/pipegoose/utils/logger.py @@ -0,0 +1,154 @@ +import datetime +import inspect +import sys +import os +import wandb +import glob +import re +import os + +class Logger: + # https://github.com/Cadene/bootstrap.pytorch/blob/master/bootstrap/lib/logger.py + """ The Logger class is a singleton. It contains all the utilities + for logging variables in a key-value dictionary. + It can also be considered as a replacement for the print function. + + .. code-block:: python + + Logger(dir_logs='logs/mnist') + Logger().flush() # write the logs.json + Logger()("Launching training procedures") # written to logs.txt + > [I 2018-07-23 18:58:31] ...trap/engines/engine.py.80: Launching training procedures + """ + + DEBUG = -1 + INFO = 0 + SUMMARY = 1 + WARNING = 2 + ERROR = 3 + SYSTEM = 4 + _instance = None + indicator = {DEBUG: 'D', INFO: 'I', SUMMARY: 'S', WARNING: 'W', ERROR: 'E', SYSTEM: 'S'} + + class Colors: + END = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + GREY = 30 + RED = 31 + GREEN = 32 + YELLOW = 33 + BLUE = 34 + PURPLE = 35 + SKY = 36 + WHITE = 37 + BACKGROUND = 10 + LIGHT = 60 + + @staticmethod + def code(value): + return '\033[{}m'.format(value) + + colorcode = { + DEBUG: Colors.code(Colors.GREEN), + INFO: Colors.code(Colors.GREY + Colors.LIGHT), + SUMMARY: Colors.code(Colors.BLUE + Colors.LIGHT), + WARNING: Colors.code(Colors.YELLOW + Colors.LIGHT), + ERROR: Colors.code(Colors.RED + Colors.LIGHT), + SYSTEM: Colors.code(Colors.WHITE + Colors.LIGHT) + } + + compactjson = True + log_level = None # log level + dir_logs = None + path_json = None + path_txt = None + file_txt = None + name = None + max_lineno_width = 3 + + def __new__(cls, dir_logs=None, name='logs'): + if Logger._instance is None: + Logger._instance = object.__new__(Logger) + + if dir_logs: + Logger._instance.name = name + Logger._instance.dir_logs = dir_logs + Logger._instance.path_txt = os.path.join(dir_logs, '{}.txt'.format(name)) + Logger._instance.file_txt = open(os.path.join(dir_logs, '{}.txt'.format(name)), 'a+') + # NOTE: Support json or CSV ? + # Logger._instance.path_json = os.path.join(dir_logs, '{}.json'.format(name)) + # Logger._instance.reload_json() + else: + Logger._instance.log_message('No logs files will be created (dir_logs attribute is empty)', + log_level=Logger.WARNING) + + return Logger._instance + + def __call__(self, *args, **kwargs): + return self.log_message(*args, **kwargs, stack_displacement=2) + + def log_message(self, *message, log_level=INFO, break_line=True, print_header=True, stack_displacement=1, + raise_error=True, adaptive_width=True): + + if self.dir_logs and not self.file_txt: + raise Exception('Critical: Log file not defined. Do you have write permissions for {}?'.format(self.dir_logs)) + + caller_info = inspect.getframeinfo(inspect.stack()[stack_displacement][0]) + message = ' '.join([str(m) for m in list(message)]) + + if print_header: + message_header = '[{} {:%Y-%m-%d %H:%M:%S}]'.format(self.indicator[log_level], + datetime.datetime.now()) + filename = caller_info.filename + if adaptive_width: + # allows the lineno_width to grow when necessary + lineno_width = len(str(caller_info.lineno)) + self.max_lineno_width = max(lineno_width, self.max_lineno_width) + else: + # manually fix it to 3 numbers + lineno_width = 3 + + if len(filename) > 28 - self.max_lineno_width: + filename = '...{}'.format(filename[-22 - (self.max_lineno_width - lineno_width):]) + + message_locate = '{}.{}:'.format(filename, caller_info.lineno) + message_logger = '{} {} {}'.format(message_header, message_locate, message) + message_screen = '{}{}{}{} {} {}'.format(self.Colors.BOLD, + self.colorcode[log_level], + message_header, + self.Colors.END, + message_locate, + message) + else: + message_logger = message + message_screen = message + + if break_line: + print(message_screen) + if self.dir_logs: + self.file_txt.write('%s\n' % message_logger) + else: + print(message_screen, end='') + sys.stdout.flush() + if self.dir_logs: + self.file_txt.write(message_logger) + + if self.dir_logs: + self.file_txt.flush() + if log_level == self.ERROR and raise_error: + raise Exception(message) + + def update_log_file(self, path_src, path_dst): + """ + Append content of file at path_src to file at path_dst + """ + + with open(path_src, 'r') as f: + lines_src = f.readlines() + + with open(path_dst, 'r') as f: + lines_dst = f.readlines() + + with open(path_dst, 'w') as f: + f.writelines(lines_src + ["\n"] + lines_dst) \ No newline at end of file diff --git a/tests/convergence/debug_batch.pt b/tests/convergence/debug_batch.pt new file mode 100644 index 0000000..1f696bb Binary files /dev/null and b/tests/convergence/debug_batch.pt differ diff --git a/tests/convergence/debug_target.pt b/tests/convergence/debug_target.pt new file mode 100644 index 0000000..076bb06 Binary files /dev/null and b/tests/convergence/debug_target.pt differ diff --git a/tests/convergence/model.pt b/tests/convergence/model.pt new file mode 100644 index 0000000..24e9582 Binary files /dev/null and b/tests/convergence/model.pt differ diff --git a/tests/convergence/run_tp.py b/tests/convergence/run_tp.py new file mode 100644 index 0000000..b3c5bf5 --- /dev/null +++ b/tests/convergence/run_tp.py @@ -0,0 +1,210 @@ +from copy import deepcopy + +import torch +import torch.distributed as dist +from datasets import load_dataset +from torch.optim import SGD +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoModelForCausalLM, AutoTokenizer + +from pipegoose.distributed.parallel_context import ParallelContext +from pipegoose.distributed.parallel_mode import ParallelMode +from pipegoose.nn import TensorParallel +from pipegoose.utils.logger import Logger + +def get_model_params_size(model, fp_bytes=4): + params_size = 0 + for p in model.parameters(): + params_size += p.numel() + params_gb = params_size * fp_bytes / 2**30 + return params_gb + + +def set_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +if __name__ == "__main__": + import wandb + + DATA_PARALLEL_SIZE = 1 + TENSOR_PARALLEL_SIZE = 2 + PIPELINE_PARALLEL_SIZE = 1 + MODEL = "bigscience/bloom-560m" + DATASET = "imdb" + NUM_EPOCHS = 4 + LR = 1e-3 + SEED = 69 + BATCH_SIZE = 4 + CONTEXT_LENGTH = 1024 + + torch.cuda.empty_cache() + set_seed(SEED) + + Logger()(f"device_count: {torch.cuda.device_count()}") + Logger()(f"is available: {torch.cuda.is_available()}") + + parallel_context = ParallelContext.from_torch( + data_parallel_size=DATA_PARALLEL_SIZE, + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, + ) + rank = parallel_context.get_global_rank() + + Logger()(f"rank={rank}, initialized parallel_context") + + train_dataset = load_dataset("imdb", split="train[:130]") + train_dataset = train_dataset.map(lambda x: {"text": x["text"][:10]}) # for demonstration purposes + + dp_rank = parallel_context.get_local_rank(ParallelMode.DATA) + train_sampler = DistributedSampler(train_dataset, num_replicas=DATA_PARALLEL_SIZE, rank=dp_rank, seed=SEED) + train_dataloader = DataLoader( + train_dataset, + batch_size=BATCH_SIZE // DATA_PARALLEL_SIZE, + shuffle=False, + sampler=train_sampler, + ) + + val_dataset = load_dataset("imdb", split="test[:130]") + val_dataset = val_dataset.map(lambda x: {"text": x["text"][:10]}) # for demonstration purposes + val_sampler = DistributedSampler(val_dataset, num_replicas=DATA_PARALLEL_SIZE, rank=dp_rank, seed=SEED) + val_dataloader = DataLoader( + val_dataset, + batch_size=BATCH_SIZE // DATA_PARALLEL_SIZE, + shuffle=False, + sampler=val_sampler, + ) + + model = AutoModelForCausalLM.from_pretrained(MODEL) + ref_model = deepcopy(model) + tokenizer = AutoTokenizer.from_pretrained(MODEL) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + Logger()(f"rank={rank}, model size before parallelizing: {round(get_model_params_size(model), 3)} GB") + + dist.barrier() + + model = TensorParallel(model, parallel_context).parallelize() + # model = DataParallel(model, parallel_context).parallelize() + optim = SGD(model.parameters(), lr=LR) + # optim = DistributedOptimizer(optim, parallel_context) + model.to("cuda") + device = next(model.parameters()).device + + Logger()(f"rank={rank}, model size after parallelizing: {round(get_model_params_size(model), 3)} GB") + Logger()(f"rank={rank}, model is moved to device: {device}") + + ref_model.to(device) + # if DATA_PARALLEL_SIZE > 1: + # ref_model = torch.nn.parallel.DistributedDataParallel(ref_model) + + ref_optim = SGD(ref_model.parameters(), lr=LR) + + model.train() + ref_model.train() + step = 0 + dist.barrier() + + if rank == 0: + + def get_time_name(): + import datetime + + today = datetime.datetime.now() + return today.strftime("%d/%m/%Y_%H:%M:%S") + + wandb.init( + project="pipegoose", + name=f"{get_time_name()}.test_dp_tp_zero1_converegence", + config={ + "data_parallel_size": DATA_PARALLEL_SIZE, + "tensor_parallel_size": TENSOR_PARALLEL_SIZE, + "pipeline_parallel_size": PIPELINE_PARALLEL_SIZE, + "model": MODEL, + "dataset": DATASET, + "epochs": NUM_EPOCHS, + "learning_rate": LR, + "seed": SEED, + "batch_size": BATCH_SIZE, + }, + ) + + for epoch in range(NUM_EPOCHS): + train_sampler.set_epoch(epoch) + Logger()(f"rank={rank}, epoch={epoch}") + + for batch in train_dataloader: + inputs = tokenizer( + batch["text"], + padding=True, + truncation=True, + max_length=CONTEXT_LENGTH, + return_tensors="pt", + ) + inputs = {name: tensor.to(device) for name, tensor in inputs.items()} + labels = inputs["input_ids"] + + outputs = model(**inputs, labels=labels) + ref_outputs = ref_model(**inputs, labels=labels) + + optim.zero_grad() + outputs.loss.backward() + optim.step() + + ref_optim.zero_grad() + ref_outputs.loss.backward() + ref_optim.step() + + Logger()(f"epoch={epoch}, step={step}, rank={rank}, train_loss={outputs.loss}, ref_train_loss={ref_outputs.loss}") + + if rank == 0: + wandb.log( + { + "train_loss": outputs.loss, + "ref_train_loss": ref_outputs.loss, + "step": step, + "epoch": epoch, + } + ) + + step += 1 + + model.eval() + ref_model.eval() + dist.barrier() + + step = 0 + val_sampler.set_epoch(1) + + for batch in val_dataloader: + inputs = tokenizer( + batch["text"], + padding=True, + truncation=True, + max_length=CONTEXT_LENGTH, + return_tensors="pt", + ) + inputs = {name: tensor.to(device) for name, tensor in inputs.items()} + labels = inputs["input_ids"] + + outputs = model(**inputs, labels=labels) + ref_outputs = ref_model(**inputs, labels=labels) + + Logger()(f"rank={rank}, val_loss={outputs.loss}, ref_val_loss={ref_outputs.loss}, step={step}") + + if rank == 0: + wandb.log( + { + "val_loss": outputs.loss, + "ref_val_loss": ref_outputs.loss, + "step": step, + } + ) + + step += 1 + + wandb.finish() + model.cpu() diff --git a/tests/convergence/run_tp_mnist.py b/tests/convergence/run_tp_mnist.py new file mode 100644 index 0000000..1918526 --- /dev/null +++ b/tests/convergence/run_tp_mnist.py @@ -0,0 +1,232 @@ +from copy import deepcopy + +import torch +import torch.distributed as dist +from torch.optim import SGD +from torch.utils.data import DataLoader, random_split +import torch.nn as nn +from torchvision import datasets, transforms + +from pipegoose.distributed.parallel_context import ParallelContext +from pipegoose.distributed.parallel_mode import ParallelMode +from pipegoose.nn import TensorParallel +from pipegoose.utils.logger import Logger + +class NN(nn.Module): + def __init__(self, input_size, output_size): + super(NN, self).__init__() + self.debug_single_mlp = nn.Linear(input_size, output_size) + + def forward(self, x): + x = torch.flatten(x, 1) + x = self.debug_single_mlp(x) + return x + +class MNISTloader: + def __init__( + self, + batch_size: int = 64, + data_dir: str = "./data/", + num_workers: int = 0, + pin_memory: bool = False, + shuffle: bool = False, + train_val_split: float = 0.1, + ): + self.batch_size = batch_size + self.data_dir = data_dir + self.num_workers = num_workers + self.pin_memory = pin_memory + self.shuffle = shuffle + self.train_val_split = train_val_split + + self.setup() + + def setup(self): + transform = transforms.Compose( + [ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5], std=[0.5]), + ] + ) + + self.train_dataset = datasets.MNIST( + self.data_dir, train=True, download=True, transform=transform + ) + val_split = int(len(self.train_dataset) * self.train_val_split) + train_split = len(self.train_dataset) - val_split + + self.train_dataset, self.val_dataset = random_split( + self.train_dataset, [train_split, val_split] + ) + self.test_dataset = datasets.MNIST( + self.data_dir, train=False, download=True, transform=transform + ) + + print( + "Image Shape: {}".format(self.train_dataset[0][0].numpy().shape), + end="\n\n", + ) + print("Training Set: {} samples".format(len(self.train_dataset))) + print("Validation Set: {} samples".format(len(self.val_dataset))) + print("Test Set: {} samples".format(len(self.test_dataset))) + + def load(self): + train_loader = DataLoader( + dataset=self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + ) + + val_loader = DataLoader( + dataset=self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + ) + + test_loader = DataLoader( + dataset=self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + ) + + return train_loader, val_loader, test_loader + +def set_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +if __name__ == "__main__": + import wandb + + DATA_PARALLEL_SIZE = 1 + TENSOR_PARALLEL_SIZE = 2 + PIPELINE_PARALLEL_SIZE = 1 + NUM_EPOCHS = 30 + LR = 2e-1 + SEED = 42 + BATCH_SIZE = 1024 + + torch.cuda.empty_cache() + set_seed(SEED) + + Logger()(f"device_count: {torch.cuda.device_count()}") + Logger()(f"is available: {torch.cuda.is_available()}") + + parallel_context = ParallelContext.from_torch( + data_parallel_size=DATA_PARALLEL_SIZE, + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, + ) + rank = parallel_context.get_global_rank() + + Logger()(f"rank={rank}, initialized parallel_context") + + dp_rank = parallel_context.get_local_rank(ParallelMode.DATA) + # train_dataloader, _, _ = MNISTloader(batch_size=BATCH_SIZE).load() + # for batch_idx, (debug_batch, debug_target) in enumerate(train_dataloader): + # break + # Dump batch of data to reload later + # torch.save(debug_batch, "debug_batch.pt") + # torch.save(debug_target, "debug_target.pt") + + # Load batch of data + debug_batch = torch.load("debug_batch.pt") + debug_target = torch.load("debug_target.pt") + + model = NN(input_size=32 * 32, output_size=10) + model.load_state_dict(torch.load("model.pt")) + ref_model = deepcopy(model) + + dist.barrier() + + model = TensorParallel(model, parallel_context).parallelize() + optim = SGD(model.parameters(), lr=LR) + criterion = nn.CrossEntropyLoss() + + model.to("cuda") + device = next(model.parameters()).device + + Logger()(f"rank={rank}, model is moved to device: {device}") + + ref_model.to(device) + ref_optim = SGD(ref_model.parameters(), lr=LR) + ref_criterion = nn.CrossEntropyLoss() + + + model.train() + ref_model.train() + step = 0 + dist.barrier() + + if rank == 0: + + def get_time_name(): + import datetime + + today = datetime.datetime.now() + return today.strftime("%d/%m/%Y_%H:%M:%S") + + wandb.init( + project="pipegoose", + name=f"{get_time_name()}.test_tp_mnist_converegence", + config={ + "data_parallel_size": DATA_PARALLEL_SIZE, + "tensor_parallel_size": TENSOR_PARALLEL_SIZE, + "pipeline_parallel_size": PIPELINE_PARALLEL_SIZE, + "model": "NN", + "dataset": "MNIST", + "epochs": NUM_EPOCHS, + "learning_rate": LR, + "seed": SEED, + "batch_size": BATCH_SIZE, + }, + ) + + # wandb log image + # wandb.log({"examples": [wandb.Image(img.numpy()) for img in debug_batch]}) + + for epoch in range(NUM_EPOCHS): + Logger()(f"rank={rank}, epoch={epoch}") + + train_loss_running, train_acc_running = 0, 0 + + inputs, labels = debug_batch.to(device), debug_target.to(device) + + outputs = model(inputs) + _, predictions = torch.max(outputs, dim=1) + loss = criterion(outputs, labels) + + ref_outputs = ref_model(inputs) + _, ref_predictions = torch.max(ref_outputs, dim=1) + ref_loss = ref_criterion(ref_outputs, labels) + + optim.zero_grad() + loss.backward() + optim.step() + + ref_optim.zero_grad() + ref_loss.backward() + ref_optim.step() + + Logger()(f"epoch={epoch}, rank={rank}, train_loss={loss}, ref_train_loss={ref_loss}") + + if rank == 0: + wandb.log( + { + "train_loss": loss, + "ref_train_loss": ref_loss, + "epoch": epoch, + } + ) + + + dist.barrier() + wandb.finish() + model.cpu() diff --git a/tests/convergence/run_tp_small.py b/tests/convergence/run_tp_small.py new file mode 100644 index 0000000..360e76f --- /dev/null +++ b/tests/convergence/run_tp_small.py @@ -0,0 +1,117 @@ +from copy import deepcopy + +import torch +import torch.distributed as dist +from torch.optim import SGD +from torch.utils.data import DataLoader, random_split +import torch.nn as nn +from torchvision import datasets, transforms + +from pipegoose.distributed.parallel_context import ParallelContext +from pipegoose.distributed.parallel_mode import ParallelMode +from pipegoose.nn import TensorParallel +from pipegoose.utils.logger import Logger +import torch.nn.functional as F +import numpy as np +import random + +class NN(nn.Module): + def __init__(self, input_size, output_size): + super(NN, self).__init__() + self.debug_single_mlp = nn.Linear(input_size, output_size) + + def forward(self, x): + x = self.debug_single_mlp(x) + return x + +def set_random_seed(seed: int): + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + +if __name__ == "__main__": + DATA_PARALLEL_SIZE = 1 + TENSOR_PARALLEL_SIZE = 2 + PIPELINE_PARALLEL_SIZE = 1 + NUM_EPOCHS = 30 + LR = 2e-1 + SEED = 42 + + torch.cuda.empty_cache() + + Logger()(f"device_count: {torch.cuda.device_count()}") + Logger()(f"is available: {torch.cuda.is_available()}") + + parallel_context = ParallelContext.from_torch( + data_parallel_size=DATA_PARALLEL_SIZE, + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, + seed=SEED, + backend="nccl" + ) + + rank = parallel_context.get_local_rank(ParallelMode.TENSOR) + set_random_seed(SEED + rank) + + Logger()(f"rank={rank}, initialized parallel_context") + + BATCH_SIZE = 1 + IN_FEATURES = 4 + OUT_FEATURES = 6 + + X = torch.randn(BATCH_SIZE, IN_FEATURES, device="cuda", requires_grad=True) + L_weight = torch.randn(BATCH_SIZE, OUT_FEATURES, device="cuda") + + # Rank 0 brodcast X and W to other rank + dist.broadcast(X, src=0) + dist.broadcast(L_weight, src=0) + + Logger()(f"[rank {rank}]: {X}") + Logger()(f"[rank {rank}]: {L_weight}") + + model = NN(input_size=IN_FEATURES, output_size=OUT_FEATURES) + model_ref = deepcopy(model) + + dist.barrier() + + model = TensorParallel(model, parallel_context).parallelize() + model.to("cuda") + device = next(model.parameters()).device + model_ref.to(device) + Logger()(f"[rank {rank}]: model is moved to device: {device}") + + # Reference + Y_ref = model_ref(X) + L_ref = torch.mul(Y_ref, L_weight).sum() + # Manually compute the gradient + dLdW_ref = torch.matmul(L_weight.t(), X) + dLdX_ref = torch.matmul(L_weight, model_ref.debug_single_mlp.weight) + + dist.barrier() + + # Distributed + Logger()("===========FORWARD===========") + Y = model(X) + L = torch.mul(Y, L_weight).sum() + Y.retain_grad() + + Logger()("===========BACKWARD===========") + L.backward() + + #HACK: we need to divide by world size because we are calling L.backward() on rank 0 and 1 + # Too lazy to find a way to merge into a single matrix + dLdX = X.grad / dist.get_world_size() + + if rank == 0: + #NOTE: tests inspired from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/tests/test_layers.py#L173 + Logger()(f"error Y_ref - Y: {Y_ref.sub(Y).abs().max()}") + Logger()(f"error L_ref - L: {L_ref.sub(L).abs().max()}") + Logger()(f"error dLdX_ref - dLdX: {dLdX_ref.sub(dLdX).abs().max()}") + + dist.barrier() + + dLdW_ref = torch.split(dLdW_ref, OUT_FEATURES // dist.get_world_size(), dim=0)[rank].contiguous() + dLdW = model.debug_single_mlp.weight.grad + Logger()(f"error dLdW_ref - dLdW (rank {rank}): {dLdW_ref.sub(dLdW).abs().max()}") \ No newline at end of file diff --git a/tests/convergence/sandbox.ipynb b/tests/convergence/sandbox.ipynb new file mode 100644 index 0000000..5f5282e --- /dev/null +++ b/tests/convergence/sandbox.ipynb @@ -0,0 +1,250 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " NN(\n", + " (out): Linear(in_features=10, out_features=5, bias=True)\n", + ")\n", + "out Linear(in_features=10, out_features=5, bias=True)\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "class NN(nn.Module):\n", + " def __init__(self, input_size, output_size):\n", + " super(NN, self).__init__()\n", + " self.out = nn.Linear(input_size, output_size)\n", + "\n", + " def forward(self, x):\n", + " x = torch.flatten(x, 1)\n", + " x = self.out(x)\n", + " return x\n", + "\n", + "# Example of using named_children\n", + "model = NN(input_size=10, output_size=5)\n", + "for name, module in model.named_modules():\n", + " print(name, module)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Image Shape: (1, 32, 32)\n", + "\n", + "Training Set: 54000 samples\n", + "Validation Set: 6000 samples\n", + "Test Set: 10000 samples\n", + "torch.Size([1, 1, 32, 32])\n" + ] + } + ], + "source": [ + "from torchvision import datasets, transforms\n", + "from torch.utils.data import DataLoader, random_split\n", + "\n", + "\n", + "class MNISTloader:\n", + " def __init__(\n", + " self,\n", + " batch_size: int = 64,\n", + " data_dir: str = \"./data/\",\n", + " num_workers: int = 0,\n", + " pin_memory: bool = False,\n", + " shuffle: bool = False,\n", + " train_val_split: float = 0.1,\n", + " ):\n", + " self.batch_size = batch_size\n", + " self.data_dir = data_dir\n", + " self.num_workers = num_workers\n", + " self.pin_memory = pin_memory\n", + " self.shuffle = shuffle\n", + " self.train_val_split = train_val_split\n", + "\n", + " self.setup()\n", + "\n", + " def setup(self):\n", + " transform = transforms.Compose(\n", + " [\n", + " transforms.Resize((32, 32)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.5], std=[0.5]),\n", + " ]\n", + " )\n", + "\n", + " self.train_dataset = datasets.MNIST(\n", + " self.data_dir, train=True, download=True, transform=transform\n", + " )\n", + " val_split = int(len(self.train_dataset) * self.train_val_split)\n", + " train_split = len(self.train_dataset) - val_split\n", + "\n", + " self.train_dataset, self.val_dataset = random_split(\n", + " self.train_dataset, [train_split, val_split]\n", + " )\n", + " self.test_dataset = datasets.MNIST(\n", + " self.data_dir, train=False, download=True, transform=transform\n", + " )\n", + "\n", + " print(\n", + " \"Image Shape: {}\".format(self.train_dataset[0][0].numpy().shape),\n", + " end=\"\\n\\n\",\n", + " )\n", + " print(\"Training Set: {} samples\".format(len(self.train_dataset)))\n", + " print(\"Validation Set: {} samples\".format(len(self.val_dataset)))\n", + " print(\"Test Set: {} samples\".format(len(self.test_dataset)))\n", + "\n", + " def load(self):\n", + " train_loader = DataLoader(\n", + " dataset=self.train_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " pin_memory=self.pin_memory,\n", + " shuffle=self.shuffle,\n", + " )\n", + "\n", + " val_loader = DataLoader(\n", + " dataset=self.val_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " pin_memory=self.pin_memory,\n", + " shuffle=self.shuffle,\n", + " )\n", + "\n", + " test_loader = DataLoader(\n", + " dataset=self.test_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " pin_memory=self.pin_memory,\n", + " shuffle=self.shuffle,\n", + " )\n", + "\n", + " return train_loader, val_loader, test_loader\n", + "\n", + "\n", + "# Load only 1 image\n", + "train_loader, val_loader, test_loader = MNISTloader(batch_size=1).load()\n", + "for batch_idx, (data, target) in enumerate(train_loader):\n", + " print(data.shape)\n", + " break\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 1, 32, 32])\n" + ] + } + ], + "source": [ + "print(data.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgLElEQVR4nO3de3BU9f3/8VcCyQKSbAyQbFKSEC5CFYJThJhBESUNRMeCZFq8zBRaRgYanEJq1XRUqu1M/NoZbx3EaW1BZ0QsrUC1I1bRhGrDLRARLylk0oLNBaRmEwLZxOTz+6Pj/rpy20+y4cOG52PmzLB73nnnfTxjXjm7J5+NMcYYAQBwgcW6HgAAcGkigAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4MdD1AF/X3d2t+vp6JSQkKCYmxvU4AABLxhi1trYqPT1dsbFnv8656AKovr5eGRkZrscAAPTSkSNHNHLkyLPu77OX4FavXq1Ro0Zp0KBBys3N1a5du8L6uoSEhL4aCQBwAZ3v53mfBNArr7yikpISrVq1Snv37tXkyZM1e/ZsHT169Lxfy8tuANA/nPfnuekD06ZNM8XFxcHHXV1dJj093ZSVlZ33a/1+v5HExsbGxhblm9/vP+fP+4hfAXV0dKiqqkr5+fnB52JjY5Wfn6/KysrT6gOBgFpaWkI2AED/F/EA+vzzz9XV1aXU1NSQ51NTU9XY2HhafVlZmbxeb3DjBgQAuDQ4/zug0tJS+f3+4HbkyBHXIwEALoCI34Y9fPhwDRgwQE1NTSHPNzU1yefznVbv8Xjk8XgiPQYA4CIX8Sug+Ph4TZkyRdu2bQs+193drW3btikvLy/S3w4AEKX65A9RS0pKtHDhQl1zzTWaNm2annrqKbW1tekHP/hBX3w7AEAU6pMAWrBggY4dO6aHH35YjY2Nuvrqq7V169bTbkwAAFy6YowxxvUQ/6ulpUVer9f1GACAXvL7/UpMTDzrfud3wQEALk0EEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4MdD1AAD6r5iYmLBrY2Ptfh+Oj4+3HSdsGRkZVvU2x3n06FGr3s3NzWHXGmOservGFRAAwImIB9DPf/5zxcTEhGwTJkyI9LcBAES5PnkJ7qqrrtLbb7/9/7/JQF7pAwCE6pNkGDhwoHw+X1+0BgD0E33yHtDBgweVnp6u0aNH66677tLhw4fPWhsIBNTS0hKyAQD6v4gHUG5urtatW6etW7dqzZo1qqur0/XXX6/W1tYz1peVlcnr9QY327tPAADRKcb08X17zc3NysrK0hNPPKHFixeftj8QCCgQCAQft7S0EEJAP8Ft2Ke7lG7D9vv9SkxMPOv+Pr87ICkpSVdccYUOHTp0xv0ej0cej6evxwAAXGT6/O+ATpw4odraWqWlpfX1twIARJGIB9C9996riooK/fOf/9Tf//533XbbbRowYIDuuOOOSH8rAEAUi/hLcJ999pnuuOMOHT9+XCNGjNB1112nHTt2aMSIEZH+VgAiIC4urk9qJWnQoEFh16akpFj1zsrKsqq3ccstt1jVd3d3h127efNmq97vv/9+2LWdnZ1WvV2LeABt2LAh0i0BAP0Qa8EBAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAATvT5xzEA6L2BA8P/X9X2c3IyMzPDrrVd1X7UqFFh186dO9eq93e+8x2rehs2n+8j/XcNzHB9+umnVr137doVdm20rQXHFRAAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBEvxAFFg+vTpYdfecccdVr1vuOGGsGt9Pp9Vb5tlgWyWG7rYnDx5Muza5uZmq96nTp2ynCZ6cAUEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCAAIAOEEAAQCciN7Fl9Av2awdJkmzZs0Ku/bOO++06u31esOura2tteodExNjVZ+Xlxd27ejRo616JyYmhl1ru15bbGz4v+MGAgGr3jb/zXft2mXV+8CBA1b11dXVYdfu27fPqrcxxqo+mnAFBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnGAtOFxURo0aZVWfn58fdu0tt9xi1TsuLi7s2qlTp1r1tl0LLiEhIexam7kluzXYOjo6rHo3NzeHXbt582ar3q+++mrYtceOHbPqbTO3JPn9/rBrT548adW7P+MKCADghHUAbd++XbfeeqvS09MVExNz2m8txhg9/PDDSktL0+DBg5Wfn6+DBw9Gal4AQD9hHUBtbW2aPHmyVq9efcb9jz/+uJ555hk999xz2rlzpy677DLNnj1b7e3tvR4WANB/WL8HVFhYqMLCwjPuM8boqaee0oMPPqi5c+dKkl588UWlpqZq8+bNuv3223s3LQCg34joe0B1dXVqbGwMeWPY6/UqNzdXlZWVZ/yaQCCglpaWkA0A0P9FNIAaGxslSampqSHPp6amBvd9XVlZmbxeb3DLyMiI5EgAgIuU87vgSktL5ff7g9uRI0dcjwQAuAAiGkA+n0+S1NTUFPJ8U1NTcN/XeTweJSYmhmwAgP4vogGUnZ0tn8+nbdu2BZ9raWnRzp07lZeXF8lvBQCIctZ3wZ04cUKHDh0KPq6rq1N1dbWSk5OVmZmpFStW6Je//KXGjRun7OxsPfTQQ0pPT9e8efMiOTcAIMpZB9CePXt04403Bh+XlJRIkhYuXKh169bpvvvuU1tbm5YsWaLm5mZdd9112rp1qwYNGhS5qdFvjRs3zqr+qquuCrs2Ntbugr++vj7s2oaGBqvehw8ftqpvbW0Nu7arq8uq99dfMj+XpKQkq94DBgwIu/b3v/+9Ve8DBw5Y1ePiYx1AM2fOlDHmrPtjYmL06KOP6tFHH+3VYACA/s35XXAAgEsTAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcMJ6KR7A1jXXXBN27c0332zV2+YDDPfs2WPV+09/+lPYtTbrxkmy/tyrtra2sGtt14L74osvwq71er1WvW0+XoW13S49XAEBAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAATrAUD6wNGTLEqn7BggVh186bN8+qd2tra9i1f/nLX6x6r1mzxqr+UnD8+HHXI6Af4QoIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4wVpwsJaZmWlVP3HixLBr4+PjrXrv3r077Nr333/fqjeAvsUVEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAES/HAWk5OjlV9enp62LU7duyw6v3CCy+EXbt3716r3gMH9t3/HsYYq/ru7u4+6w24whUQAMAJAggA4IR1AG3fvl233nqr0tPTFRMTo82bN4fsX7RokWJiYkK2OXPmRGpeAEA/YR1AbW1tmjx5slavXn3Wmjlz5qihoSG4vfzyy70aEgDQ/1i/y1pYWKjCwsJz1ng8Hvl8vh4PBQDo//rkPaDy8nKlpKRo/PjxWrZsmY4fP37W2kAgoJaWlpANAND/RTyA5syZoxdffFHbtm3T//3f/6miokKFhYXq6uo6Y31ZWZm8Xm9wy8jIiPRIAICLUMT/0OH2228P/nvSpEnKycnRmDFjVF5erlmzZp1WX1paqpKSkuDjlpYWQggALgF9fhv26NGjNXz4cB06dOiM+z0ejxITE0M2AED/1+cB9Nlnn+n48eNKS0vr628FAIgi1i/BnThxIuRqpq6uTtXV1UpOTlZycrIeeeQRFRUVyefzqba2Vvfdd5/Gjh2r2bNnR3RwAEB0sw6gPXv26MYbbww+/ur9m4ULF2rNmjXav3+/XnjhBTU3Nys9PV0FBQX6xS9+IY/HE7mp4VRBQYFVvc1acIMHD7bqfcMNN4Rd29zcbNW7L3355ZdW9TU1NWHXnjp1ynYcwAnrAJo5c+Y5Fzt88803ezUQAODSwFpwAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMR/zwg9H9xcXFW9bGx4f+ek5WVZdV7yZIlYdd+73vfs+rdlzo6OqzqKysrw659/vnnrXp/8MEHYdcGAgGr3sC5cAUEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCAAIAOMFSPLD25z//2aq+vb097Fqfz2c7TtiMMVb1CQkJYdfm5eVZ9Y6Pj7eqT0lJCbvWZm5J+s1vfhN27XvvvWfVu7Oz06oelxaugAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBOsBQdr27dvt6qvq6sLu3bIkCG24/SZwYMHh107YcIEq94TJ060qi8qKgq7tqCgwKr3zp07w67dvXu3VW/WgsO5cAUEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCAAIAOMFSPLB27NixPq2PRu+++65V/dVXX21Vf+ONN4ZdO27cOKveXq837NqBA/mRgcjhCggA4IRVAJWVlWnq1KlKSEhQSkqK5s2bp5qampCa9vZ2FRcXa9iwYRo6dKiKiorU1NQU0aEBANHPKoAqKipUXFysHTt26K233lJnZ6cKCgrU1tYWrFm5cqVee+01bdy4URUVFaqvr9f8+fMjPjgAILpZvaC7devWkMfr1q1TSkqKqqqqNGPGDPn9fv3ud7/T+vXrddNNN0mS1q5dq29+85vasWOHrr322shNDgCIar16D8jv90uSkpOTJUlVVVXq7OxUfn5+sGbChAnKzMxUZWXlGXsEAgG1tLSEbACA/q/HAdTd3a0VK1Zo+vTpwQ/XamxsVHx8vJKSkkJqU1NT1djYeMY+ZWVl8nq9wS0jI6OnIwEAokiPA6i4uFgHDhzQhg0bejVAaWmp/H5/cDty5Eiv+gEAokOPbupfvny5Xn/9dW3fvl0jR44MPu/z+dTR0aHm5uaQq6Cmpib5fL4z9vJ4PPJ4PD0ZAwAQxayugIwxWr58uTZt2qR33nlH2dnZIfunTJmiuLg4bdu2LfhcTU2NDh8+rLy8vMhMDADoF6yugIqLi7V+/Xpt2bJFCQkJwfd1vF6vBg8eLK/Xq8WLF6ukpETJyclKTEzUPffco7y8PO6AAwCEsAqgNWvWSJJmzpwZ8vzatWu1aNEiSdKTTz6p2NhYFRUVKRAIaPbs2Xr22WcjMiwAoP+wCiBjzHlrBg0apNWrV2v16tU9HgpAqHD+3/tKIBCw6t3R0RF2bXd3t1Vv4FxYCw4A4AQBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwokcfxwAgVFxcnFX9bbfdZlX/9Q95PJdjx45Z9f7Pf/4Tdm1nZ6dVb+BcuAICADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABOsBYccBaxseH/fpaQkGDVe9GiRVb1ycnJYddu377dqnddXV3YtadOnbLqDZwLV0AAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEyzFA5zFiBEjwq5dtmyZVW+bpXUk6fPPPw+79sUXX7TqvXPnTqt6IFK4AgIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE6wFlw/NWDAgLBrJ02aZNW7sLDQqv7kyZNh13700UdWvY8ePRp2bUpKilXvgoKCsGu/+93vWvX+97//bVX/29/+NuzaiooKq96tra1W9UCkcAUEAHDCKoDKyso0depUJSQkKCUlRfPmzVNNTU1IzcyZMxUTExOyLV26NKJDAwCin1UAVVRUqLi4WDt27NBbb72lzs5OFRQUqK2tLaTu7rvvVkNDQ3B7/PHHIzo0ACD6Wb0HtHXr1pDH69atU0pKiqqqqjRjxozg80OGDJHP54vMhACAfqlX7wH5/X5Jp3+41ksvvaThw4dr4sSJKi0tPeeb0IFAQC0tLSEbAKD/6/FdcN3d3VqxYoWmT5+uiRMnBp+/8847lZWVpfT0dO3fv1/333+/ampq9Oqrr56xT1lZmR555JGejgEAiFI9DqDi4mIdOHBA7733XsjzS5YsCf570qRJSktL06xZs1RbW6sxY8ac1qe0tFQlJSXBxy0tLcrIyOjpWACAKNGjAFq+fLlef/11bd++XSNHjjxnbW5uriTp0KFDZwwgj8cjj8fTkzEAAFHMKoCMMbrnnnu0adMmlZeXKzs7+7xfU11dLUlKS0vr0YAAgP7JKoCKi4u1fv16bdmyRQkJCWpsbJQkeb1eDR48WLW1tVq/fr1uvvlmDRs2TPv379fKlSs1Y8YM5eTk9MkBAACik1UArVmzRtJ//9j0f61du1aLFi1SfHy83n77bT311FNqa2tTRkaGioqK9OCDD0ZsYABA/2D9Ety5ZGRkWK9Dhb4xcGD4p/baa6+16m277tlll10Wdu1Xt/aHq729PezaQYMGWfVOSEgIu/brK4KczwsvvGBVX1lZGXZtfX29Ve/u7m6reiBSWAsOAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcKLHnweEi5vN8ip1dXVWvf/xj39Y1V9//fVh157pIzvOpbm5Oeza/fv3W/X++kfQn8uuXbusev/tb3+zqm9tbQ27lqV1EC24AgIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE6wFlw/9eWXX4Zda7uOmU1vSdq7d2/YtSkpKVa9/X5/2LUffPCBVe/du3eHXfvFF19Y9Q4EAlb1xhireiAacAUEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCAAIAOBFjLrI1PlpaWuT1el2PAQDoJb/fr8TExLPu5woIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAATlgF0Jo1a5STk6PExEQlJiYqLy9Pb7zxRnB/e3u7iouLNWzYMA0dOlRFRUVqamqK+NAAgOhnFUAjR47UY489pqqqKu3Zs0c33XST5s6dq48++kiStHLlSr322mvauHGjKioqVF9fr/nz5/fJ4ACAKGd66fLLLzfPP/+8aW5uNnFxcWbjxo3BfZ988omRZCorK8Pu5/f7jSQ2NjY2tijf/H7/OX/e9/g9oK6uLm3YsEFtbW3Ky8tTVVWVOjs7lZ+fH6yZMGGCMjMzVVlZedY+gUBALS0tIRsAoP+zDqAPP/xQQ4cOlcfj0dKlS7Vp0yZdeeWVamxsVHx8vJKSkkLqU1NT1djYeNZ+ZWVl8nq9wS0jI8P6IAAA0cc6gMaPH6/q6mrt3LlTy5Yt08KFC/Xxxx/3eIDS0lL5/f7gduTIkR73AgBEj4G2XxAfH6+xY8dKkqZMmaLdu3fr6aef1oIFC9TR0aHm5uaQq6Cmpib5fL6z9vN4PPJ4PPaTAwCiWq//Dqi7u1uBQEBTpkxRXFyctm3bFtxXU1Ojw4cPKy8vr7ffBgDQz1hdAZWWlqqwsFCZmZlqbW3V+vXrVV5erjfffFNer1eLFy9WSUmJkpOTlZiYqHvuuUd5eXm69tpr+2p+AECUsgqgo0eP6vvf/74aGhrk9XqVk5OjN998U9/+9rclSU8++aRiY2NVVFSkQCCg2bNn69lnn+2TwQEA0S3GGGNcD/G/Wlpa5PV6XY8BAOglv9+vxMTEs+5nLTgAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMXXQBdZAszAAB66Hw/zy+6AGptbXU9AgAgAs738/yiWwuuu7tb9fX1SkhIUExMTPD5lpYWZWRk6MiRI+dcWyjacZz9x6VwjBLH2d9E4jiNMWptbVV6erpiY89+nWP9gXR9LTY2ViNHjjzr/sTExH598r/CcfYfl8IxShxnf9Pb4wxnUemL7iU4AMClgQACADgRNQHk8Xi0atUqeTwe16P0KY6z/7gUjlHiOPubC3mcF91NCACAS0PUXAEBAPoXAggA4AQBBABwggACADgRNQG0evVqjRo1SoMGDVJubq527drleqSI+vnPf66YmJiQbcKECa7H6pXt27fr1ltvVXp6umJiYrR58+aQ/cYYPfzww0pLS9PgwYOVn5+vgwcPuhm2F853nIsWLTrt3M6ZM8fNsD1UVlamqVOnKiEhQSkpKZo3b55qampCatrb21VcXKxhw4Zp6NChKioqUlNTk6OJeyac45w5c+Zp53Pp0qWOJu6ZNWvWKCcnJ/jHpnl5eXrjjTeC+y/UuYyKAHrllVdUUlKiVatWae/evZo8ebJmz56to0ePuh4toq666io1NDQEt/fee8/1SL3S1tamyZMna/Xq1Wfc//jjj+uZZ57Rc889p507d+qyyy7T7Nmz1d7efoEn7Z3zHackzZkzJ+Tcvvzyyxdwwt6rqKhQcXGxduzYobfeekudnZ0qKChQW1tbsGblypV67bXXtHHjRlVUVKi+vl7z5893OLW9cI5Tku6+++6Q8/n44487mrhnRo4cqccee0xVVVXas2ePbrrpJs2dO1cfffSRpAt4Lk0UmDZtmikuLg4+7urqMunp6aasrMzhVJG1atUqM3nyZNdj9BlJZtOmTcHH3d3dxufzmV/96lfB55qbm43H4zEvv/yygwkj4+vHaYwxCxcuNHPnznUyT185evSokWQqKiqMMf89d3FxcWbjxo3Bmk8++cRIMpWVla7G7LWvH6cxxtxwww3mxz/+sbuh+sjll19unn/++Qt6Li/6K6COjg5VVVUpPz8/+FxsbKzy8/NVWVnpcLLIO3jwoNLT0zV69GjdddddOnz4sOuR+kxdXZ0aGxtDzqvX61Vubm6/O6+SVF5erpSUFI0fP17Lli3T8ePHXY/UK36/X5KUnJwsSaqqqlJnZ2fI+ZwwYYIyMzOj+nx+/Ti/8tJLL2n48OGaOHGiSktLdfLkSRfjRURXV5c2bNigtrY25eXlXdBzedEtRvp1n3/+ubq6upSamhryfGpqqj799FNHU0Vebm6u1q1bp/Hjx6uhoUGPPPKIrr/+eh04cEAJCQmux4u4xsZGSTrjef1qX38xZ84czZ8/X9nZ2aqtrdXPfvYzFRYWqrKyUgMGDHA9nrXu7m6tWLFC06dP18SJEyX993zGx8crKSkppDaaz+eZjlOS7rzzTmVlZSk9PV379+/X/fffr5qaGr366qsOp7X34YcfKi8vT+3t7Ro6dKg2bdqkK6+8UtXV1RfsXF70AXSpKCwsDP47JydHubm5ysrK0h/+8ActXrzY4WTordtvvz3470mTJiknJ0djxoxReXm5Zs2a5XCynikuLtaBAwei/j3K8znbcS5ZsiT470mTJiktLU2zZs1SbW2txowZc6HH7LHx48erurpafr9ff/zjH7Vw4UJVVFRc0Bku+pfghg8frgEDBpx2B0ZTU5N8Pp+jqfpeUlKSrrjiCh06dMj1KH3iq3N3qZ1XSRo9erSGDx8eled2+fLlev311/Xuu++GfGyKz+dTR0eHmpubQ+qj9Xye7TjPJDc3V5Ki7nzGx8dr7NixmjJlisrKyjR58mQ9/fTTF/RcXvQBFB8frylTpmjbtm3B57q7u7Vt2zbl5eU5nKxvnThxQrW1tUpLS3M9Sp/Izs6Wz+cLOa8tLS3auXNnvz6vkvTZZ5/p+PHjUXVujTFavny5Nm3apHfeeUfZ2dkh+6dMmaK4uLiQ81lTU6PDhw9H1fk833GeSXV1tSRF1fk8k+7ubgUCgQt7LiN6S0Mf2bBhg/F4PGbdunXm448/NkuWLDFJSUmmsbHR9WgR85Of/MSUl5eburo68/7775v8/HwzfPhwc/ToUdej9Vhra6vZt2+f2bdvn5FknnjiCbNv3z7zr3/9yxhjzGOPPWaSkpLMli1bzP79+83cuXNNdna2OXXqlOPJ7ZzrOFtbW829995rKisrTV1dnXn77bfNt771LTNu3DjT3t7uevSwLVu2zHi9XlNeXm4aGhqC28mTJ4M1S5cuNZmZmeadd94xe/bsMXl5eSYvL8/h1PbOd5yHDh0yjz76qNmzZ4+pq6szW7ZsMaNHjzYzZsxwPLmdBx54wFRUVJi6ujqzf/9+88ADD5iYmBjz17/+1Rhz4c5lVASQMcb8+te/NpmZmSY+Pt5MmzbN7Nixw/VIEbVgwQKTlpZm4uPjzTe+8Q2zYMECc+jQIddj9cq7775rJJ22LVy40Bjz31uxH3roIZOammo8Ho+ZNWuWqampcTt0D5zrOE+ePGkKCgrMiBEjTFxcnMnKyjJ333131P3ydKbjk2TWrl0brDl16pT50Y9+ZC6//HIzZMgQc9ttt5mGhgZ3Q/fA+Y7z8OHDZsaMGSY5Odl4PB4zduxY89Of/tT4/X63g1v64Q9/aLKyskx8fLwZMWKEmTVrVjB8jLlw55KPYwAAOHHRvwcEAOifCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAODE/wOKg1B1pGd9kgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "# Plot the first image\n", + "plt.imshow(data[0][0], cmap=\"gray\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "n = 1024\n", + "m = 1024\n", + "\n", + "X = torch.randn(m,m)\n", + "W = nn.Parameter(torch.randn(n, m))\n", + "b = nn.Parameter(torch.randn(n))\n", + "\n", + "\n", + "ref = F.linear(X, W, b)\n", + "out = F.linear(X, W) + b\n", + "\n", + "torch.testing.assert_close(ref, out, msg=lambda msg: f\"{name}:\\n{msg}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env-pipegoose", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.2" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/convergence/train_mlp.py b/tests/convergence/train_mlp.py new file mode 100644 index 0000000..0d9078a --- /dev/null +++ b/tests/convergence/train_mlp.py @@ -0,0 +1,138 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import datasets, transforms +from torch.utils.data import DataLoader, random_split + +from pipegoose.utils.logger import Logger + +def seed_everything(seed: int): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +class NN(nn.Module): + def __init__(self, input_size, output_size): + super(NN, self).__init__() + self.debug_single_mlp = nn.Linear(input_size, output_size) + + def forward(self, x): + x = torch.flatten(x, 1) + x = self.debug_single_mlp(x) + return x + +class MNISTloader: + def __init__( + self, + batch_size: int = 64, + data_dir: str = "./data/", + num_workers: int = 0, + pin_memory: bool = False, + shuffle: bool = False, + train_val_split: float = 0.1, + ): + self.batch_size = batch_size + self.data_dir = data_dir + self.num_workers = num_workers + self.pin_memory = pin_memory + self.shuffle = shuffle + self.train_val_split = train_val_split + + self.setup() + + def setup(self): + transform = transforms.Compose( + [ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5], std=[0.5]), + ] + ) + + self.train_dataset = datasets.MNIST( + self.data_dir, train=True, download=True, transform=transform + ) + val_split = int(len(self.train_dataset) * self.train_val_split) + train_split = len(self.train_dataset) - val_split + + self.train_dataset, self.val_dataset = random_split( + self.train_dataset, [train_split, val_split] + ) + self.test_dataset = datasets.MNIST( + self.data_dir, train=False, download=True, transform=transform + ) + + print( + "Image Shape: {}".format(self.train_dataset[0][0].numpy().shape), + end="\n\n", + ) + print("Training Set: {} samples".format(len(self.train_dataset))) + print("Validation Set: {} samples".format(len(self.val_dataset))) + print("Test Set: {} samples".format(len(self.test_dataset))) + + def load(self): + train_loader = DataLoader( + dataset=self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + ) + + val_loader = DataLoader( + dataset=self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + ) + + test_loader = DataLoader( + dataset=self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + ) + + return train_loader, val_loader, test_loader + +if __name__ == "__main__": + seed_everything(42) + LR = 0.001 + EPOCHS = 30 + + model = NN(input_size=32 * 32, output_size=10) + device = torch.device("cuda") + optimizer = optim.SGD(model.parameters(), LR) + criterion = nn.CrossEntropyLoss() + train_loader, _, _ = MNISTloader(train_val_split=0.).load() + + model = model.to(device) + + for epoch in range(EPOCHS): + + train_loss_running, train_acc_running = 0, 0 + + for inputs, labels in train_loader: + + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + + outputs = model(inputs) + + _, predictions = torch.max(outputs, dim=1) + loss = criterion(outputs, labels) + + loss.backward() + optimizer.step() + + train_loss_running += loss.item() * inputs.shape[0] + train_acc_running += torch.sum(predictions == labels.data) + + train_loss = train_loss_running / len(train_loader.sampler) + train_acc = train_acc_running / len(train_loader.sampler) + + info = "Epoch: {:3}/{} \t train_loss: {:.3f} \t train_acc: {:.3f}" + Logger()(info.format(epoch + 1, EPOCHS, train_loss, train_acc)) + torch.save(model.state_dict(), "model.pt") \ No newline at end of file