Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/fix hybrid tp dp #38

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ __pycache__/
*.so

# Distribution / packaging
wandb/
.Python
build/
develop-eggs/
Expand Down
11 changes: 7 additions & 4 deletions pipegoose/distributed/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def __init__(
self.init_global_dist(rank, world_size, backend, host, port)
self.init_parallel_groups()

# if torch.cuda.is_available():
# self.set_device()
if torch.cuda.is_available() and backend == "nccl":
self.set_device()

self.map_rank_to_device()

Expand Down Expand Up @@ -261,17 +261,20 @@ def set_seed(self, seed: int):

def map_rank_to_device(self):
"""Map global rank to device."""

rank_tensor = torch.zeros(len(self._local_ranks), dtype=torch.long)
rank_tensor = rank_tensor.cuda() if torch.cuda.is_available() else rank_tensor

for idx, local_rank in enumerate(self._local_ranks.values()):
rank_tensor[idx] = local_rank

rank_tensor_list = [
torch.zeros(rank_tensor.size(), dtype=torch.long) for _ in range(self.get_world_size(ParallelMode.GLOBAL))
torch.zeros(rank_tensor.size(), dtype=torch.long).cuda() if torch.cuda.is_available() else torch.zeros(rank_tensor.size(), dtype=torch.long)
for _ in range(self.get_world_size(ParallelMode.GLOBAL))
]

dist.all_gather(tensor_list=rank_tensor_list, tensor=rank_tensor)

for _rank, _rank_tensor in enumerate(rank_tensor_list):
modes_and_ranks = {mode: rank for mode, rank in zip(self._local_ranks.keys(), _rank_tensor.tolist())}
self._ranks_to_device[tuple(modes_and_ranks.items())] = _rank
Expand Down
5 changes: 4 additions & 1 deletion pipegoose/nn/tensor_parallel/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ def backward(ctx: Any, grad: torch.Tensor) -> Tuple[torch.Tensor, None, None]:

all_reduce(grad, parallel_context=parallel_context, parallel_mode=ParallelMode.TENSOR)

return (grad, None, None)
return (
grad,
None,
)


class _Gather(Function):
Expand Down
9 changes: 4 additions & 5 deletions pipegoose/nn/tensor_parallel/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,16 @@ def __init__(

if bias is True:
self.bias = nn.Parameter(torch.randn(out_per_partition))

else:
self.bias = None

def _get_output_per_partition(self, out_features: int, parallel_context: ParallelContext) -> int:
local_world_size = parallel_context.get_world_size(ParallelMode.TENSOR)
return out_features // local_world_size

def forward(self, input: torch.Tensor) -> torch.Tensor:
input_parallel = broadcast_to_tensor_group(input, self.parallel_context)
outputs = F.linear(input_parallel, self.weight)

if self.bias is not None:
outputs = outputs + self.bias
outputs = F.linear(input_parallel, self.weight, self.bias)

if self.gather_output:
outputs = gather_to_tensor_group(outputs, dim=-1, parallel_context=self.parallel_context)
Expand Down
2 changes: 2 additions & 0 deletions pipegoose/nn/tensor_parallel/parallel_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ParallelMapping:
Row(("mlp.dense_4h_to_h", "self_attention.dense")),
LMHead(("lm_head",)),
],
"debug_single_mlp": [Column(("debug_single_mlp",))],

}

@staticmethod
Expand Down
11 changes: 7 additions & 4 deletions pipegoose/nn/tensor_parallel/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,16 @@
from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.nn.parallel import Parallel
from pipegoose.nn.tensor_parallel.parallelizer import (
EmbeddingParallelizer,
LayerNormParallelizer,
LinearParallelizer,
LMHeadParallelizer,
ModuleParallelizer,
)


class TensorParallel(Parallel):
"""Turn a 🤗 transformers model into a tensor parallel model."""

PARALLELIZERS = [EmbeddingParallelizer, LinearParallelizer, LayerNormParallelizer, LMHeadParallelizer]
# PARALLELIZERS = [EmbeddingParallelizer, LinearParallelizer, LayerNormParallelizer, LMHeadParallelizer]
PARALLELIZERS = [LinearParallelizer]

def __init__(self, module: nn.Module, parallel_context: ParallelContext):
self.module = module
Expand All @@ -33,6 +31,11 @@ def parallelize(self) -> nn.Module:
# multiple times. so we filter out and retain the non-repetitive modules (leaf modules)
leaf_modules = self._get_leaf_modules(module)
for module_name, leaf_module in leaf_modules:
# NOTE: just skip parallelizing query_key_value in attention
# for debugging purposes
if "query_key_value" in module_name:
continue

parallelizer = self._find_parallelizer(module_name, leaf_module)
if parallelizer is not None:
parallelizer(module_name, leaf_module, module, self.parallel_context).parallelize()
Expand Down
154 changes: 154 additions & 0 deletions pipegoose/utils/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import datetime
import inspect
import sys
import os
import wandb
import glob
import re
import os

