Skip to content

[diff] Llm custom merged #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
51 changes: 29 additions & 22 deletions megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import torch
import os

logger = logging.getLogger(__name__)

from .package_info import (
__description__,
__contact_names__,
__url__,
__download_url__,
__keywords__,
__license__,
__package_name__,
__version__,
)

from .global_vars import get_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import get_signal_handler
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .initialize import initialize_megatron
if "MEGATRON_SETUP" not in os.environ:
from .global_vars import get_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import get_signal_handler
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .initialize import initialize_megatron

def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
logger.info(str(message))

def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)

def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
logger.info(str(message))
63 changes: 31 additions & 32 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
"""Megatron arguments."""

import argparse
import logging
import os

import torch

logger = logging.getLogger(__name__)

def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
Expand Down Expand Up @@ -74,13 +77,12 @@ def parse_args(extra_args_provider=None, defaults={},
'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // model_parallel_size
if args.rank == 0:
print('using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '.format(
args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True)
logger.info('using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '.format(
args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size))
if args.pipeline_model_parallel_size > 1:
if args.pipeline_model_parallel_split_rank is not None:
assert args.pipeline_model_parallel_split_rank < \
Expand Down Expand Up @@ -112,11 +114,9 @@ def parse_args(extra_args_provider=None, defaults={},
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
logger.warning('Overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)))
else:
setattr(args, key, defaults[key])

Expand All @@ -125,9 +125,8 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.micro_batch_size > 0
if args.global_batch_size is None:
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
if args.rank == 0:
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
logger.info('setting global batch size to {}'.format(
args.global_batch_size))
assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.pipeline_model_parallel_size > 2, \
Expand All @@ -154,13 +153,10 @@ def parse_args(extra_args_provider=None, defaults={},
# be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.', flush=True)
logger.info('accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.')

if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
logger.info('using {} for parameters ...'.format(args.params_dtype))

# If we do accumulation and all-reduces in fp32, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is not off.
Expand Down Expand Up @@ -275,17 +271,14 @@ def parse_args(extra_args_provider=None, defaults={},

def _print_args(args):
"""Print arguments."""
if args.rank == 0:
print('------------------------ arguments ------------------------',
flush=True)
str_list = []
for arg in vars(args):
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print('-------------------- end of arguments ---------------------',
flush=True)
logger.info('------------------------ arguments ------------------------')
str_list = []
for arg in vars(args):
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
logger.info(arg)
logger.info('-------------------- end of arguments ---------------------')


def _check_arg_is_not_none(args, arg):
Expand Down Expand Up @@ -350,8 +343,12 @@ def _add_network_size_args(parser):
def _add_logging_args(parser):
group = parser.add_argument_group(title='logging')

group.add_argument('--name', type=str, default=None,
help='A name for the experiment.')
group.add_argument('--log-params-norm', action='store_true',
help='If set, calculate and log parameters norm.')
group.add_argument('--log-scales', action='store_true',
help='Log the scales of parameters, gradients and activations.')
group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--tensorboard-log-interval', type=int, default=1,
Expand Down Expand Up @@ -708,6 +705,8 @@ def _add_data_args(parser):
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--indexmap-path', type=str, default=None,
help='Path for intermediate data files')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
Expand Down
2 changes: 1 addition & 1 deletion megatron/data/bert_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def build_training_sample(sample,
if binary_head:
# We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length
assert target_seq_length <= max_seq_length-2

# Divide sample into two segments (A and B).
if binary_head:
Expand Down
24 changes: 17 additions & 7 deletions megatron/data/biencoder_dataset_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import time
from pathlib import Path

import numpy as np
import torch
import torch.distributed

from megatron import get_args, get_tokenizer, mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, \
Expand Down Expand Up @@ -146,6 +148,12 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
indexmap_filename += '_1sentok'
indexmap_filename += '.npy'

args=get_args()
if args.indexmap_path is not None:
indexmap_path=Path(get_args().indexmap_path).resolve()
indexmap_path.mkdir(parents=True, exist_ok=True)
indexmap_filename = indexmap_path/Path(indexmap_filename).name

# Build the indexed mapping if not exist.
if mpu.get_data_parallel_rank() == 0 and \
not os.path.isfile(indexmap_filename):
Expand Down Expand Up @@ -184,13 +192,15 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
'(seconds): {:4f}'.format(
time.time() - start_time))

# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Wait until rank 0 generate the index file.
torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])], group=mpu.get_data_parallel_group())
# It can take some time for the file to be visible on other nodes.
for i in range(120):
if indexmap_filename.is_file():
break
if i%10==0:
print_rank_0(" Waiting for index file...")
time.sleep(1.0)

# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
Expand Down
30 changes: 20 additions & 10 deletions megatron/data/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import os
import time
import collections
from pathlib import Path

import numpy as np
import torch
import torch.distributed

from megatron import (
get_args,
Expand Down Expand Up @@ -446,7 +448,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, binary_head, dataset_type=dataset_type)
seed, skip_warmup, binary_head,max_seq_length_dec, dataset_type=dataset_type)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
Expand Down Expand Up @@ -661,6 +663,12 @@ def get_samples_mapping(indexed_dataset,
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'

args=get_args()
if args.indexmap_path is not None:
indexmap_path=Path(args.indexmap_path).resolve()
indexmap_path.mkdir(parents=True, exist_ok=True)
indexmap_filename = indexmap_path/Path(indexmap_filename).name

# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
Expand Down Expand Up @@ -696,15 +704,17 @@ def get_samples_mapping(indexed_dataset,
print_rank_0(' > elasped time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))

# Wait until rank 0 generate the index file.
print_rank_0(f"Barrier device {int(os.environ['LOCAL_RANK'])}")
torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])], group=mpu.get_data_parallel_group())
# It can take some time for the file to be visible on other nodes.
for i in range(120):
if indexmap_filename.is_file():
break
if i%10==0:
print_rank_0(" Waiting for index file...")
time.sleep(1.0)

# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
Expand Down
31 changes: 21 additions & 10 deletions megatron/data/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

import os
import time
from pathlib import Path

import numpy as np
import torch
import torch.distributed

from megatron import mpu, print_rank_0
from megatron import mpu, print_rank_0, get_args
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples
from megatron.data.dataset_utils import get_train_valid_test_split_
Expand Down Expand Up @@ -211,6 +213,14 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
sample_idx_filename = _filename + '_sample_idx.npy'
shuffle_idx_filename = _filename + '_shuffle_idx.npy'

args=get_args()
if args.indexmap_path is not None:
indexmap_path=Path(args.indexmap_path).resolve()
indexmap_path.mkdir(parents=True, exist_ok=True)
doc_idx_filename = indexmap_path/Path(doc_idx_filename).name
sample_idx_filename = indexmap_path/Path(sample_idx_filename).name
shuffle_idx_filename = indexmap_path/Path(shuffle_idx_filename).name

# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if (not os.path.isfile(doc_idx_filename)) or \
Expand Down Expand Up @@ -293,15 +303,16 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))

# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Wait until rank 0 generate the index file.
torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])], group=mpu.get_data_parallel_group())

# It can take some time for the file to be visible on other nodes.
for i in range(120):
if doc_idx_filename.is_file() and sample_idx_filename.is_file() and shuffle_idx_filename.is_file():
break
if i%10==0:
print_rank_0(" Waiting for index files...")
time.sleep(1.0)

# Load mappings.
start_time = time.time()
Expand Down
Loading