Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
134 commits
Select commit Hold shift + click to select a range
3bee7af
Fix C4 dataset loading by removing specific data file references
amitsrivastava78 Jul 4, 2025
a13c6c3
Update CUDA extension for latest PyTorch compatibility
amitsrivastava78 Jul 4, 2025
e12ab1b
Removed the c4 dataset
amitsrivastava78 Jul 4, 2025
c0a5f76
Removed the exit from load cmd
amitsrivastava78 Jul 5, 2025
8980e47
added back exit
amitsrivastava78 Jul 5, 2025
a873340
Added simple Quant and GPTQ seprately via cmdline
amitsrivastava78 Jul 7, 2025
376c402
Added simple Quant and GPTQ seprately via cmdline part 1
amitsrivastava78 Jul 7, 2025
0826e61
Added simple Quant and GPTQ seprately via cmdline part 2
amitsrivastava78 Jul 7, 2025
138a8c7
Ported OPT quantization and evaluation to tf-keras, added calibration…
amitsrivastava78 Jul 8, 2025
df79797
added debug prints
amitsrivastava78 Jul 9, 2025
1b71dee
reported files to tf directly
amitsrivastava78 Jul 9, 2025
7ad31a9
Fixed TF issue
amitsrivastava78 Jul 9, 2025
74f7c6d
Debug statements for perpexityscore of 65
amitsrivastava78 Jul 9, 2025
5d796f7
Fixed perpexity issue
amitsrivastava78 Jul 9, 2025
5bda33d
Fixed perpexity issue part 1
amitsrivastava78 Jul 9, 2025
8e25846
Added orignal TF model opt
amitsrivastava78 Jul 9, 2025
e99eb52
Added same log in pytorch and tf impl
amitsrivastava78 Jul 9, 2025
273bca3
Fix bug while logging in opt.py
amitsrivastava78 Jul 9, 2025
d5cb7a5
Fix bug while logging in opt.py part 2
amitsrivastava78 Jul 9, 2025
fa969da
Fix error in TF as model is different
amitsrivastava78 Jul 9, 2025
b02323c
Fix error in TF as model is different Part 2
amitsrivastava78 Jul 9, 2025
47b4f01
Fix error in identifying the Dense Layer
amitsrivastava78 Jul 9, 2025
5ef6e61
Fix error in identifying the Dense Layer Part 2
amitsrivastava78 Jul 9, 2025
f6068e2
Fix error in identifying the Dense Layer Part 3
amitsrivastava78 Jul 9, 2025
0b141ae
Fix error in identifying the Dense Layer Part 4
amitsrivastava78 Jul 9, 2025
25df8c6
Fix error in identifying the Dense Layer Part 5
amitsrivastava78 Jul 9, 2025
06cf6e4
Fix error in identifying the Dense Layer Part 6
amitsrivastava78 Jul 9, 2025
ceea748
Fix error in identifying the Dense Layer Part 7
amitsrivastava78 Jul 9, 2025
958c04d
Fix error in identifying the Dense Layer Part 8
amitsrivastava78 Jul 9, 2025
4553d33
Fix input collection
amitsrivastava78 Jul 9, 2025
06dfb72
Fix input collection part 1
amitsrivastava78 Jul 9, 2025
634658a
Fix input collection part 2
amitsrivastava78 Jul 9, 2025
1a48dc9
Fix input collection part 3
amitsrivastava78 Jul 9, 2025
61e63ea
Fix no quantization weights
amitsrivastava78 Jul 9, 2025
e85ba1f
Fix no quantization weights Part 2
amitsrivastava78 Jul 9, 2025
ac9c36d
Fix no quantization weights Part 3
amitsrivastava78 Jul 9, 2025
1da9127
Fix no quantization weights Part 4
amitsrivastava78 Jul 9, 2025
cfce07c
Fix no quantization weights Part 5
amitsrivastava78 Jul 9, 2025
3f6d0b8
Fix only fc1 and fc2
amitsrivastava78 Jul 9, 2025
f7c0293
Fix only fc1 and fc2 Part 2
amitsrivastava78 Jul 9, 2025
c5379fe
Fix only fc1 and fc2 Part 3
amitsrivastava78 Jul 9, 2025
6ef8646
Fix only fc1 and fc2 Part 4
amitsrivastava78 Jul 9, 2025
436da1b
Fix only fc1 and fc2 Part 5
amitsrivastava78 Jul 9, 2025
6b70cde
Fix gptqkeras logic
amitsrivastava78 Jul 9, 2025
dfb314a
Fix gptqkeras logic Part 2
amitsrivastava78 Jul 9, 2025
922b22a
Fix gptqkeras logic Part 3
amitsrivastava78 Jul 9, 2025
6cdb8b1
Fix gptqkeras logic Part 4
amitsrivastava78 Jul 9, 2025
332d068
Fix gptqkeras logic Part 5
amitsrivastava78 Jul 9, 2025
3751dcb
Fix gptqkeras logic Part 6
amitsrivastava78 Jul 9, 2025
a3ee146
Fix gptqkeras logic Part 7
amitsrivastava78 Jul 9, 2025
dcdf904
Fix gptqkeras logic Part 8
amitsrivastava78 Jul 9, 2025
ceaff41
Fix gptqkeras logic Part 9
amitsrivastava78 Jul 9, 2025
53bbe2a
Fix gptqkeras logic Part 10
amitsrivastava78 Jul 9, 2025
93e35bb
Fix gptqkeras logic Part 11
amitsrivastava78 Jul 9, 2025
2168a18
Fix Quantization update error
amitsrivastava78 Jul 9, 2025
3b5d557
Fix Quantization update error Part 2
amitsrivastava78 Jul 9, 2025
ba9408c
Fix Quantization update error Part 3
amitsrivastava78 Jul 9, 2025
f3ebbb5
Fix Quantization update error Part 4
amitsrivastava78 Jul 9, 2025
8d6704b
Fix Quantization update error Part 5
amitsrivastava78 Jul 9, 2025
f409520
Fix Quantization update error Part 5
amitsrivastava78 Jul 9, 2025
2094b44
Fix Quantization update error Part 6
amitsrivastava78 Jul 9, 2025
5b458e5
Fix Quantization update error Part 7
amitsrivastava78 Jul 9, 2025
06c594f
Fix Quantization update error Part 8
amitsrivastava78 Jul 9, 2025
4d77675
Fix Quantization update error Part 9
amitsrivastava78 Jul 9, 2025
56c8e80
FIx matric shape warning
amitsrivastava78 Jul 9, 2025
569382e
FIx matric shape warning Part 1
amitsrivastava78 Jul 9, 2025
9e79dd3
FIx matric shape warning Part 2
amitsrivastava78 Jul 9, 2025
ace94b6
FIx matric shape warning Part 3
amitsrivastava78 Jul 9, 2025
6e3d160
FIx matric shape warning Part 4
amitsrivastava78 Jul 9, 2025
7453efb
FIx matric shape warning Part 5
amitsrivastava78 Jul 9, 2025
054bf91
Quantize all Dense layers
amitsrivastava78 Jul 9, 2025
bd3364b
Quantize all Dense layers Part 1
amitsrivastava78 Jul 9, 2025
b1c7023
Quantize all Dense layers Part 2
amitsrivastava78 Jul 9, 2025
077db0b
Quantize all Dense layers Part 3
amitsrivastava78 Jul 9, 2025
9e278ca
Trying to fix the shape issue
amitsrivastava78 Jul 9, 2025
a93a074
Trying to fix the shape issue Part 1
amitsrivastava78 Jul 9, 2025
c83597e
Trying to fix the shape issue Part 2
amitsrivastava78 Jul 9, 2025
c3691c4
Trying to fix the shape issue Part 3
amitsrivastava78 Jul 9, 2025
d9ba0f6
Trying to fix the shape issue Part 4
amitsrivastava78 Jul 9, 2025
0b76f46
Trying to fix the shape issue Part 5
amitsrivastava78 Jul 9, 2025
5006e3b
Trying to fix the shape issue Part 6
amitsrivastava78 Jul 9, 2025
f97f6fc
Trying to fix the shape issue Part 7
amitsrivastava78 Jul 9, 2025
1ccf972
Trying to fix the shape issue Part 8
amitsrivastava78 Jul 9, 2025
8f94e9b
Trying to fix the shape issue Part 9
amitsrivastava78 Jul 9, 2025
a376ad0
Trying to fix the shape issue Part 10
amitsrivastava78 Jul 9, 2025
5041aca
Trying to fix the shape issue Part 11
amitsrivastava78 Jul 9, 2025
d009335
Fix No calibration data issue
amitsrivastava78 Jul 9, 2025
405b2ef
Fix No calibration data issue Part 1
amitsrivastava78 Jul 9, 2025
3cfc69b
Fix No calibration data issue Part 2
amitsrivastava78 Jul 9, 2025
4f01d92
Added new impl for TF model load and dataloader
amitsrivastava78 Jul 10, 2025
4c0bfed
Trying to fix tf add_batch in gptqkeras.py
amitsrivastava78 Jul 10, 2025
2665637
Trying to fix tf add_batch in gptqkeras.py part 2
amitsrivastava78 Jul 10, 2025
004e9e1
Trying to fix tf add_batch in gptqkeras.py part 3
amitsrivastava78 Jul 10, 2025
7bc5fcf
Trying to fix tf add_batch in gptqkeras.py part 4
amitsrivastava78 Jul 10, 2025
1fb2325
Trying to fix tf add_batch in gptqkeras.py part 5
amitsrivastava78 Jul 10, 2025
6a23d50
Trying to fix tf add_batch in gptqkeras.py part 6
amitsrivastava78 Jul 10, 2025
59c540f
Trying to fix tf add_batch in gptqkeras.py part 7
amitsrivastava78 Jul 10, 2025
3757377
Trying to fix tf add_batch in gptqkeras.py part 8
amitsrivastava78 Jul 10, 2025
e85620d
Trying to fix tf add_batch in gptqkeras.py part 9
amitsrivastava78 Jul 10, 2025
bee8baa
Fixing no sample issue
amitsrivastava78 Jul 10, 2025
c2bf289
Hessian matrix shape print
amitsrivastava78 Jul 10, 2025
01f99e1
Hessian matrix shape print part 1
amitsrivastava78 Jul 10, 2025
04b1e68
Hessian matrix shape print part 2
amitsrivastava78 Jul 10, 2025
79bebdc
Fixed Hessian matrix
amitsrivastava78 Jul 10, 2025
3a8bc61
Fixed Hessian matrix
amitsrivastava78 Jul 10, 2025
d73c7ca
No Quant error
amitsrivastava78 Jul 10, 2025
4b27722
No Quant error part 1
amitsrivastava78 Jul 10, 2025
11a9171
Refactor the code
amitsrivastava78 Jul 10, 2025
63723dd
Added Entry and Exit prints
amitsrivastava78 Jul 11, 2025
4f6f36e
Fix No Quant weights found issue
amitsrivastava78 Jul 11, 2025
b9c1522
Fix No Quant weights found issue Part 1
amitsrivastava78 Jul 11, 2025
27778cb
Fix No Quant weights found issue Part 2
amitsrivastava78 Jul 11, 2025
b4c8ce5
Added exit after All Quant
amitsrivastava78 Jul 11, 2025
6f5a6bc
Added exit after All Quant Part 1
amitsrivastava78 Jul 11, 2025
4debaa0
Added exit after All Quant Part 2
amitsrivastava78 Jul 11, 2025
699ca69
Align with pytorch prints
amitsrivastava78 Jul 11, 2025
d8bcdc6
Align with pytorch prints Part 1
amitsrivastava78 Jul 11, 2025
a7cb137
Align with pytorch prints Part 2
amitsrivastava78 Jul 11, 2025
0dd2f90
Align with pytorch prints Part 3
amitsrivastava78 Jul 11, 2025
712c9cd
Align Quantizer count
amitsrivastava78 Jul 11, 2025
ef5c9de
Continue flow to final model perpexity score
amitsrivastava78 Jul 11, 2025
dda4e58
Fix last tester code
amitsrivastava78 Jul 11, 2025
bfd5656
Fix last tester code Part 1
amitsrivastava78 Jul 11, 2025
fdd618c
Fix last tester code Part 2
amitsrivastava78 Jul 11, 2025
85a96e0
Fix last tester code Part 3
amitsrivastava78 Jul 11, 2025
be5de00
Fix last tester code Part 4
amitsrivastava78 Jul 11, 2025
a18e609
Fix last tester code Part 5
amitsrivastava78 Jul 11, 2025
97a9496
Fix last tester code Part 6
amitsrivastava78 Jul 11, 2025
6aec53f
Fix last tester code Part 7
amitsrivastava78 Jul 11, 2025
2f6557b
Fix last tester code Part 8
amitsrivastava78 Jul 11, 2025
ab5464c
All working ppl score high
amitsrivastava78 Jul 11, 2025
5085a94
Fix error issue
amitsrivastava78 Jul 11, 2025
1a7c0d2
reverting gptq fix done by mistake
amitsrivastava78 Jul 11, 2025
003f746
fix datautils.py
amitsrivastava78 Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ opt175b
*.txt
*.pt
*egg-info*
.DS_Store
126 changes: 126 additions & 0 deletions Amitopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# main.py
import tensorflow as tf
from datasets import load_dataset
from transformers import AutoTokenizer, TFOPTForCausalLM