class Logger:
# https://github.com/Cadene/bootstrap.pytorch/blob/master/bootstrap/lib/logger.py
""" The Logger class is a singleton. It contains all the utilities
for logging variables in a key-value dictionary.
It can also be considered as a replacement for the print function.

.. code-block:: python

Logger(dir_logs='logs/mnist')
Logger().flush() # write the logs.json
Logger()("Launching training procedures") # written to logs.txt
> [I 2018-07-23 18:58:31] ...trap/engines/engine.py.80: Launching training procedures
"""

DEBUG = -1
INFO = 0
SUMMARY = 1
WARNING = 2
ERROR = 3
SYSTEM = 4
_instance = None
indicator = {DEBUG: 'D', INFO: 'I', SUMMARY: 'S', WARNING: 'W', ERROR: 'E', SYSTEM: 'S'}

class Colors:
END = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
GREY = 30
RED = 31
GREEN = 32
YELLOW = 33
BLUE = 34
PURPLE = 35
SKY = 36
WHITE = 37
BACKGROUND = 10
LIGHT = 60

@staticmethod
def code(value):
return '\033[{}m'.format(value)

colorcode = {
DEBUG: Colors.code(Colors.GREEN),
INFO: Colors.code(Colors.GREY + Colors.LIGHT),
SUMMARY: Colors.code(Colors.BLUE + Colors.LIGHT),
WARNING: Colors.code(Colors.YELLOW + Colors.LIGHT),
ERROR: Colors.code(Colors.RED + Colors.LIGHT),
SYSTEM: Colors.code(Colors.WHITE + Colors.LIGHT)
}

compactjson = True
log_level = None # log level
dir_logs = None
path_json = None
path_txt = None
file_txt = None
name = None
max_lineno_width = 3

def __new__(cls, dir_logs=None, name='logs'):
if Logger._instance is None:
Logger._instance = object.__new__(Logger)

if dir_logs:
Logger._instance.name = name
Logger._instance.dir_logs = dir_logs
Logger._instance.path_txt = os.path.join(dir_logs, '{}.txt'.format(name))
Logger._instance.file_txt = open(os.path.join(dir_logs, '{}.txt'.format(name)), 'a+')
# NOTE: Support json or CSV ?
# Logger._instance.path_json = os.path.join(dir_logs, '{}.json'.format(name))
# Logger._instance.reload_json()
else:
Logger._instance.log_message('No logs files will be created (dir_logs attribute is empty)',
log_level=Logger.WARNING)

return Logger._instance

def __call__(self, *args, **kwargs):
return self.log_message(*args, **kwargs, stack_displacement=2)

def log_message(self, *message, log_level=INFO, break_line=True, print_header=True, stack_displacement=1,
raise_error=True, adaptive_width=True):

if self.dir_logs and not self.file_txt:
raise Exception('Critical: Log file not defined. Do you have write permissions for {}?'.format(self.dir_logs))

caller_info = inspect.getframeinfo(inspect.stack()[stack_displacement][0])
message = ' '.join([str(m) for m in list(message)])

if print_header:
message_header = '[{} {:%Y-%m-%d %H:%M:%S}]'.format(self.indicator[log_level],
datetime.datetime.now())
filename = caller_info.filename
if adaptive_width:
# allows the lineno_width to grow when necessary
lineno_width = len(str(caller_info.lineno))
self.max_lineno_width = max(lineno_width, self.max_lineno_width)
else:
# manually fix it to 3 numbers
lineno_width = 3

if len(filename) > 28 - self.max_lineno_width:
filename = '...{}'.format(filename[-22 - (self.max_lineno_width - lineno_width):])

message_locate = '{}.{}:'.format(filename, caller_info.lineno)
message_logger = '{} {} {}'.format(message_header, message_locate, message)
message_screen = '{}{}{}{} {} {}'.format(self.Colors.BOLD,
self.colorcode[log_level],
message_header,
self.Colors.END,
message_locate,
message)
else:
message_logger = message
message_screen = message

if break_line:
print(message_screen)
if self.dir_logs:
self.file_txt.write('%s\n' % message_logger)
else:
print(message_screen, end='')
sys.stdout.flush()
if self.dir_logs:
self.file_txt.write(message_logger)

if self.dir_logs:
self.file_txt.flush()
if log_level == self.ERROR and raise_error:
raise Exception(message)

def update_log_file(self, path_src, path_dst):
"""
Append content of file at path_src to file at path_dst
"""

with open(path_src, 'r') as f:
lines_src = f.readlines()

with open(path_dst, 'r') as f:
lines_dst = f.readlines()

with open(path_dst, 'w') as f:
f.writelines(lines_src + ["\n"] + lines_dst)
Binary file added tests/convergence/debug_batch.pt
Binary file not shown.
Binary file added tests/convergence/debug_target.pt
Binary file not shown.
Binary file added tests/convergence/model.pt
Binary file not shown.
Loading