diff --git a/entropix/prompts.py b/entropix/prompts.py index 8e40bc7..6b1e5cf 100644 --- a/entropix/prompts.py +++ b/entropix/prompts.py @@ -169,4 +169,4 @@ def create_prompts_from_csv(csv_path: str) -> List[str]: You are a principal devops engineer at google. You are an expert at all things cloud and deployment. Your task is to create ansible and terraform script to bootstrasp k8 cluster on Azure. Be clear and concise. Make sure it is production grade. Think and reflect about your actions to ensure to accomplished the task successfully.<|eot_id|><|start_header_id|>assistant<|end_header_id|> -""" \ No newline at end of file +""" diff --git a/entropix/torch_device.py b/entropix/torch_device.py new file mode 100644 index 0000000..94c78f2 --- /dev/null +++ b/entropix/torch_device.py @@ -0,0 +1,10 @@ +import torch + +def get_device(): + if torch.backends.mps.is_available(): + device = torch.device("mps") + elif torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + return device diff --git a/entropix/torch_kvcache.py b/entropix/torch_kvcache.py index 6caf2f2..2d458a2 100644 --- a/entropix/torch_kvcache.py +++ b/entropix/torch_kvcache.py @@ -1,13 +1,8 @@ import torch import torch.nn as nn -# Device selection, tree is like first apple silicion, then cuda, fallback is cpu. -if torch.backends.mps.is_available(): - device = torch.device("mps") -elif torch.cuda.is_available(): - device = torch.device("cuda") -else: - device = torch.device("cpu") +from entropix.torch_device import get_device +device = get_device() #print(f"Using device: {device}") diff --git a/entropix/torch_main.py b/entropix/torch_main.py index 64cf1ee..a215a82 100644 --- a/entropix/torch_main.py +++ b/entropix/torch_main.py @@ -17,16 +17,9 @@ from entropix.torch_sampler import sample from entropix.prompts import prompt, bp1 -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +from entropix.torch_device import get_device - -# Device selection, tree is like first apple silicion, then cuda, fallback is cpu. -if torch.backends.mps.is_available(): - device = torch.device("mps") -elif torch.cuda.is_available(): - device = torch.device("cuda") -else: - device = torch.device("cpu") +device = get_device() print(f"Using device: {device}") @@ -107,7 +100,7 @@ def generate(xfmr_weights, model_params, tokens): bsz, seqlen = tokens.shape attn_mask = build_attn_mask(seqlen, cur_pos) freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope) - kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim).to(DEVICE) + kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim).to(device) logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask) next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True).to(torch.int32) gen_tokens = next_token @@ -127,4 +120,4 @@ def generate(xfmr_weights, model_params, tokens): generate(xfmr_weights, model_params, raw_tokens1) if __name__ == '__main__': - tyro.cli(main) \ No newline at end of file + tyro.cli(main) diff --git a/entropix/torch_model.py b/entropix/torch_model.py index 0ebb3e9..1b75533 100644 --- a/entropix/torch_model.py +++ b/entropix/torch_model.py @@ -10,13 +10,8 @@ DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max) -# Device selection, tree is like first apple silicion, then cuda, fallback is cpu. -if torch.backends.mps.is_available(): - device = torch.device("mps") -elif torch.cuda.is_available(): - device = torch.device("cuda") -else: - device = torch.device("cpu") +from entropix.torch_device import get_device +device = get_device() #print(f"Using device: {device}") @@ -77,4 +72,4 @@ def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: torch.Ten h = h + h_attn h = h + feed_forward(rms_norm(h, xfmr_weights.layer_weights[i].ffn_norm), xfmr_weights.layer_weights[i]) logits = F.linear(rms_norm(h, xfmr_weights.norm), xfmr_weights.output) - return logits, kvcache, scores, attn_stats \ No newline at end of file + return logits, kvcache, scores, attn_stats diff --git a/entropix/torch_sampler.py b/entropix/torch_sampler.py index 0e28a3f..ac1b7e5 100644 --- a/entropix/torch_sampler.py +++ b/entropix/torch_sampler.py @@ -2,13 +2,8 @@ import torch.nn.functional as F from typing import Tuple, Dict -# Device selection, tree is like first apple silicion, then cuda, fallback is cpu. -if torch.backends.mps.is_available(): - device = torch.device("mps") -elif torch.cuda.is_available(): - device = torch.device("cuda") -else: - device = torch.device("cpu") +from entropix.torch_device import get_device +device = get_device() LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E @@ -168,4 +163,4 @@ def sample(gen_tokens: torch.Tensor, logits: torch.Tensor, attention_scores: tor base_top_p=top_p, base_top_k=top_k, generator=generator - ) \ No newline at end of file + ) diff --git a/entropix/torch_stats.py b/entropix/torch_stats.py index 718783a..54ca263 100644 --- a/entropix/torch_stats.py +++ b/entropix/torch_stats.py @@ -1,12 +1,7 @@ import torch -# Device selection, tree is like first apple silicion, then cuda, fallback is cpu. -if torch.backends.mps.is_available(): - device = torch.device("mps") -elif torch.cuda.is_available(): - device = torch.device("cuda") -else: - device = torch.device("cpu") +from entropix.torch_device import get_device +device = get_device() #print(f"Using device: {device}") @@ -45,4 +40,4 @@ def update(self, scores: torch.Tensor, layer_idx: int): self.entropy[:, layer_idx, :] = new_entropy self.varentropy[:, layer_idx, :] = new_varentropy - return self \ No newline at end of file + return self diff --git a/entropix/torch_weights.py b/entropix/torch_weights.py index ccbf187..96986fc 100644 --- a/entropix/torch_weights.py +++ b/entropix/torch_weights.py @@ -10,13 +10,8 @@ from pathlib import Path -# Device selection, tree is like first apple silicion, then cuda, fallback is cpu. -if torch.backends.mps.is_available(): - device = torch.device("mps") -elif torch.cuda.is_available(): - device = torch.device("cuda") -else: - device = torch.device("cpu") +from entropix.torch_device import get_device +device = get_device() #print(f"Using device: {device}") @@ -80,4 +75,4 @@ def load_weights(ckpt_dir: Path = Path('weights/1B-Instruct'), n_layers: int = 1 layer_weights=layer_weights ) - return xfmr_weights \ No newline at end of file + return xfmr_weights