def get_wikitext2(tokenizer, sequence_length=128, batch_size=8):
"""
Loads and processes the wikitext-2-raw-v1 dataset.

Args:
tokenizer: The tokenizer to use for encoding the text.
sequence_length (int): The fixed length of sequences.
batch_size (int): The batch size for the DataLoader.

Returns:
A tf.data.Dataset object ready for training.
"""
print("Loading wikitext-2 dataset...")
# Load the training split
train_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

# Filter out empty lines
train_dataset = train_dataset.filter(lambda example: example['text'] != '')
print(f"Number of examples after filtering: {len(train_dataset)}")

# Tokenize the dataset
def tokenize_function(examples):
return tokenizer(examples["text"], return_tensors="tf", padding='max_length', truncation=True, max_length=sequence_length)

tokenized_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Convert to a TensorFlow DataLoader (tf.data.Dataset)
# For language modeling, the input_ids are used as both input and label.
tf_dataset = tokenized_dataset.to_tf_dataset(
columns=['input_ids', 'attention_mask'],
label_cols=['input_ids'], # Use input_ids as the label
shuffle=True,
batch_size=batch_size,
collate_fn=None # Use default collation
)

print("Wikitext-2 dataset converted to TensorFlow DataLoader.")
return tf_dataset

def get_ptb(tokenizer, sequence_length=128, batch_size=8):
"""
Loads and processes the Penn Treebank (PTB) dataset directly from its source URL.

Args:
tokenizer: The tokenizer to use for encoding the text.
sequence_length (int): The fixed length of sequences.
batch_size (int): The batch size for the DataLoader.

Returns:
A tf.data.Dataset object ready for training.
"""
print("\nLoading PTB dataset...")
# We load the data directly from its source URL using the generic 'text' loader.
data_files = {"train": "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt"}
train_dataset = load_dataset("text", data_files=data_files, split="train")

