Skip to content

Commit

Permalink
Fix bugs (#1701)
Browse files Browse the repository at this point in the history
* Phi 4

* Update llama.py

* Torch.Cuda Is Available Condition and Warning (#1545)

* check for torch.cuda and triton if available
on my machine(mac m3) the cuda were not available

* Update pyproject.toml

* Update __init__.py

---------

Co-authored-by: Daniel Han <[email protected]>

* Update mistral.py

* Update mistral.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Fix

* Bug fixes

* Update mapper.py

* Add dropout to granite to match HF's implementation (#1557)

Signed-off-by: datta0 <[email protected]>

* Update llama.py

* Update llama.py

* Bug fixes

* fix: flash_attn_detection_error (#1556)

* fix: flash_attn_detection_error

* Update _utils.py

---------

Co-authored-by: Daniel Han <[email protected]>

* Update mapper.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* dim fix

* Update _utils.py

* Torch 2.6 support

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Faster inference?

* Update llama.py

* Update llama.py

* Update utils.py

* Update llama.py

* Update llama.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update mapper.py

* Fast Inference via vLLM

* Update llama.py

* Update llama.py

* Update utils.py

* Create rl.py

* PatchRL

* Update rl.py

* Update rl.py

* Update rl.py

* PatchRLStatistics

* Update rl.py

* Update rl.py

* Update rl.py

* Update utils.py

* Update utils.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* RL metrics

* Update rl.py

* RL metrics

* Update __init__.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update chat_templates.py

* Update mapper.py

* Fp8 cache

* Update llama.py

* Update llama.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update __init__.py

* Update loader.py

* Update rl.py

* Update rl.py

* Update _utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Better TRL handling

* Update rl.py

* Update tokenizer_utils.py

* Auto patching

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update rl.py

* Update tokenizer_utils.py

* Update rl.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update tokenizer_utils.py

* Update rl.py

* Update rl.py

* Update rl.py

* max seq length

* Update rl.py

* Update rl.py

* Patching

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* NEFTune

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Extra replacements

* Update rl_replacements.py

* Update rl.py

* extra RL replacements

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update llama.py

* Update rl_replacements.py

* Update _utils.py

* Update loader_utils.py

* Update rl.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* autocast

* Update rl_replacements.py

* Update llama.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update llama.py

* Update rl_replacements.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update rl_replacements.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update pyproject.toml

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update llama.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update llama.py

* Update _utils.py

---------

Signed-off-by: datta0 <[email protected]>
Co-authored-by: AminWhat <[email protected]>
Co-authored-by: Datta Nimmaturi <[email protected]>
Co-authored-by: Zhe Zhang <[email protected]>
  • Loading branch information
4 people authored Feb 13, 2025
1 parent 027b510 commit 016315e
Show file tree
Hide file tree
Showing 7 changed files with 670 additions and 323 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ cu124onlytorch260 = [
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
cu126onlytorch260 = [
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
Expand Down
5 changes: 4 additions & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "2025.2.4"
__version__ = "2025.2.5"

__all__ = [
"SUPPORTS_BFLOAT16",
Expand Down Expand Up @@ -131,13 +131,16 @@

# Ignore logging messages
class HideLoggingMessage(logging.Filter):
__slots__ = "text",
def __init__(self, text): self.text = text
def filter(self, x): return not (self.text in x.getMessage())
pass

# The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here.
from transformers.training_args import logger as transformers_training_args_logger
transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups"))
# torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED.
transformers_training_args_logger.addFilter(HideLoggingMessage("torch.distributed"))
del transformers_training_args_logger

# Using the default loss: `ForCausalLMLoss`.
Expand Down
55 changes: 34 additions & 21 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
import gc
import math
from functools import partial
from typing import Optional, Tuple, List, Union
from ._utils import *
from ._utils import __version__
Expand Down Expand Up @@ -447,20 +448,28 @@ def LlamaAttention_fast_forward(
A = flash_attn_func(Q, K, V, causal = True)
else:
# Grouped query attention
if n_groups != 1:
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
pass
# Must be contiguous or else results are False!
# https://github.com/pytorch/pytorch/issues/112577
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
if SDPA_HAS_GQA:
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2)#.contiguous()
else:
if n_groups != 1:
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
pass
# Must be contiguous or else results are False!
# https://github.com/pytorch/pytorch/issues/112577
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
pass
pass
attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
Expand Down Expand Up @@ -699,6 +708,7 @@ def LlamaModel_fast_forward(
if attention_mask is None:
padding_mask = None
elif self.training:
# elif attention_mask is not None and self.training:
attention_mask = None
padding_mask = None
else:
Expand All @@ -714,6 +724,7 @@ def LlamaModel_fast_forward(
past_key_values_length,
sliding_window = getattr(self.config, "sliding_window", None),
)
attention_mask = attention_mask.to(torch.bool)
pass

hidden_states = inputs_embeds
Expand Down Expand Up @@ -1802,8 +1813,6 @@ def from_pretrained(
model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype)
model.vllm_engine = llm
model.fast_generate = model.vllm_engine.generate

from functools import partial
model.fast_generate_batches = partial(generate_batches, model.vllm_engine)
pass
# Return old flag
Expand Down Expand Up @@ -1952,13 +1961,13 @@ def from_pretrained(
Trainer._inner_training_loop = _fast_inner_training_loop

# Save max_seq_length
model.max_seq_length = max_position_embeddings
model.max_seq_length = max_seq_length
internal_model = model
while hasattr(internal_model, "model"):
internal_model.max_seq_length = max_position_embeddings
internal_model.max_seq_length = max_seq_length
internal_model = internal_model.model
pass
internal_model.max_seq_length = max_position_embeddings
internal_model.max_seq_length = max_seq_length

# We check the tokenizer first for errors
if fix_tokenizer:
Expand Down Expand Up @@ -2146,8 +2155,6 @@ def get_peft_model(
signature = str(inspect.signature(LoraConfig))
SUPPORTS_LOFTQ = "loftq_config" in signature
SUPPORTS_RSLORA = "use_rslora" in signature

assert(max_seq_length <= model.max_seq_length)

if lora_dropout != 0:
logger.warning_once(
Expand Down Expand Up @@ -2632,6 +2639,10 @@ def patch_peft_model(
gc.collect()
torch.cuda.empty_cache()
pass

# Add for_inference and for_training
model.for_training = partial(FastLlamaModel.for_training, model)
model.for_inference = partial(FastLlamaModel.for_inference, model)
return model
pass

Expand Down Expand Up @@ -2739,3 +2750,5 @@ def for_training(model, use_gradient_checkpointing = True):
pass
pass

from .rl import PatchFastRL
PatchFastRL(FastLanguageModel = FastLlamaModel)
5 changes: 5 additions & 0 deletions unsloth/models/loader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def __get_model_name(

elif load_in_4bit and SUPPORTS_FOURBIT and lower_model_name in FLOAT_TO_INT_MAPPER:

# Support returning original full -bnb-4bit name if specified specifically
# since we'll map it to the dynamic version instead
if lower_model_name.endswith("-bnb-4bit"):
return lower_model_name

new_model_name = FLOAT_TO_INT_MAPPER[lower_model_name]
# logger.warning_once(
# f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\
Expand Down
Loading

0 comments on commit 016315e

Please sign in to comment.