Skip to content

Commit

Permalink
Add preemption
Browse files Browse the repository at this point in the history
  • Loading branch information
artemisp committed Mar 24, 2024
1 parent 3bc98a7 commit 1eb8553
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 24 deletions.
4 changes: 0 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
# Ignore .pyc files
*.pyc

# Ignore apex
**apex**


# Ignore cache
**.cache**

Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ In installing `PyTorch` we assume `CUDA` version ~~12.0~~ 12.1 are compatible wi


```
>> conda create -n test_me python=3.10
>> conda create -n test_me python=3.8
>> conda activate test_me
>> conda activate test_me
>> conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
Expand All @@ -58,9 +58,10 @@ In installing `PyTorch` we assume `CUDA` version ~~12.0~~ 12.1 are compatible wi
```



📛 `CUDA` version 12.0. Deprecated as of March 14, 2024 since the NVIDIA Drives in nlpgpu got an update.
```
>> conda create -n test_me python=3.10
>> conda create -n test_me python=3.8
>> conda activate test_me
>> conda activate test_me
>> conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
Expand Down
26 changes: 16 additions & 10 deletions parallelm/data/pl_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# Get the data root dir
cache_dir = os.getenv('CACHE_DIR', "./.cache")

sys.path.append(os.getcwd())
from parallelm.data.preprocessing import get_inputs_and_targets, tokenize_inputs_and_targets, batch_tokenize_inputs_and_targets


Expand All @@ -44,8 +43,7 @@ def __init__(self, dataset: datasets.Dataset, split, tokenizer = None, **kwargs)
self.kwargs = kwargs
self.tokenizer = tokenizer
self.split = split



def __len__(self):
"""
Returns the length of the dataset.
Expand Down Expand Up @@ -75,7 +73,10 @@ def __getitem__(self, index):
# encoder=self.encoder if self.kwargs.get('preprocessing_kwargs', {}).get("context_aware_prefix", False) else None,
**self.kwargs.get('tokenization_kwargs', {})
)
return sample
if self.split != 'train':
return sample
# return indices to track in training
return index, sample

class CustomDataModule(pl.LightningDataModule):
def __init__(self, tokenizer=None, batch_size=None, **kwargs):
Expand Down Expand Up @@ -130,6 +131,9 @@ def __init__(self, tokenizer=None, batch_size=None, **kwargs):
self.sampler_state = None

self.kwargs = kwargs
self.current_epoch = 0
self.global_step = 0
self.effective_batch_size = self.batch_size * max(1, torch.cuda.device_count()) if 'dp' in self.strategy else self.batch_size


def preprocess_datasets(self):
Expand Down Expand Up @@ -196,16 +200,16 @@ def configure_data_splits(self, **kwargs):
return

## dev data
if isinstance(kwargs.get('dev_size', .1), float) and kwargs.get('dev_size', .1) <= 1.:
dev_size = int(len(self.dataset['train'])*kwargs.get('dev_size', .1))
if isinstance(kwargs.get('dev_size', -1), float) and kwargs.get('dev_size', .1) <= 1.:
dev_size = int(len(self.dataset['train'])*kwargs.get('dev_size'))
else:
dev_size = kwargs.get('dev_size', .1)
dev_size = kwargs.get('dev_size', -1)
if "dev" in self.splits:
self.dataset['dev'] = self.dataset['dev'].select(range(dev_size))
if kwargs.get('dev_from_train', -1) > 0:
self.dataset = self.dataset['train'].train_test_split(test_size=dev_size,shuffle=False)
self.dataset['dev'] = self.dataset['test']
else:
if "dev" in self.splits:
self.dataset['dev'].shuffle().select(range(dev_size))


## training size
if kwargs.get('tiny', False):
Expand Down Expand Up @@ -307,6 +311,8 @@ def setup(self, stage=None):
print(f"Loaded {len(getattr(self, f'{split}_dataset'))} datasaet samples for {split} split")

def train_dataloader(self):
if 'dp' in self.strategy and torch.cuda.device_count() > 1:
self.train_sampler.set_epoch(self.current_epoch)
return DataLoader(self.train_dataset, sampler=self.train_sampler, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=False)

def val_dataloader(self):
Expand Down
2 changes: 1 addition & 1 deletion parallelm/models/pl_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def __init__(self, learning_rate=None, tokenizer=None, predict=False, **kwargs):
self.lora = self.kwargs.get('lora', False)
self.predict = predict
self.soft_prompt = self.kwargs.get('prefix_tuning', False)


def setup(self, stage):
super().setup(stage)
Expand Down Expand Up @@ -226,6 +225,7 @@ def forward(self, input_ids, attention_mask, labels=None, output_hidden_states=F
return output

def training_step(self, batch, batch_idx):
indices, batch = batch
if type(batch["input_ids"]) == list:
batch["input_ids"] = torch.stack(batch["input_ids"]).t()
batch["attention_mask"] = torch.stack(batch["attention_mask"]).t()
Expand Down
7 changes: 1 addition & 6 deletions pl_predict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
if __name__ == "__main__":

import sys

import os
sys.path.append(os.getcwd())

import datasets
datasets.disable_caching()

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ peft
accelerate
bitsandbytes
wandb
vissl
numpy

0 comments on commit 1eb8553

Please sign in to comment.