# Filter out empty lines (the 'text' loader creates a 'text' column)
train_dataset = train_dataset.filter(lambda example: example['text'] != '')
print(f"Number of examples after filtering: {len(train_dataset)}")

# Tokenize the dataset
def tokenize_function(examples):
return tokenizer(examples["text"], return_tensors="tf", padding='max_length', truncation=True, max_length=sequence_length)

tokenized_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Convert to a TensorFlow DataLoader (tf.data.Dataset)
tf_dataset = tokenized_dataset.to_tf_dataset(
columns=['input_ids', 'attention_mask'],
label_cols=['input_ids'], # Use input_ids as the label
shuffle=True,
batch_size=batch_size,
collate_fn=None # Use default collation
)

print("PTB dataset converted to TensorFlow DataLoader.")
return tf_dataset

def get_opt_125m_tf():
"""
Loads the facebook/opt-125m model and tokenizer for TensorFlow.

Returns:
A tuple containing the loaded model and tokenizer.
"""
print("\nLoading facebook/opt-125m for TensorFlow...")
model_name = "facebook/opt-125m"
# Note the use of TFOPTForCausalLM for TensorFlow
model = TFOPTForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Model and tokenizer loaded.")
return model, tokenizer

if __name__ == "__main__":
# Define a batch size
BATCH_SIZE = 4

