Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save intermediate checkpoints when sampling without replacement (take 2) #535

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ numerical results as the naïve method.

#### Epochs

For larger datasets (eg Laion2B), we recommend setting --train-num-samples to a lower value than the full epoch, for example `--train-num-samples 135646078` to 1/16 of an epoch in conjunction with --dataset-resampled to do sampling with replacement. This allows having frequent checkpoints to evaluate more often.
For larger datasets (eg Laion2B), we recommend setting `--train-num-samples` to a lower value than the full epoch, for example `--train-num-samples 135646078` to 1/16 of an epoch in conjunction with `--dataset-resampled` to do sampling with replacement. This allows having frequent checkpoints to evaluate more often.

Alternatively, you can use `--num-subepochs-per-epoch` to save checkpoints more frequently without `--dataset-resampled`.
When this `--num-subepochs-per-epoch` is used, checkpointing will act as if there are that many times more epochs.
For example, if `args.num_subepochs_per_epoch` is set to 2, `args.epochs` is set to 1 and `args.save_frequency` is also set to 1, the code will save 2 checkpoints `epoch_1.pt` and `epoch_2.pt`.

#### Patch Dropout

Expand Down
67 changes: 39 additions & 28 deletions src/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,42 +233,55 @@ def pytorch_worker_seed(increment=0):
return wds.utils.pytorch_worker_seed()


_SHARD_SHUFFLE_SIZE = 2000
_SHARD_SHUFFLE_INITIAL = 500
_SAMPLE_SHUFFLE_SIZE = 5000
_SAMPLE_SHUFFLE_INITIAL = 1000


class detshuffle2(wds.PipelineStage):
def __init__(
self,
bufsize=1000,
initial=100,
seed=0,
epoch=-1,
):
self.bufsize = bufsize
self.initial = initial
class SimpleShardList2(IterableDataset):
"""An iterable dataset yielding a list of urls."""

def __init__(self, urls, epoch=-1, seed=0, num_sub_epochs=None):
"""Iterate through the list of shards."""
super().__init__()
urls, _ = expand_urls(urls)
self.urls = urls
assert isinstance(self.urls[0], str)
self.seed = seed
self.num_sub_epochs = num_sub_epochs
self.epoch = epoch

def run(self, src):
def __len__(self):
return len(self.urls)

def __iter__(self):
"""Return an iterator over the shards."""
urls = self.urls.copy()

# Set epoch
if isinstance(self.epoch, SharedEpoch):
epoch = self.epoch.get_value()
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.epoch += 1
epoch = self.epoch
rng = random.Random()
if self.seed < 0:
# If seed is negative, we use the worker's seed, this will be different across all nodes/workers
seed = pytorch_worker_seed(epoch)
else:
# This seed to be deterministic AND the same across all nodes/workers in each epoch

# Shuffle with the same seed across all nodes/workers in each epoch or super epoch
if self.num_sub_epochs is None:
seed = self.seed + epoch
rng.seed(seed)
return _shuffle(src, self.bufsize, self.initial, rng)
else:
# Keep shuffling consistent across the super epochs
seed = self.seed + (epoch // self.num_sub_epochs)
random.Random(seed).shuffle(urls)

# Restrict to shards in the sub epoch if needed
if self.num_sub_epochs is not None:
urls = urls[epoch % self.num_sub_epochs::self.num_sub_epochs]

# Yield shards
for url in urls:
yield dict(url=url)



class ResampledShards2(IterableDataset):
Expand Down Expand Up @@ -344,6 +357,10 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni
# Eval will just exhaust the iterator if the size is not specified.
num_samples = args.val_num_samples or 0

# Adjust num_samples if saving multiple times per epoch when sampling without replacement
if not resampled and args.num_subepochs_per_epoch is not None:
num_samples = int(num_samples / args.num_subepochs_per_epoch)

shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc

if resampled:
Expand All @@ -356,18 +373,12 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni
else:
assert args.train_data_upsampling_factors is None,\
"--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)."
pipeline = [wds.SimpleShardList(input_shards)]
pipeline = [SimpleShardList2(input_shards, epoch=shared_epoch, num_sub_epochs=args.num_subepochs_per_epoch)]

# at this point we have an iterator over all the shards
if is_train:
if not resampled:
pipeline.extend([
detshuffle2(
bufsize=_SHARD_SHUFFLE_SIZE,
initial=_SHARD_SHUFFLE_INITIAL,
seed=args.seed,
epoch=shared_epoch,
),
wds.split_by_node,
wds.split_by_worker,
])
Expand Down
9 changes: 7 additions & 2 deletions src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import random
from datetime import datetime
import math

import numpy as np
import torch
Expand Down Expand Up @@ -348,6 +349,10 @@ def main(args):
data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model))
assert len(data), 'At least one train or eval dataset must be specified.'

