Skip to content

Commit

Permalink
120 line length, add pyproject.toml
Browse files Browse the repository at this point in the history
  • Loading branch information
achalddave committed Nov 8, 2023
1 parent 6893dc4 commit 482b91a
Show file tree
Hide file tree
Showing 30 changed files with 145 additions and 448 deletions.
16 changes: 4 additions & 12 deletions datapreprocess/make_2048.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,10 @@ def dump_queue_to_buffer():
buffer.append(chunk)


def consumer(
my_id, output_dir, threads, buffer, buffer_lock, num_consumers, upload_to_s3=False
):
def consumer(my_id, output_dir, threads, buffer, buffer_lock, num_consumers, upload_to_s3=False):
output_directory = f"{output_dir}/{CHUNK_SIZE - 1}-v1/{my_id}"
os.makedirs(output_directory, exist_ok=True)
shard_writer = ShardWriter(
os.path.join(output_directory, "shard-%07d.tar"), maxcount=SHARD_SIZE
)
shard_writer = ShardWriter(os.path.join(output_directory, "shard-%07d.tar"), maxcount=SHARD_SIZE)

chunks = []

Expand All @@ -155,14 +151,10 @@ def consumer(
chunks = []
time_for_shard = time.time() - start_time
print("shards / s", num_consumers / time_for_shard)
print(
"tokens / s", num_consumers * SHARD_SIZE * CHUNK_SIZE / time_for_shard
)
print("tokens / s", num_consumers * SHARD_SIZE * CHUNK_SIZE / time_for_shard)
print(
"hours req for 1.2T tokens",
1_200_000_000_000
/ (num_consumers * SHARD_SIZE * CHUNK_SIZE / time_for_shard)
/ 3600,
1_200_000_000_000 / (num_consumers * SHARD_SIZE * CHUNK_SIZE / time_for_shard) / 3600,
)

start_time = time.time()
Expand Down
16 changes: 4 additions & 12 deletions datapreprocess/make_assistant_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,10 @@ def dump_queue_to_buffer():
buffer.append(chunk)


def consumer(
my_id, output_dir, threads, buffer, buffer_lock, num_consumers, upload_to_s3=False
):
def consumer(my_id, output_dir, threads, buffer, buffer_lock, num_consumers, upload_to_s3=False):
output_directory = f"{output_dir}/{CHUNK_SIZE - 1}-v1/{my_id}"
os.makedirs(output_directory, exist_ok=True)
shard_writer = ShardWriter(
os.path.join(output_directory, "shard-%07d.tar"), maxcount=SHARD_SIZE
)
shard_writer = ShardWriter(os.path.join(output_directory, "shard-%07d.tar"), maxcount=SHARD_SIZE)

chunks = []

Expand All @@ -130,14 +126,10 @@ def consumer(
chunks = []
time_for_shard = time.time() - start_time
print("shards / s", num_consumers / time_for_shard)
print(
"tokens / s", num_consumers * SHARD_SIZE * CHUNK_SIZE / time_for_shard
)
print("tokens / s", num_consumers * SHARD_SIZE * CHUNK_SIZE / time_for_shard)
print(
"hours req for 1.2T tokens",
1_200_000_000_000
/ (num_consumers * SHARD_SIZE * CHUNK_SIZE / time_for_shard)
/ 3600,
1_200_000_000_000 / (num_consumers * SHARD_SIZE * CHUNK_SIZE / time_for_shard) / 3600,
)

start_time = time.time()
Expand Down
36 changes: 8 additions & 28 deletions datapreprocess/ray/tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,13 @@ def dist_tokenize(data, tokenizer, content_key):
def cut_to_context(jsonl_batch, seqlen=1024, pad_type=PadType.CIRCULAR):
tokens_list = jsonl_batch["tokens"]
flat_token_list = [item for sublist in tokens_list for item in sublist]
repartioned_lists = [
flat_token_list[i : i + seqlen] for i in range(0, len(flat_token_list), seqlen)
]
repartioned_lists = [flat_token_list[i : i + seqlen] for i in range(0, len(flat_token_list), seqlen)]
end_len = len(repartioned_lists[-1])
if len(repartioned_lists[-1]) < seqlen:
if pad_type == PadType.CIRCULAR:
repartioned_lists[-1] = (
repartioned_lists[-1] + repartioned_lists[0][: (seqlen - end_len)]
)
repartioned_lists[-1] = repartioned_lists[-1] + repartioned_lists[0][: (seqlen - end_len)]
else:
repartioned_lists[-1] = repartioned_lists[-1] + [
SpecialTokens.PAD.value
] * (seqlen - end_len)
repartioned_lists[-1] = repartioned_lists[-1] + [SpecialTokens.PAD.value] * (seqlen - end_len)
return {"tokens": repartioned_lists}


Expand Down Expand Up @@ -237,11 +231,7 @@ def glob_files(path, suffix=".jsonl"):
# List the objects in the bucket with the given prefix
paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
all_files = [
f"s3://{bucket_name}/{obj['Key']}"
for objects in pages
for obj in objects.get("Contents", [])
]
all_files = [f"s3://{bucket_name}/{obj['Key']}" for objects in pages for obj in objects.get("Contents", [])]

# Filter out the files based on the suffix
matching_files = [f for f in all_files if f.endswith(suffix)]
Expand All @@ -264,9 +254,7 @@ def get_filesystem(environment):
# Extract the AWS credentials from the environment dictionary
access_key = environment.get("AWS_ACCESS_KEY_ID")
secret_key = environment.get("AWS_SECRET_ACCESS_KEY")
session_token = environment.get(
"AWS_SESSION_TOKEN", None
) # Session token might be optional
session_token = environment.get("AWS_SESSION_TOKEN", None) # Session token might be optional

# Create and return the S3FileSystem
return fs.S3FileSystem(
Expand Down Expand Up @@ -307,11 +295,7 @@ def human_to_bytes(s):
"""

symbols = ("B", "K", "M", "G", "T", "P")
letter = (
s[-2:].strip().upper()
if s[-2:].strip().upper()[:-1] in symbols
else s[-1:].upper()
)
letter = s[-2:].strip().upper() if s[-2:].strip().upper()[:-1] in symbols else s[-1:].upper()
number = float(s[: -len(letter)].strip())

if letter == "B":
Expand Down Expand Up @@ -341,9 +325,7 @@ def human_to_bytes(s):
# e.g s3://dcnlp-data/rpj_tokenized_upsampled_eleutherai_deduplicated/
)
parser.add_argument("--content_key", type=str, default="text")
parser.add_argument(
"--no_shuffle", help="do not dedup + random shuffle", action="store_true"
)
parser.add_argument("--no_shuffle", help="do not dedup + random shuffle", action="store_true")
parser.add_argument("--seqlen", type=int, default=2048)
parser.add_argument("--pad_type", type=str, default="circular")
parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b")
Expand All @@ -355,9 +337,7 @@ def human_to_bytes(s):
parser.add_argument("--materialize", action="store_true")
parser.add_argument("--ray_address", type=str, default=None)
parser.add_argument("--block_size", type=str, default="10MB")
parser.add_argument(
"--ray_spill_location", type=str, default="s3://dcnlp-hub/ray_spill"
)
parser.add_argument("--ray_spill_location", type=str, default="s3://dcnlp-hub/ray_spill")

args = parser.parse_args()
# configure remote spilling
Expand Down
11 changes: 3 additions & 8 deletions eval/eval_openlm_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,12 @@ def evaluate(model, tokenizer, cfg):

in_memory_logger = InMemoryLogger() # track metrics in the in_memory_logger
loggers: List[LoggerDestination] = [
build_logger(name, logger_cfg)
for name, logger_cfg in (cfg.get("loggers") or {}).items()
build_logger(name, logger_cfg) for name, logger_cfg in (cfg.get("loggers") or {}).items()
]
loggers.append(in_memory_logger)

fsdp_config = cfg.get("fsdp_config", None)
fsdp_config = (
om.to_container(fsdp_config, resolve=True) if fsdp_config is not None else None
)
fsdp_config = om.to_container(fsdp_config, resolve=True) if fsdp_config is not None else None

load_path = cfg.get("load_path", None)

Expand Down Expand Up @@ -98,9 +95,7 @@ def main():
"""
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint")
parser.add_argument(
"--model", type=str, default="m1b_neox", help="Name of the model to use."
)
parser.add_argument("--model", type=str, default="m1b_neox", help="Name of the model to use.")
parser.add_argument("--eval-yaml")
parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b")
add_model_args(parser)
Expand Down
58 changes: 14 additions & 44 deletions open_lm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,7 @@ def log_and_continue(exn):
return True


def group_by_keys_nothrow(
data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None
):
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
Expand All @@ -164,11 +162,7 @@ def group_by_keys_nothrow(
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
if (
current_sample is None
or prefix != current_sample["__key__"]
or suffix in current_sample
):
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
if valid_sample(current_sample):
yield current_sample
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
Expand Down Expand Up @@ -290,9 +284,7 @@ def __iter__(self):
if self.weights is None:
yield dict(url=self.rng.choice(self.urls))
else:
yield dict(
url=self.rng.choices(self.urls, weights=self.weights, k=1)[0]
)
yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0])


def filter_lt_seqlen(seq_len, x):
Expand Down Expand Up @@ -330,9 +322,7 @@ def get_wds_dataset(
num_samples = force_num_samples[ii]
else:
if args.train_data_mix_weights is not None:
num_samples = int(
args.train_num_samples * args.train_data_mix_weights[ii]
)
num_samples = int(args.train_num_samples * args.train_data_mix_weights[ii])
else:
num_samples = args.train_num_samples // len(input_shards_)
else:
Expand All @@ -346,9 +336,7 @@ def get_wds_dataset(
# Eval will just exhaust the iterator if the size is not specified.
num_samples = args.val_num_samples or 0

shared_epoch = SharedEpoch(
epoch=epoch
) # create a shared epoch store to sync epoch to dataloader worker proc
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc

if resampled:
pipeline = [
Expand All @@ -372,9 +360,7 @@ def get_wds_dataset(
[
detshuffle2(
bufsize=0 if args.disable_buffer else _SHARD_SHUFFLE_SIZE,
initial=0
if args.disable_buffer
else _SHARD_SHUFFLE_INITIAL,
initial=0 if args.disable_buffer else _SHARD_SHUFFLE_INITIAL,
seed=args.seed,
epoch=shared_epoch,
),
Expand All @@ -389,9 +375,7 @@ def get_wds_dataset(
wds.shuffle(
bufsize=0 if args.disable_buffer else _SAMPLE_SHUFFLE_SIZE,
initial=0 if args.disable_buffer else _SHARD_SHUFFLE_INITIAL,
rng=random.Random(args.seed + shared_epoch.get_value())
if args.seed is not None
else None,
rng=random.Random(args.seed + shared_epoch.get_value()) if args.seed is not None else None,
),
]
)
Expand All @@ -407,9 +391,7 @@ def get_wds_dataset(
if data_key == "json":
pipeline.extend(
[
wds.map_dict(
json=partial(preprocess_json, vocab_size=args.vocab_size)
),
wds.map_dict(json=partial(preprocess_json, vocab_size=args.vocab_size)),
wds.to_tuple("json"),
wds.select(partial(filter_lt_seqlen, args.seq_len)),
wds.batched(args.batch_size, partial=not is_train),
Expand All @@ -418,9 +400,7 @@ def get_wds_dataset(
else:
pipeline.extend(
[
wds.map_dict(
txt=partial(preprocess_txt, vocab_size=args.vocab_size)
),
wds.map_dict(txt=partial(preprocess_txt, vocab_size=args.vocab_size)),
wds.to_tuple("txt"),
wds.select(partial(filter_lt_seqlen, args.seq_len)),
wds.batched(args.batch_size, partial=not is_train),
Expand All @@ -441,15 +421,9 @@ def get_wds_dataset(
if not resampled:
num_shards = num_shards or len(expand_urls(input_shards)[0])
if num_shards < args.workers * args.world_size:
print(
"Please increase --train-num-samples or decrease workers or world size"
)
print(
f"num_shards: {num_shards}, workers: {args.workers}, world_size: {args.world_size}"
)
assert (
num_shards >= args.workers * args.world_size
), "number of shards must be >= total workers"
print("Please increase --train-num-samples or decrease workers or world size")
print(f"num_shards: {num_shards}, workers: {args.workers}, world_size: {args.world_size}")
assert num_shards >= args.workers * args.world_size, "number of shards must be >= total workers"
# roll over and repeat a few samples to get same number of full batches on each node
round_fn = math.floor if floor else math.ceil
global_batch_size = args.batch_size * args.world_size
Expand All @@ -458,15 +432,11 @@ def get_wds_dataset(
for ii in range(len(datasets)):
num_batches = round_fn(all_num_samples[ii] / global_batch_size)
num_workers = max(1, args.workers)
num_worker_batches = round_fn(
num_batches / num_workers
) # per dataloader worker
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
num_batches = num_worker_batches * num_workers
num_samples = num_batches * global_batch_size
# TODO: what is the effect of setting this?
datasets[ii] = datasets[ii].with_epoch(
num_worker_batches
) # each worker is iterating over this
datasets[ii] = datasets[ii].with_epoch(num_worker_batches) # each worker is iterating over this

total_num_batches += num_batches
total_num_samples += num_samples
Expand Down
4 changes: 1 addition & 3 deletions open_lm/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def init_distributed_device(args):
else:
# DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url
)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True
Expand Down
27 changes: 8 additions & 19 deletions open_lm/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def remote_sync_s3(local_dir, remote_dir):
stderr=subprocess.PIPE,
)
if result.returncode != 0:
logging.error(
f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}"
)
logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}")
return False

logging.info(f"Successfully synced with S3 bucket")
Expand Down Expand Up @@ -89,9 +87,7 @@ def pt_save(pt_obj, file_path):

def _pt_load_s3_cp(file_path, map_location=None):
cmd = f"aws s3 cp {file_path} -"
proc = subprocess.Popen(
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = proc.communicate()
if proc.returncode != 0:
raise Exception(f"Failed to fetch model from s3. stderr: {stderr.decode()}")
Expand Down Expand Up @@ -179,22 +175,17 @@ def source_exhausted(paths, shard_list_per_source):
return False


def get_string_for_epoch(
num_samples, starting_chunk, paths, weights, min_shards_needed
):
def get_string_for_epoch(num_samples, starting_chunk, paths, weights, min_shards_needed):
if weights is None:
weights = [1.0 / len(paths) for _ in range(len(paths))]
needed_samples_per_source = [
int(np.ceil(weights[i] * num_samples / sum(weights)))
for i in range(len(weights))
]
needed_samples_per_source = [int(np.ceil(weights[i] * num_samples / sum(weights))) for i in range(len(weights))]
shard_strings_per_source = []
next_chunk = starting_chunk
shard_list_per_source = [[] for _ in range(len(paths))]
num_samples_per_source = [0 for _ in range(len(paths))]
while not enough_shards(
shard_list_per_source, min_shards_needed
) or not enough_samples(num_samples_per_source, needed_samples_per_source):
while not enough_shards(shard_list_per_source, min_shards_needed) or not enough_samples(
num_samples_per_source, needed_samples_per_source
):
for i, source_path in enumerate(paths):
shard_list_source, num_samples_source = get_shards_for_chunk(
needed_samples_per_source[i], next_chunk, source_path
Expand All @@ -213,9 +204,7 @@ def get_string_for_epoch(
shard_list_source = shard_list_per_source[i]
num_samples_source = num_samples_per_source[i]
shard_root_source = "/".join(source_path.split("/")[:-1]) + "/"
shard_string_source = (
shard_root_source + "{" + ",".join(shard_list_source) + "}.tar"
)
shard_string_source = shard_root_source + "{" + ",".join(shard_list_source) + "}.tar"
if source_path.startswith("s3"):
shard_string_source = f"pipe:aws s3 cp {shard_string_source} -"
shard_strings_per_source.append(shard_string_source)
Expand Down
4 changes: 1 addition & 3 deletions open_lm/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ def setup_logging(log_file, level, include_host=False):
datefmt="%Y-%m-%d,%H:%M:%S",
)
else:
formatter = logging.Formatter(
"%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S"
)
formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S")

logging.root.setLevel(level)
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
Expand Down
Loading

0 comments on commit 482b91a

Please sign in to comment.