# 1. Load the TensorFlow model and tokenizer
opt_model, opt_tokenizer = get_opt_125m_tf()

# 2. Load and process the datasets into TensorFlow DataLoaders
wikitext_dataloader = get_wikitext2(opt_tokenizer, batch_size=BATCH_SIZE)
ptb_dataloader = get_ptb(opt_tokenizer, batch_size=BATCH_SIZE)

# 3. Print some information to verify
print("\n--- Verification ---")
print(f"Model Class: {opt_model.__class__.__name__}")
print(f"Tokenizer Class: {opt_tokenizer.__class__.__name__}")

# Take one batch from each dataloader to show the structure
print("\nSample batch from Wikitext-2 DataLoader:")
for inputs, labels in wikitext_dataloader.take(1):
print("Inputs (input_ids) shape:", inputs['input_ids'].shape)
print("Inputs (attention_mask) shape:", inputs['attention_mask'].shape)
print("Labels shape:", labels.shape)

print("\nSample batch from PTB DataLoader:")
for inputs, labels in ptb_dataloader.take(1):
print("Inputs (input_ids) shape:", inputs['input_ids'].shape)
print("Inputs (attention_mask) shape:", inputs['attention_mask'].shape)
print("Labels shape:", labels.shape)
64 changes: 47 additions & 17 deletions datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,30 @@ def get_wikitext2(nsamples, seed, seqlen, model):

