-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4a4f61a
commit 5e44edf
Showing
13 changed files
with
4,593 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
# WActiGrad | ||
# The is the implementation of our paper Weight Activation and Gradient (WActiGrad) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import torch | ||
import random | ||
import numpy as np | ||
from datasets import load_dataset | ||
from transformers import AutoTokenizer, LlamaTokenizer | ||
|
||
def set_seed(seed): | ||
np.random.seed(seed) | ||
torch.random.manual_seed(seed) | ||
|
||
def get_tokenizer(model): | ||
if "llama" in model.lower(): | ||
tokenizer = LlamaTokenizer.from_pretrained(model, use_fast=False) | ||
if tokenizer.bos_token_id != 1 or tokenizer.eos_token_id != 2: | ||
try: | ||
tokenizer.bos_token_id = 1 | ||
tokenizer.eos_token_id = 2 | ||
except AttributeError: | ||
pass | ||
else: | ||
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) | ||
return tokenizer | ||
|
||
|
||
def get_wikitext2(nsamples, seed, seqlen, tokenizer): | ||
|
||
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') | ||
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') | ||
|
||
trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt') | ||
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') | ||
|
||
random.seed(seed) | ||
trainloader = [] | ||
for _ in range(nsamples): | ||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) | ||
j = i + seqlen | ||
inp = trainenc.input_ids[:, i:j] | ||
tar = inp.clone() | ||
tar[:, :-1] = -100 | ||
trainloader.append((inp, tar)) | ||
return trainloader, testenc | ||
|
||
def get_ptb(nsamples, seed, seqlen, tokenizer): | ||
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') | ||
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') | ||
|
||
trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') | ||
testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') | ||
|
||
random.seed(seed) | ||
trainloader = [] | ||
for _ in range(nsamples): | ||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) | ||
j = i + seqlen | ||
inp = trainenc.input_ids[:, i:j] | ||
tar = inp.clone() | ||
tar[:, :-1] = -100 | ||
trainloader.append((inp, tar)) | ||
return trainloader, testenc | ||
|
||
def get_c4(nsamples, seed, seqlen, tokenizer): | ||
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') | ||
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') | ||
|
||
random.seed(seed) | ||
trainloader = [] | ||
for _ in range(nsamples): | ||
while True: | ||
i = random.randint(0, len(traindata) - 1) | ||
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') | ||
if trainenc.input_ids.shape[1] > seqlen: | ||
break | ||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) | ||
j = i + seqlen | ||
inp = trainenc.input_ids[:, i:j] | ||
tar = inp.clone() | ||
tar[:, :-1] = -100 | ||
trainloader.append((inp, tar)) | ||
|
||
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') | ||
valenc = valenc.input_ids[:, :(256 * seqlen)] | ||
|
||
class TokenizerWrapper: | ||
def __init__(self, input_ids): | ||
self.input_ids = input_ids | ||
valenc = TokenizerWrapper(valenc) | ||
|
||
return trainloader, valenc | ||
|
||
|
||
|
||
|
||
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None): | ||
|
||
if 'wikitext2' in name: | ||
return get_wikitext2(nsamples, seed, seqlen, tokenizer) | ||
|
||
if 'ptb' in name: | ||
return get_ptb(nsamples, seed, seqlen, tokenizer) | ||
|
||
if 'c4' in name: | ||
return get_c4(nsamples, seed, seqlen, tokenizer) | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import time | ||
import torch | ||
import torch.nn as nn | ||
|
||
from data import get_loaders | ||
|
||
def eval_ppl(model, tokenizer, dataset, device = torch.device("cuda:0")): | ||
ppl_dict = {} | ||
trainloader, testloader = get_loaders(dataset, seed=0, nsamples=128, seqlen=model.seqlen, tokenizer=tokenizer) | ||
with torch.no_grad(): | ||
print("Evaluating done on ", dataset) | ||
ppl_dict[dataset + "_ppl_test"] = eval_ppl_test(model, testloader, 1, device) | ||
ppl_dict[dataset + "_ppl_train"] = eval_ppl_train(model, trainloader, 1, device) | ||
return ppl_dict | ||
|
||
|
||
def eval_ppl_train(model, trainloader, bs = 1, device = None): | ||
nsamples = len(trainloader) | ||
|
||
nlls = [] | ||
print(f"nsamples {nsamples}") | ||
|
||
for i in range(0,nsamples,bs): | ||
if i % 50 == 0: | ||
print(f"sample {i}") | ||
|
||
j = min(i+bs, nsamples) | ||
|
||
inputs = trainloader[i][0].to(device) | ||
inputs = inputs.reshape(j-i, model.seqlen) | ||
|
||
lm_logits = model(inputs).logits | ||
|
||
shift_logits = lm_logits[:, :-1, :].contiguous() | ||
shift_labels = inputs[:, 1:] | ||
|
||
loss_fct = nn.CrossEntropyLoss() | ||
loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) | ||
|
||
neg_log_likelihood = loss.float() * model.seqlen * (j-i) | ||
|
||
nlls.append(neg_log_likelihood) | ||
|
||
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) | ||
|
||
torch.cuda.empty_cache() | ||
|
||
return ppl.item() | ||
|
||
def eval_ppl_test(model, testenc, bs=1, device=None): | ||
testenc = testenc.input_ids | ||
|
||
nsamples = testenc.numel() // model.seqlen | ||
|
||
nlls = [] | ||
print(f"nsamples {nsamples}") | ||
|
||
for i in range(0,nsamples,bs): | ||
if i % 50 == 0: | ||
print(f"sample {i}") | ||
|
||
j = min(i+bs, nsamples) | ||
|
||
inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device) | ||
inputs = inputs.reshape(j-i, model.seqlen) | ||
|
||
lm_logits = model(inputs).logits | ||
|
||
shift_logits = lm_logits[:, :-1, :].contiguous() | ||
shift_labels = inputs[:, 1:] | ||
|
||
loss_fct = nn.CrossEntropyLoss() | ||
loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) | ||
|
||
neg_log_likelihood = loss.float() * model.seqlen * (j-i) | ||
|
||
nlls.append(neg_log_likelihood) | ||
|
||
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) | ||
torch.cuda.empty_cache() | ||
return ppl.item() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import sys, os | ||
import json | ||
import torch | ||
import numpy as np | ||
|
||
def print_requires_grad_params(model): | ||
for name, param in model.named_parameters(): | ||
if param.requires_grad: | ||
print(name) | ||
|
||
|
||
def get_LLM(model_name): | ||
if "llama" in model_name: | ||
from modeling_sparse_llama import LlamaForCausalLM | ||
from transformers import LlamaTokenizer | ||
|
||
model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype = torch.float16, device_map = "auto") | ||
tokenizer = LlamaTokenizer.from_pretrained(model_name) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
|
||
elif "Mistral" in model_name: | ||
from modeling_sparse_mistral import MistralForCausalLM | ||
from transformers import AutoTokenizer | ||
|
||
model = MistralForCausalLM.from_pretrained(model_name, torch_dtype = torch.float16, device_map = "auto") | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
tokenizer.padding_side='left' | ||
|
||
model.seqlen = 2048 | ||
return model, tokenizer | ||
|
||
|
||
|
Oops, something went wrong.