Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tinygrad Quantization Support [WIP] #630

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions exo/inference/tinygrad/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
import json
import os
from exo.inference.tinygrad.models.llama import Transformer, TransformerShard, convert_from_huggingface, fix_bf16, sample_logits
from exo.inference.tinygrad.models.llama import Transformer, TransformerShard, convert_from_huggingface, fix_bf16, sample_logits, unpack_quantized, AffineQuantizedLinear, AffineQuantizedEmbedding
from exo.inference.shard import Shard
from exo.inference.tokenizers import resolve_tokenizer
from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict
Expand Down Expand Up @@ -41,8 +41,19 @@

def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
# build model
try:
with open(model_path/"config.json", "r") as f:
config = json.load(f)

except FileNotFoundError:
raise Exception(f"Config file not found in {model_path}")

linear = nn.Linear
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
embedding = nn.Embedding
if (quantization := config.get("quantization", None)) is not None:
linear, embedding = AffineQuantizedLinear, AffineQuantizedEmbedding

model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, embedding=embedding, max_context=8192, jit=True, shard=shard)

# load weights
if model_path.is_dir():
Expand All @@ -54,6 +65,8 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
weights = fix_bf16(weights)

if (quantization is not None): weights = unpack_quantized(weights, quantization["bits"])

with Context(BEAM=0):
# replace weights in model
load_state_dict(model, weights, strict=False, consume=False) # consume=True
Expand Down
105 changes: 91 additions & 14 deletions exo/inference/tinygrad/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,56 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
# NOTE: this is different from x.repeat((1, 1, n_rep, 1))
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)


def dequantize_asymmetric(weight: Tensor, scales: Tensor, biases: Tensor, group_size: int) -> Tensor:
weight = weight.cast(dtype=dtypes.half).reshape(*scales.shape, group_size) # group_size, out_features, in_features
scales = scales.reshape(*scales.shape, 1).expand(*scales.shape, group_size)
biases = biases.reshape(*biases.shape, 1).expand(*biases.shape, group_size)
return (weight*scales + biases).reshape(weight.shape[0], -1)


class AffineQuantizedLinear:
def __init__(self, in_features: int, out_features: int, group_size: int = 64, bias: bool = False):
self.group_size = group_size
self.weight = Tensor.ones(out_features, in_features, dtype=dtypes.uint8)
self.scales = Tensor.ones(out_features, in_features // group_size, dtype=dtypes.half)
self.biases = Tensor.ones(out_features, in_features // group_size, dtype=dtypes.half)

def __call__(self, x):
orig_shape = x.shape
x_flat = x.reshape(-1, x.shape[-1]) # [batch*seq_len, in_features]
total_batch = x_flat.shape[0]

# Grouping parameters
in_features = x_flat.shape[-1]
group_size = self.group_size
n_groups = in_features // group_size
out_features = self.weight.shape[0]

# Reshape weights and input for batched operations
w_group = self.weight.reshape(*self.scales.shape, group_size).permute(1, 0, 2).cast(dtype=dtypes.float)
x_group = x_flat.reshape(total_batch, n_groups, group_size).permute(1, 0, 2) # [n_groups, total_batch, group_size]

term1 = (x_group.dot(w_group.permute(0, 2, 1)))*self.scales.T.reshape(n_groups, 1, out_features) # [n_groups, 1, out_features]
term2 = x_group.sum(axis=-1, keepdim=True)*self.biases.T.reshape(n_groups, 1, out_features) # [n_groups, total_batch, out_features]
return (term1 + term2).sum(axis=0).reshape(*orig_shape[:-1], out_features).cast(dtype=dtypes.half)


class AffineQuantizedEmbedding:
def __init__(self, vocab_size: int, embed_size: int, group_size: int = 64):
self.vocab_sz, self.embed_sz, self.group_size = vocab_size, embed_size, group_size
self.weight = Tensor.ones(vocab_size, embed_size, dtype=dtypes.uint8)
self.scales = Tensor.ones(vocab_size, embed_size // group_size, dtype=dtypes.half)
self.biases = Tensor.ones(vocab_size, embed_size // group_size, dtype=dtypes.half)

def __call__(self, idx: Tensor) -> Tensor:
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
full_weight = dequantize_asymmetric(self.weight, self.scales, self.biases, self.group_size)
big_shp = idx.shape + (self.vocab_sz, self.embed_sz)
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape + (1, 1)).expand(big_shp), full_weight.expand(big_shp)
return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype)


class Attention:
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
self.n_heads = n_heads
Expand All @@ -61,7 +111,7 @@ def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)

def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None) -> Tensor:
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor] = None) -> Tensor:
if getenv("WQKV"):
if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
xqkv = x @ self.wqkv.T
Expand Down Expand Up @@ -180,6 +230,7 @@ def __init__(
vocab_size,
shard: Shard = None,
linear=nn.Linear,
embedding=nn.Embedding,
n_kv_heads=None,
rope_theta=10000,
max_context=1024,
Expand All @@ -190,10 +241,13 @@ def __init__(
):
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
self.norm = nn.RMSNorm(dim, norm_eps)
self.tok_embeddings = nn.Embedding(vocab_size, dim)
self.output = nn.Linear(dim, vocab_size, bias=False)
self.tok_embeddings = embedding(vocab_size, dim)
self.output = linear(dim, vocab_size, bias=False)
if tie_word_embeddings:
self.output.weight = self.tok_embeddings.weight
if (getattr(self.output, "scales", None) is not None):
self.output.scales = self.tok_embeddings.scales
self.output.biases = self.tok_embeddings.biases
self.max_context = max_context
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
self.forward_jit = TinyJit(self.forward_base) if jit else None
Expand Down Expand Up @@ -283,19 +337,30 @@ def permute(v: Tensor, n_heads: int):
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])