# when sampling without replacement and saving subepochs, we need to adjust args.epochs
if not args.dataset_resampled and args.num_subepochs_per_epoch is not None:
args.epochs *= args.num_subepochs_per_epoch

# create scheduler if train
scheduler = None
if 'train' in data and optimizer is not None:
Expand Down Expand Up @@ -411,7 +416,7 @@ def main(args):

loss = create_loss(args)

for epoch in range(start_epoch, args.epochs):
for epoch in range(start_epoch, math.ceil(args.epochs)):
if is_master(args):
logging.info(f'Start epoch {epoch}')

Expand All @@ -432,7 +437,7 @@ def main(args):
if scaler is not None:
checkpoint_dict["scaler"] = scaler.state_dict()

if completed_epoch == args.epochs or (
if completed_epoch >= args.epochs or (
args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
):
torch.save(
Expand Down
12 changes: 11 additions & 1 deletion src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def parse_args(args):
"--batch-size", type=int, default=64, help="Batch size per GPU."
)
parser.add_argument(
"--epochs", type=int, default=32, help="Number of epochs to train for."
"--epochs", type=float, default=32, help="Number of epochs to train for."
)
parser.add_argument(
"--epochs-cooldown", type=int, default=None,
Expand Down Expand Up @@ -174,6 +174,16 @@ def parse_args(args):
default=False,
help="Always save the most recent model trained to epoch_latest.pt.",
)
parser.add_argument(
"--num-subepochs-per-epoch",
type=int,
default=None,
help=(
"Number of subepochs per epoch. This can be used to save checkpoints more frequently when --dataset-resampled is False. "
"When this flag is used, checkpointing will act as if there are `args.num_subepochs_per_epoch` times more epochs. "
"E.g. if `args.num_subepochs_per_epoch` is 2, `args.epochs` is 1 and `args.save_frequency` is 1, it'll save checkpoints epoch_1.pt and epoch_2.pt."
)
)
parser.add_argument(
"--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
)
Expand Down
17 changes: 12 additions & 5 deletions src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,14 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist
batch_time_m.update(time.time() - end)
end = time.time()
batch_count = i_accum + 1
if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch):
batch_size = len(images)
num_samples = batch_count * batch_size * args.accum_freq * args.world_size
samples_per_epoch = dataloader.num_samples
percent_complete = 100.0 * batch_count / num_batches_per_epoch

batch_size = len(images)
num_samples = batch_count * batch_size * args.accum_freq * args.world_size
samples_per_epoch = dataloader.num_samples
percent_complete = 100.0 * batch_count / num_batches_per_epoch

# Log training progress
if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch):
# NOTE loss is coarsely sampled, just master node and per log update
for key, val in losses.items():
if key not in losses_m:
Expand Down Expand Up @@ -230,6 +232,11 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist
# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()

# Exit early if we've hit our epoch limit
if args.epochs % 1 > 0 and epoch + 1 == math.ceil(args.epochs) and percent_complete >= args.epochs % 1:
return

# end for


Expand Down
36 changes: 33 additions & 3 deletions tests/test_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def save_tar(idx, num_samples):
return input_dir


def build_params(input_shards, seed=0):
def build_params(input_shards, seed=0, **kwargs):
args = parse_args([])
args.train_data = input_shards
args.train_num_samples = TRAIN_NUM_SAMPLES
Expand All @@ -67,16 +67,46 @@ def build_params(input_shards, seed=0):
preprocess_img = lambda x: x
tokenizer = lambda x: [x.strip()]

for key, value in kwargs.items():
setattr(args, key, value)

return args, preprocess_img, tokenizer


def get_dataloader(input_shards):
args, preprocess_img, tokenizer = build_params(input_shards)
def get_dataloader(input_shards, return_dataset=False, **kwargs):
args, preprocess_img, tokenizer = build_params(input_shards, **kwargs)
dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer)
if return_dataset:
return dataset
dataloader = dataset.dataloader
return dataloader


def test_sampling_without_replacement():
"""Test webdataset when sampling without replacement."""
input_dir = build_inputs('single_source')
input_shards = os.path.join(input_dir, 'test_data_{000..001}.tar')
dataset = get_dataloader(input_shards, dataset_resampled=False, num_subepochs_per_epoch=2, return_dataset=True)

for epoch in [0, 1]:
dataset.set_epoch(epoch)

dataloader = dataset.dataloader

counts = collections.defaultdict(int)
for sample in dataloader:
txts = sample[1]
for txt in txts:
counts[txt] += 1

sample_key = list(counts.keys())[0]
prefix = sample_key.split('_')[0]
expected_count = TRAIN_NUM_SAMPLES / 20 if prefix == '000' else TRAIN_NUM_SAMPLES / 10
for key, count in counts.items():
assert key.startswith(prefix)
assert count == pytest.approx(expected_count, RTOL)


def test_single_source():
"""Test webdataset with a single tar file."""
input_dir = build_inputs('single_source')
Expand Down