-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bounty] PyTorch & HuggingFace Interface #139
base: main
Are you sure you want to change the base?
Changes from 250 commits
da39519
89f1be0
b6f6afc
5eb6c34
ed64437
18d41eb
c73ed76
9ecbf0c
79c9e70
ebfd44a
dae2cbe
1c1dd06
55ae027
4b6a86d
830d33d
074dfe3
d9cfcc4
c3e1934
2c056b4
da5c28d
83a723b
47be250
ea0d4b1
30b7991
3a2c431
4def538
b35224c
6c6e7b2
55ffdc7
aacdeb5
ce702d1
e387a79
e0ba2bb
5b9638f
664f29f
2591fab
99dac57
c12526f
493cd3e
de23294
d5a02be
e7470b1
fa24f46
5c69f3f
f5a1cef
751bd1c
7d866d8
253237b
e46ffa4
476b6ba
f7e02e9
bd6322f
c51bd91
4a2aef4
79f0763
cbbc9cf
58cebab
7f9b1bb
c3adec5
c8e6acc
df028e2
e5a1939
d03a85c
69a8955
d07b825
a840e7f
bf5f22d
52fa3f8
ec49e31
f45b514
f90c24a
696c264
9514e92
d65505e
d5b6113
d2302cc
35c32eb
72fcf9b
9cac5ab
8012008
291aa10
76323d7
0d66acd
fcb298b
1512d13
0eb8044
a6768b4
8ba24e2
6e32be6
df13fbc
6b3af3f
cfb10ba
ea868c6
f1822e2
38028c0
0fd1797
5aaffe6
b2b63c3
f53ebd1
e8db8ee
22bc6a7
7f2abc3
bdf3240
227199f
5af6302
d7e5aca
1874d23
3a0ad62
6098ae5
fa1e70f
d958bf9
c8bdb09
75817eb
73630d1
ad99332
9ff2cc8
6ab6f1c
811befc
0e9f42a
a170cc6
05f3e52
ff78688
3ce2df0
4da7377
9f57e45
9f52f24
96c3eb5
4455224
fbf106e
2ecc629
405b5ae
0320c50
84f4131
596c715
a5eb1be
e550af2
21e626e
e8f689c
d0cc3b0
5f085dc
f667735
cb847e4
8325975
c8308b8
2dfce95
39cfbf5
185502a
738e931
907ba0b
f3c868b
f0c8fb1
c6806f9
4c93855
b538cd2
8c29d27
80122a0
3e0d117
34ca2ad
28d9900
30651ea
58122ba
fd1d469
d4f39fc
39ffe70
9494949
5085adb
c7a2b6b
522825f
88aeae7
73b71d5
1de87fb
028f305
f3bd881
099b6c1
027de6f
6f6b167
eabcdaa
6d29ba6
38878cb
0a985a7
48a75c1
bceeaf5
4f9f038
51184d4
bc7d699
c0d9f57
c0d0c71
05d4c9d
464b9cf
d469e3e
eeee605
5302b73
f1b05cd
13ea82d
5e31e3d
646c14a
25bdfb7
d0680b6
58a190a
d7d5590
ddaab5a
216ee1b
106e56e
754608a
3ac345e
9b1ce15
8ddcaea
81a27c5
192f0c5
fea4b31
ffd2907
726ef0d
f24397f
386ac0b
6bbbb04
0d5779e
fe1e8ef
f43af1b
9ec9b23
2f31a7b
b7cfece
332ed2a
983c341
73e13b4
f65766e
f80597c
58a7d3c
efcb5b9
4bf752b
7cc42f1
4e3e53e
c3bde74
a09956e
f508dff
8920a87
1d7262d
ec91e09
a7757d3
1431d48
7ad4b1c
fbb6e55
f7028c7
5f6b22d
76e141a
85d25c1
57b43f7
611bffb
0523893
2bc2b3d
aac3e75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Helper functions for pytorch inference | ||
# Some code coming from tinygrad but written towards pytorch | ||
|
||
import asyncio | ||
import aiohttp | ||
from tqdm import tqdm | ||
from pathlib import Path | ||
from typing import List | ||
|
||
async def fetch_file_async(session, url: str, output_path: Path): | ||
async with session.get(url) as response: | ||
response.raise_for_status() | ||
with open(output_path, 'wb') as f: | ||
async for chunk in response.content.iter_chunked(8192): | ||
f.write(chunk) | ||
|
||
async def download_files(urls: List[str], output_paths: List[Path]): | ||
async with aiohttp.ClientSession() as session: | ||
tasks = [] | ||
for url, output_path in zip(urls, output_paths): | ||
tasks.append(fetch_file_async(session, url, output_path)) | ||
|
||
for f in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Downloading files"): | ||
await f |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# experimental, based off of tinygrad/inference.py | ||
|
||
import numpy as np | ||
import torch | ||
import numpy as np | ||
import json | ||
from typing import Optional, Callable, Tuple | ||
from exo.inference.shard import Shard | ||
from exo.inference.inference_engine import InferenceEngine | ||
from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel | ||
from exo.api.chatgpt_api import resolve_tokenizer | ||
from exo.helpers import DEBUG | ||
from transformers import DynamicCache | ||
|
||
class PyTorchDynamicShardInferenceEngine(InferenceEngine): | ||
""" | ||
PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Initialize the inference engine. | ||
|
||
Args: | ||
debug (bool): If True, enables debug logging. Defaults to False. | ||
""" | ||
self.shard = None | ||
self.model = None | ||
self.tokenizer = None | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these the only options? I think supporting e.g. Mac with mps would be great since then you can run heterogeneous clusters. One thing to try at some point would be mixing MLX and PyTorch and see if they are interoperable with exactly the same model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With pytorch I don't think mac is fully rolled out yet. There seems to be some work arounds but CUDA and CPU are the only options on the pytorch download website. pytorch even stopped ROCm support for AMD They have a nightly for testing MPS https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about this in the official "stable" docs: https://pytorch.org/docs/stable/notes/mps.html There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will try that but currently no mac to test. When I get through these other fixes though I can definitely add it for you or other mac users to test. |
||
|
||
async def infer_prompt( | ||
self, | ||
request_id: str, | ||
shard: Optional[Shard] = None, | ||
prompt: str = "", | ||
image_str: Optional[str] = None, | ||
inference_state: Optional[str] = None | ||
) -> Tuple[np.ndarray, str, bool]: | ||
if DEBUG >= 2: | ||
print("infer_prompt called") | ||
|
||
await self.ensure_shard(shard) | ||
|
||
# need to make this so inference_state is not a string | ||
# cant use it with dynamic cache | ||
|
||
tokens = self.tokenizer.encode(prompt, return_tensors="pt") | ||
|
||
if DEBUG >= 2: | ||
print(f"tokens: {tokens}\n") | ||
|
||
output_data = self.model.forward_layers( | ||
tokens | ||
) | ||
|
||
is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] | ||
|
||
if is_finished: | ||
print(f"token from llm decode: {self.tokenizer.decode(output_data)}") | ||
|
||
|
||
if DEBUG >= 2: | ||
print(f"output_data: {output_data}\n") | ||
print(f"output_data.size {output_data.size}\n") | ||
print(f"output_data.item() {output_data.item()}") | ||
print(f"finished: {is_finished}") | ||
print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") | ||
print(f"output_data[-1] {output_data[-1]}") | ||
print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") | ||
|
||
return ( | ||
output_data, | ||
"", | ||
is_finished | ||
) | ||
|
||
async def infer_tensor( | ||
self, | ||
request_id: str, | ||
shard: Shard, | ||
input_data: np.ndarray, | ||
inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: | ||
|
||
in_tensor = torch.tensor(input_data) | ||
|
||
# Ensure input_data is 2D: [batch_size, seq_len] | ||
if in_tensor.dim() == 1: | ||
in_tensor = in_tensor.unsqueeze(0) # Add a batch dimension: [1, seq_len] | ||
|
||
if DEBUG >= 2: | ||
print("infer_tensor called") | ||
print(f"input_data: {input_data}\n") | ||
print(f"in_tensor: {in_tensor}\n") | ||
|
||
await self.ensure_shard(shard) | ||
|
||
output_data = self.model.forward_layers( | ||
in_tensor | ||
) | ||
|
||
is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] | ||
|
||
if DEBUG >= 2: | ||
print(f"output_data: {output_data}\n") | ||
print(f"output_data.size {output_data.size}\n") | ||
print(f"output_data.item() {output_data.item()}") | ||
print(f"finished: {is_finished}") | ||
print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") | ||
print(f"output_data[-1] {output_data[-1]}") | ||
print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") | ||
|
||
return ( | ||
output_data, | ||
"", | ||
is_finished | ||
) | ||
|
||
async def ensure_shard(self, shard: Optional[Shard]): | ||
""" | ||
Ensure the model shard is loaded and ready for inference. | ||
|
||
Args: | ||
shard (Optional[Shard]): Shard information for the model. | ||
""" | ||
if self.shard == shard: | ||
return | ||
|
||
if DEBUG >= 2: | ||
print(f"Loading new shard: {shard}") | ||
|
||
self.model = ShardedHuggingFaceModel(shard) | ||
self.tokenizer = await resolve_tokenizer(shard.model_id) | ||
self.shard = shard | ||
|
||
if DEBUG >= 2: | ||
print(f"Shard loaded successfully: {shard}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i dont think this is used |
||
|
||
from transformers import AutoModelForCausalLM, LlamaConfig, DynamicCache, Cache | ||
from exo.inference.shard import Shard | ||
from exo.helpers import DEBUG | ||
from typing import Tuple | ||
|
||
from .utils import sample_logits | ||
|
||
class ShardedHuggingFaceModel(torch.nn.Module): | ||
def __init__(self, shard: Shard): | ||
super(ShardedHuggingFaceModel, self).__init__() | ||
|
||
if DEBUG >= 2: | ||
print(f"\nShardedHuggingFaceModel init with shard {shard}") | ||
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
self.shard = shard | ||
|
||
|
||
|
||
# Load the model | ||
self.full_model = AutoModelForCausalLM.from_pretrained( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this download the entire model? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also this won't work with our download progress code. We show in the TUI what the download progress of the model is. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Will look at using that code because yes it currently does download all the model |
||
shard.model_id, | ||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, are these the only options? Would want support across other platforms |
||
device_map="auto" | ||
) | ||
|
||
# using llamaconfig not working setting layers manually | ||
layers = [] | ||
for i in range(shard.start_layer, shard.end_layer + 1): | ||
layer = self.full_model.model.layers[i] | ||
|
||
if DEBUG >= 2: | ||
print(f"Loading layers[{i}]") | ||
|
||
layers.append(layer) | ||
|
||
self.full_model.model.layers = nn.ModuleList(layers) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does the peak memory usage look like here? I'm not sure of the specifics of python if this is going to hold each layer twice. Not sure but perhaps setting them in place would be more memory efficient. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They shouldn't be held twice as when the ensure_shard function is called in the infer_prompt or infer_tensor the init class function is called which loads the needed layers each time depending on the shard. Will make sure about memory limits though and usage. |
||
|
||
if DEBUG >= 2: | ||
print(f"full_model.model layer: {len(self.full_model.model.layers)}") | ||
|
||
# Embeddings and final layer norm | ||
# used for doing what forward LlamaModel does in transformers | ||
self.embed_tokens = self.full_model.model.embed_tokens | ||
self.norm = self.full_model.model.norm | ||
|
||
# self.past_key_values = DynamicCache() | ||
|
||
def forward_layers( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really like this approach of generalising this so it works for other models without having to explicitly implement them. Can you write a test for a model with a different architecture to make sure this generalises e.g. recurrent Gemma? |
||
self, | ||
input_data: torch.tensor | ||
) -> Tuple[np.ndarray, list]: | ||
""" | ||
Forward pass through the specified layers. | ||
|
||
Note: past_key_values not working for model, might be a library bug | ||
""" | ||
if DEBUG >= 2: | ||
print("forward_layer call") | ||
print(f"input_data: {input_data}") | ||
print(f"shard {self.shard.to_dict()}") | ||
|
||
hidden_states = input_data | ||
|
||
# Forward pass through the layer | ||
if DEBUG >= 2: | ||
print(f"\n[layer model] {self.full_model.model}") | ||
print(f"IN hidden_states {hidden_states}") | ||
# print(f"past_kvs {past_kvs}") | ||
|
||
self.full_model.model.layer_idx = 5 | ||
layer_outputs = self.full_model.model( | ||
hidden_states, | ||
# position_ids=position_ids, | ||
# inputs_embeds=position_embeddings, | ||
# past_key_values=self.past_key_values, | ||
use_cache=False # not enough vram for using cache ;_; | ||
) | ||
|
||
if DEBUG >= 2: | ||
print(f"OUT hidden_states {hidden_states}") | ||
# print(f"\nlayer_outputs: {layer_outputs}") | ||
|
||
hidden_states = layer_outputs.last_hidden_state | ||
# self.past_key_values = layer_outputs.past_key_values | ||
|
||
print(f"2 is_last_layer {self.shard.is_last_layer()}") | ||
if self.shard.is_last_layer(): | ||
hs_norm = self.norm(hidden_states) | ||
hs_lm_head = self.full_model.lm_head(hs_norm).float() | ||
|
||
# Use the sampling function with default settings | ||
output_token = sample_logits( | ||
hs_lm_head[:, -1, :]).cpu().numpy().flatten() | ||
|
||
if DEBUG >= 2: | ||
print(f"hs_norm: {hs_norm}") | ||
print(f"hs_lm_head: {hs_lm_head}") | ||
print(f"output_token: {output_token}") | ||
|
||
return output_token | ||
|
||
return hidden_states.cpu().numpy() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
from torch.nn import functional as F | ||
|
||
def sample_logits(logits, temperature=1.0, top_k=0, top_p=1.0, alpha_f=0.0, alpha_p=0.0): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be imported from somewhere rather than copy-pasta into the codebase? It looks like boilerplate code from somewhere. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I was testing it as the default values but will clean that part up. I will set it in the Interface class settings to be used. |
||
""" | ||
Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. | ||
|
||
Args: | ||
logits (torch.Tensor): The logits distribution to sample from. | ||
temperature (float): Temperature for scaling logits. | ||
top_k (int): The number of top tokens to consider for sampling. | ||
top_p (float): The cumulative probability threshold for nucleus sampling. | ||
alpha_f (float): Penalty factor for repetition frequency. | ||
alpha_p (float): Penalty for repeated selection. | ||
|
||
Returns: | ||
torch.Tensor: The selected token index. | ||
""" | ||
|
||
# Ensure logits are float | ||
logits = logits.float() | ||
|
||
# If temperature is very low, just use argmax | ||
if temperature < 1e-6: | ||
return logits.argmax(dim=-1) | ||
|
||
# Alpha sampling (adjusting logits based on past selections) | ||
if alpha_f > 0.0 or alpha_p > 0.0: | ||
logits -= (sample_logits.alpha_counter * alpha_f + (sample_logits.alpha_counter > 0) * alpha_p) | ||
|
||
# Replace NaNs with -inf to prevent softmax issues | ||
logits = torch.where(torch.isnan(logits), torch.full_like(logits, -float('inf')), logits) | ||
|
||
# Apply temperature scaling | ||
logits = logits / temperature | ||
|
||
# Top-k sampling | ||
if top_k > 0: | ||
top_k = min(top_k, logits.size(-1)) | ||
top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) | ||
logits = torch.full_like(logits, -float('inf')) | ||
logits.scatter_(-1, top_k_indices, top_k_values) | ||
|
||
# Top-p sampling | ||
if 0 < top_p < 1.0: | ||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | ||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | ||
|
||
# Remove tokens with cumulative probability above the threshold | ||
sorted_indices_to_remove = cumulative_probs > top_p | ||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | ||
sorted_indices_to_remove[..., 0] = 0 | ||
|
||
sorted_logits[sorted_indices_to_remove] = -float('inf') | ||
logits = sorted_logits | ||
|
||
# Apply softmax to get probabilities | ||
probabilities = F.softmax(logits, dim=-1) | ||
|
||
# Sample from the probabilities | ||
sampled_token = torch.multinomial(probabilities, 1) | ||
|
||
return sampled_token.squeeze() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import unittest | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Run this test in circle ci |
||
from unittest.mock import patch, MagicMock | ||
from pathlib import Path | ||
import torch | ||
from exo.inference.shard import Shard | ||
from exo.inference.pytorch.helpers import build_transformer | ||
|
||
class TestBuildTransformer(unittest.TestCase): | ||
|
||
def test_build_transformer(self): | ||
# Call the build_transformer function | ||
model = build_transformer( | ||
"gpt2", | ||
quantize=True, | ||
device="cuda" | ||
) | ||
|
||
self.assertIsNotNone(model) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I can remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry forgot to. Will do that now.