def get_ptb(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation')

from transformers import AutoTokenizer

try:
# Try the new way first
traindata = load_dataset('ptb-text-only/ptb_text_only', split='train')
valdata = load_dataset('ptb-text-only/ptb_text_only', split='validation')
text_field = 'sentence'
except Exception as e1:
try:
# Try alternative dataset
traindata = load_dataset('ptb_text_only', split='train')
valdata = load_dataset('ptb_text_only', split='validation')
text_field = 'sentence'
except Exception as e2:
print(f"PTB dataset not available. Using WikiText-2 as fallback.")
print(f"Original errors: {e1}, {e2}")
# Fallback to WikiText-2
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
valdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
text_field = 'text'

tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt')
testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt')
trainenc = tokenizer("\n\n".join(traindata[text_field]), return_tensors='pt')
testenc = tokenizer("\n\n".join(valdata[text_field]), return_tensors='pt')

import random
random.seed(seed)
Expand All @@ -53,12 +70,8 @@ def get_ptb(nsamples, seed, seqlen, model):

def get_c4(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset(
'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train'
)
valdata = load_dataset(
'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation'
)
traindata = load_dataset('allenai/c4', 'en', split='train')
valdata = load_dataset('allenai/c4', 'en', split='validation')

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
Expand Down Expand Up @@ -97,17 +110,34 @@ def __init__(self, input_ids):
self.input_ids = input_ids
valenc = TokenizerWrapper(valenc)

return trainloader, valenc
return trainloader, valenc

def get_ptb_new(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')

from transformers import AutoTokenizer

try:
# Try the new way first
traindata = load_dataset('ptb-text-only/ptb_text_only', split='train')
testdata = load_dataset('ptb-text-only/ptb_text_only', split='test')
text_field = 'sentence'
except Exception as e1:
try:
# Try alternative dataset
traindata = load_dataset('ptb_text_only', split='train')
testdata = load_dataset('ptb_text_only', split='test')
text_field = 'sentence'
except Exception as e2:
print(f"PTB dataset not available. Using WikiText-2 as fallback.")
print(f"Original errors: {e1}, {e2}")
# Fallback to WikiText-2
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
text_field = 'text'

tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
trainenc = tokenizer(" ".join(traindata[text_field]), return_tensors='pt')
testenc = tokenizer(" ".join(testdata[text_field]), return_tensors='pt')

import random
random.seed(seed)
Expand Down
8 changes: 6 additions & 2 deletions gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def fasterquant(
print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
print(torch.sum(Losses))

torch.cuda.synchronize()
# Synchronize only if CUDA is available
if torch.cuda.is_available():
torch.cuda.synchronize()
print('time %.2f' % (time.time() - tick))
print('error', torch.sum(Losses).item())

Expand All @@ -168,4 +170,6 @@ def free(self):
self.H = None
self.Losses = None
self.Trace = None
torch.cuda.empty_cache()
# Clear cache only if CUDA is available
if torch.cuda.is_available():
torch.cuda.empty_cache()
Loading