Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tinygrad quantization support #213

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Tuple, Optional
from abc import ABC, abstractmethod
from .shard import Shard
from ..download.shard_download import ShardDownloader


class InferenceEngine(ABC):
Expand All @@ -16,7 +17,7 @@ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarr
pass


def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
if inference_engine_name == "mlx":
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine

Expand Down
73 changes: 68 additions & 5 deletions exo/inference/tinygrad/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from exo.inference.shard import Shard
from exo.inference.tokenizers import resolve_tokenizer
from tinygrad.nn.state import load_state_dict
from tinygrad import Tensor, nn, Context
from tinygrad import Tensor, nn, Context, dtypes
from exo.inference.inference_engine import InferenceEngine
from typing import Optional, Tuple
import numpy as np
from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
from exo.download.shard_download import ShardDownloader
from concurrent.futures import ThreadPoolExecutor
import asyncio
from functools import partial

Tensor.no_grad = True
# default settings
Expand All @@ -27,11 +28,60 @@
}


def select_bits(w, bits, start):
shift_left = 32 - (start + bits)
shift_right = shift_left + start
return (w * (2**shift_left)) // (2**shift_right)


class MLXQuantizedLinear:
def __init__(self, in_features, out_features, bits=4, group_size=64, bias=False):
assert in_features % group_size == 0
assert 32 % bits == 0
assert (in_features * bits) % 32 == 0
self.weight = Tensor.kaiming_uniform(out_features, (in_features * bits) // 32, dtype=dtypes.uint32)
self.scales = Tensor.kaiming_uniform(out_features, in_features // group_size, dtype=dtypes.half)
self.biases = Tensor.kaiming_uniform(out_features, in_features // group_size, dtype=dtypes.half)
self.bias = Tensor.uniform(out_features, low=0, high=1.0) if bias else None
self.bits = bits
self.group_size = group_size

def __call__(self, x):
w_full = Tensor.cat(
*[select_bits(self.weight, self.bits, i)[..., None] for i in range(0, 32, self.bits)], dim=-1
).reshape(len(self.weight), self.scales.shape[-1], -1)
w_full = self.scales[..., None] * w_full + self.biases[..., None]
return x.linear(w_full.reshape(len(self.weight), -1).T, self.bias)


class MLXQuantizedEmbedding:
def __init__(self, vocab_size, embed_size, bits = 4, group_size= 64):
self.vocab_sz, self.embed_sz = vocab_size, embed_size
self.bits = bits
self.group_size = group_size
self.weight = Tensor.glorot_uniform(vocab_size, (embed_size * bits) // 32)
self.scales = Tensor.glorot_uniform(vocab_size, embed_size // group_size)
self.biases = Tensor.glorot_uniform(vocab_size, embed_size // group_size)

def __call__(self, x):
s = x.shape
x = x.flatten()
w = self.weight[x]
scales = self.scales[x]
biases = self.biases[x]
w_full = Tensor.cat(
*[select_bits(w, self.bits, i)[..., None] for i in range(0, 32, self.bits)], dim=-1
).reshape(len(w), scales.shape[-1], -1)
w_full = scales[..., None] * w_full + biases[..., None]
return w_full.reshape(*s, -1)


def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
# build model
linear = nn.Linear
with Context(THREEFRY=0):
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
try:
with open(model_path/"config.json", "r") as f:
config = json.load(f)
except FileNotFoundError:
raise Exception(f"Config file not found in {model_path}")

# load weights
if model_path.is_dir():
Expand All @@ -40,6 +90,19 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
else:
weights = load(str(model_path), shard)

# build model
if (quantization := config.get("quantization", None)) is not None:
linear = partial(MLXQuantizedLinear, **quantization)
else:
linear = nn.Linear
if "model.embed_tokens.scales" in weights.keys():
tok_embeddings = partial(MLXQuantizedEmbedding, **quantization)
else:
tok_embeddings = nn.Embedding

with Context(THREEFRY=0):
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, tok_embeddings=tok_embeddings, max_context=8192, jit=True, shard=shard)
weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
weights = fix_bf16(weights)

Expand Down
6 changes: 6 additions & 0 deletions exo/inference/tinygrad/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import Union, Optional
from tinygrad import Tensor, Variable

class IdentityBlock:
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
return x
39 changes: 26 additions & 13 deletions exo/inference/tinygrad/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
from tinygrad.helpers import getenv

from .base import IdentityBlock


# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
Expand Down Expand Up @@ -172,16 +174,24 @@ def __init__(
vocab_size,
shard: Shard = None,
linear=nn.Linear,
tok_embeddings=nn.Embedding,
n_kv_heads=None,
rope_theta=10000,
max_context=1024,
jit=True,
feed_forward=FeedForward
):
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
self.norm = nn.RMSNorm(dim, norm_eps)
self.tok_embeddings = nn.Embedding(vocab_size, dim)
self.output = nn.Linear(dim, vocab_size, bias=False)
self.layers = []
for i in range(n_layers):
if shard.start_layer <= i <= shard.end_layer:
self.layers.append(TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward))
else:
self.layers.append(IdentityBlock())
if shard.is_first_layer():
self.tok_embeddings = tok_embeddings(vocab_size, dim)
if shard.is_last_layer():
self.norm = nn.RMSNorm(dim, norm_eps)
self.output = linear(dim, vocab_size, bias=False)
self.max_context = max_context
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta).contiguous()
self.forward_jit = TinyJit(self.forward) if jit else None
Expand All @@ -197,8 +207,7 @@ def forward(self, x: Tensor, start_pos: Union[Variable, int], temperature: float
else:
h = x

for i in range(self.shard.start_layer, self.shard.end_layer + 1):
layer = self.layers[i]
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)

if self.shard.is_last_layer():
Expand All @@ -222,19 +231,23 @@ def permute(v: Tensor, n_heads: int):
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])

keymap = {
"model.embed_tokens.weight": "tok_embeddings.weight",
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
**{f"model.embed_tokens.{_type}": f"tok_embeddings.{_type}" for _type in ["weight", "scales", "biases"]},
**{f"model.layers.{l}.input_layernorm.{_type}": f"layers.{l}.attention_norm.{_type}"
for _type in ["weight", "scales", "biases"]
for l in range(len(model.layers))},
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
**{f"model.layers.{l}.self_attn.{x}_proj.{_type}": f"layers.{l}.attention.w{x}.{_type}"
for _type in ["weight", "scales", "biases"]
for x in ["q", "k", "v", "o"]
for l in range(len(model.layers))},
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
**{f"model.layers.{l}.post_attention_layernorm.{_type}": f"layers.{l}.ffn_norm.{_type}"
for _type in ["weight", "scales", "biases"]
for l in range(len(model.layers))},
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
**{f"model.layers.{l}.mlp.{x}_proj.{_type}": f"layers.{l}.feed_forward.w{y}.{_type}"
for _type in ["weight", "scales", "biases"]
for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
for l in range(len(model.layers))},
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
**{f"model.norm.{_type}": f"norm.{_type}" for _type in ["weight", "scales", "biases"]},
**{f"lm_head.{_type}": f"output.{_type}" for _type in ["weight", "scales", "biases"]},
}
sd = {}
for k, v in weights.items():
Expand Down