Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logger branch #35

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ cd pipegoose/examples
torchrun --standalone --nnodes=1 --nproc-per-node=4 hybrid_parallelism.py
```

We did a small scale correctness test by comparing the validation losses between a paralleized transformer and one kept by default, starting at identical checkpoints and training data. We will conduct rigorous large scale convergence and weak scaling law benchmarks against Megatron and DeepSpeed in the near future if we manage to make it.
We did a small scale correctness test by comparing the validation losses between a parallelized transformer and one kept by default, starting at identical checkpoints and training data. We will conduct rigorous large scale convergence and weak scaling law benchmarks against Megatron and DeepSpeed in the near future if we manage to make it.
- Data Parallelism [[link]](https://wandb.ai/xariusdrake/pipegoose/runs/t5cr56xd?workspace)
- ~~Tensor Parallelism [[link]](https://wandb.ai/xariusdrake/pipegoose/runs/iz17f50n)~~ (We've found a bug in convergence, and we are fixing it)
- ~~Hybrid 2D Parallelism (TP+DP) [[link]](https://wandb.ai/xariusdrake/pipegoose/runs/us31p3q1)~~
1 change: 1 addition & 0 deletions logs/latency_logger.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[INFO] hello, [WARNING] hello, [DEBUG] hello, [ERROR] hello, [INFO] hello, [INFO] hello, [INFO] hello, [WARNING] hello, [DEBUG] hello, [ERROR] hello, [INFO] hello, [WARNING] hello, [DEBUG] hello, [ERROR] hello, [INFO] hello, [WARNING] hello, [DEBUG] hello, [ERROR] hello
2 changes: 1 addition & 1 deletion pipegoose/constants.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
# ==================================================


# NOTE: the minimum number of cocurrent worker threads that execute jobs
# NOTE: the minimum number of concurrent worker threads that execute jobs
# in the background of pipeline parallelism
PIPELINE_MIN_WORKERS = 1
PIPELINE_MAX_WORKERS = 1
1 change: 0 additions & 1 deletion pipegoose/core/bucket/dist.py
Original file line number Diff line number Diff line change
@@ -56,7 +56,6 @@ def execute(self, tensor: torch.Tensor, parallel_mode: ParallelMode):
# then empty and refill the bucket with the tensor
key = (tensor.dtype, parallel_mode)
if key not in self.buckets:

self.buckets[key] = Bucket(self.bucket_size, tensor.dtype, self.parallel_context)
else:
bucket = self.buckets[key]
2 changes: 0 additions & 2 deletions pipegoose/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
59 changes: 59 additions & 0 deletions pipegoose/distributed/logger/DistributedLogger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os

from pipegoose.distributed import ParallelMode


class DistributedLogger:
def __init__(self, name, parallel_context):
self.name = name
self.parallel_context = parallel_context
# Initialize file handling and logging configurations

def _should_log(self, rank, parallel_mode):
current_rank = self.parallel_context.get_global_rank()
rank_check = rank is None or rank == current_rank

# Check if the current parallel mode is initialized and if the current process is part of it
mode_check = self.parallel_context.is_initialized(parallel_mode)

return rank_check and mode_check

def _save_log(self, path, log):
# Add code to save the log file to the specified path
log_name = self.name + ".txt"

# check if path directory exists
if not os.path.exists(path):
os.makedirs(path)

# check if log file exists
if not os.path.isfile(path + log_name):
with open(os.path.join(path, log_name), "a") as f:
f.write(log)
else:
with open(os.path.join(path, log_name), "a") as f:
f.write(", " + log)

def _log_message(self, message, level, rank=None, parallel_mode=ParallelMode.GLOBAL):
if self._should_log(rank, parallel_mode):
# Print and save the message
log = f"[{level}] {message}"
print(log)
# Add code to save the message to a file
self._save_log("logs/", log)
# else:
# print(f"Process {self.parallel_context.get_global_rank()} is not part of the {parallel_mode} parallel mode")

# The logging methods (info, warning, debug, error) remain the same

def info(self, message, rank=None, parallel_mode=ParallelMode.GLOBAL):
self._log_message(message, "INFO", rank, parallel_mode)

def warning(self, message, rank=None, parallel_mode=ParallelMode.GLOBAL):
self._log_message(message, "WARNING", rank, parallel_mode)

def debug(self, message, rank=None, parallel_mode=ParallelMode.GLOBAL):
self._log_message(message, "DEBUG", rank, parallel_mode)

def error(self, message, rank=None, parallel_mode=ParallelMode.GLOBAL):
self._log_message(message, "ERROR", rank, parallel_mode)
2 changes: 1 addition & 1 deletion pipegoose/nn/pipeline_parallel/pipeline_context.py
Original file line number Diff line number Diff line change
@@ -75,7 +75,7 @@ def clock_idx(self) -> int:
return self._clock_idx

def increase_a_clock_cycle(self):
"""Increase the current clock cycle in the pipline by 1."""
"""Increase the current clock cycle in the pipeline by 1."""
# TODO: add assert maximum clock cycles
with self._wait_new_clock_cycle:
self._clock_idx += 1
1 change: 0 additions & 1 deletion pipegoose/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from pipegoose.optim.zero.optim import DistributedOptimizer
133 changes: 133 additions & 0 deletions tests/distributed/test_DistributedLogger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os
from multiprocessing import Process
from unittest.mock import mock_open, patch

import pytest
import torch.distributed as dist

from pipegoose.distributed import ParallelMode
from pipegoose.distributed.logger import DistributedLogger
from pipegoose.testing.utils import find_free_port, init_parallel_context


@pytest.fixture
def logger(parallel_context):
return DistributedLogger("test_logger", parallel_context)


@pytest.fixture(params=[1, 1])
def tensor_parallel_size(request):
return request.param


def should_log_test(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, logger_name):
context = init_parallel_context(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size)
logger = DistributedLogger(logger_name, context)
assert logger._should_log(rank, ParallelMode.GLOBAL) == True
dist.destroy_process_group()


@pytest.mark.parametrize("tensor_parallel_size", [1, 1])
@pytest.mark.parametrize("data_parallel_size", [1, 1])
@pytest.mark.parametrize("pipeline_parallel_size", [1, 1])
def test_should_log(tensor_parallel_size, data_parallel_size, pipeline_parallel_size):
world_size = tensor_parallel_size * pipeline_parallel_size * data_parallel_size
port = find_free_port()
processes = []

for rank in range(world_size):
p = Process(
target=should_log_test,
args=(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, "test_logger"),
)
p.start()
processes.append(p)

for p in processes:
p.join()


def test_save_log(logger):
# Test when path directory does not exist
with patch.object(os.path, "exists", return_value=False):
with patch.object(os, "makedirs") as mock_makedirs:
with patch("builtins.open", mock_open()) as mock_file:
logger._save_log("logs/", "test log")
mock_makedirs.assert_called_once_with("logs/")
mock_file.assert_called_once_with("logs/test_logger.txt", "a")
mock_file().write.assert_called_once_with("test log")

# Test when log file does not exist
with patch.object(os.path, "exists", return_value=True):
with patch("builtins.open", mock_open()) as mock_file:
logger._save_log("logs/", "test log")
mock_file.assert_called_once_with("logs/test_logger.txt", "a")
mock_file().write.assert_called_once_with("test log")

# Test when log file exists
with patch.object(os.path, "exists", return_value=True):
with patch("builtins.open", mock_open()) as mock_file:
with patch.object(os.path, "isfile", return_value=True):
logger._save_log("logs/", "test log")
mock_file.assert_called_once_with("logs/test_logger.txt", "a")
mock_file().write.assert_called_once_with(", test log")


def test_log_message(logger):
# Test when should_log is True
with patch.object(logger, "_should_log", return_value=True):
with patch("builtins.print") as mock_print:
with patch.object(logger, "_save_log") as mock_save_log:
logger._log_message("test message", "INFO")
mock_print.assert_called_once_with("[INFO] test message")
mock_save_log.assert_called_once_with("logs/", "[INFO] test message")

with patch.object(logger, "_should_log", return_value=False) as mock_should_log:
with patch("builtins.print") as mock_print:
with patch.object(logger, "_save_log") as mock_save_log:
logger._log_message("test message", "INFO")
mock_should_log.assert_called() # Ensure _should_log is being called
mock_print.assert_not_called()
mock_save_log.assert_not_called()


def test_log_message_with_rank_zero_global(logger):
# Test _log_message with rank=0 and parallel_mode=ParallelMode.GLOBAL
with patch.object(logger, "_should_log", return_value=True):
with patch("builtins.print") as mock_print:
with patch.object(logger, "_save_log") as mock_save_log:
logger._log_message("test message", "INFO", rank=0, parallel_mode=ParallelMode.GLOBAL)
mock_print.assert_called_once_with("[INFO] test message")
mock_save_log.assert_called_once_with("logs/", "[INFO] test message")

# You can also test the behavior when _should_log returns False
with patch.object(logger, "_should_log", return_value=False):
with patch("builtins.print") as mock_print:
with patch.object(logger, "_save_log") as mock_save_log:
logger._log_message("test message", "INFO", rank=0, parallel_mode=ParallelMode.GLOBAL)
mock_print.assert_not_called()
mock_save_log.assert_not_called()


def test_info(logger):
with patch.object(logger, "_log_message") as mock_log_message:
logger.info("test message")
mock_log_message.assert_called_once_with("test message", "INFO", None, ParallelMode.GLOBAL)


def test_warning(logger):
with patch.object(logger, "_log_message") as mock_log_message:
logger.warning("test message")
mock_log_message.assert_called_once_with("test message", "WARNING", None, ParallelMode.GLOBAL)


def test_debug(logger):
with patch.object(logger, "_log_message") as mock_log_message:
logger.debug("test message")
mock_log_message.assert_called_once_with("test message", "DEBUG", None, ParallelMode.GLOBAL)


def test_error(logger):
with patch.object(logger, "_log_message") as mock_log_message:
logger.error("test message")
mock_log_message.assert_called_once_with("test message", "ERROR", None, ParallelMode.GLOBAL)