keymap = {
"model.embed_tokens.weight": "tok_embeddings.weight",
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
for l in range(len(model.layers))},
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
for x in ["q", "k", "v", "o"]
for l in range(len(model.layers))},
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
**{f"model.embed_tokens.{_type}": f"tok_embeddings.{_type}"
for _type in ["weight", "scales", "biases"]},
**{f"model.layers.{l}.input_layernorm.{_type}": f"layers.{l}.attention_norm.{_type}"
for _type in ["weight", "scales", "biases"]
for l in range(len(model.layers))},
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
**{
f"model.layers.{l}.self_attn.{x}_proj.{_type}": f"layers.{l}.attention.w{x}.{_type}"
for _type in ["weight", "scales", "biases"]
for x in ["q", "k", "v", "o"]
for l in range(len(model.layers))
},
**{f"model.layers.{l}.post_attention_layernorm.{_type}": f"layers.{l}.ffn_norm.{_type}"
for _type in ["weight", "scales", "biases"]
for l in range(len(model.layers))},
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
**{
f"model.layers.{l}.mlp.{x}_proj.{_type}": f"layers.{l}.feed_forward.w{y}.{_type}"
for _type in ["weight", "scales", "biases"]
for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
for l in range(len(model.layers))
},
**{f"model.norm.{_type}": f"norm.{_type}"
for _type in ["weight", "scales", "biases"]},
**{f"lm_head.{_type}": f"output.{_type}"
for _type in ["weight", "scales", "biases"]},
}
sd = {}
for k, v in weights.items():
Expand Down Expand Up @@ -325,3 +390,15 @@ def fix_bf16(weights: Dict[Any, Tensor]):
return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
# TODO: check if device supports bf16
return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}


def unpack_quantized(weights: Dict[Any, Tensor], bits: int) -> Dict[Any, Tensor]:
unpacked_weights = {}
for k, v in weights.items():
if v.dtype == dtypes.uint32:
weights_per_pack = 32 // bits
unpacked = Tensor.stack(*reversed([v*2**(bits*i) for i in range(weights_per_pack)]), dim=-1).idiv(2**(bits*(weights_per_pack-1))).cast(dtypes.uint8)
unpacked_weights[k] = unpacked.reshape((*v.shape[:-1], weights_per_pack*v.shape[-1]))
else:
unpacked_weights[k] = v
return unpacked_weights
4 changes: 2 additions & 2 deletions exo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
"layers": 16,
"repo": {
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
"TinygradDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
},
},
"llama-3.2-1b-8bit": {
"layers": 16,
"repo": {
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-8bit",
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
"TinygradDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-8bit",
},
},
"llama-3.2-3b": {
Expand Down