-
-
Notifications
You must be signed in to change notification settings - Fork 748
Open
Labels
BugSomething isn't workingSomething isn't workingHugging Face IntegrationAn issue or PR that is related to the interaction between bitsandbytes and HF libraries.An issue or PR that is related to the interaction between bitsandbytes and HF libraries.Medium Priority(will be worked on after all high priority issues)(will be worked on after all high priority issues)
Description
System Info
bitsandbytes=0.42.0
transformers=4.37.1
Reproduction
import torch
import torch.nn as nn
import transformers
from transformers import BitsAndBytesConfig
class Wrapper(nn.Module):
def __init__(self):
super().__init__()
# get the llama model
llama_model = transformers.LlamaModel
bnb_config_4bit = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False, # whether to use double quantization
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16) # 4 bits qlora
# load the based model weights with quantization
self.model = llama_model.from_pretrained(
"Llama_weights/llama-2-7b-hf-weights",
low_cpu_mem_usage=True,
device_map= 0,
quantization_config=bnb_config_4bit,
torchscript=True,)
self.model.output_hidden_states = False
def forward(self, tokens_tensor):
self.model.eval()
o = self.model(tokens_tensor, output_hidden_states=False)
return o[0]
model = Wrapper()
model.eval()
with torch.no_grad():
dummy_tokens_tensor = torch.randint(0, 1000, (1, 50), dtype=torch.long).to("cuda")
outputs = model(dummy_tokens_tensor)
trace_model = torch.jit.trace(model, [dummy_tokens_tensor]) # this works, but with some trace waring
print("traced_model done")
torch.jit.save(trace_model, "llama_4bit.pt") # --> error!
The linetorch.jit.save(trace_model, "llama_4bit.pt") # --> error!
gives me error:
RuntimeError:
Could not export Python function call 'MatMul4Bit'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/autograd/function.py(506): apply
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py(577): matmul_4bit
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/bitsandbytes/nn/modules.py(256): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py(386): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py(798): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py(1070): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home///test/test_torchscript_llama.py(62): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/jit/_trace.py(1056): trace_module
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/jit/_trace.py(794): trace
Expected behavior
troch.jit.save shall work properly for quantized model
Metadata
Metadata
Assignees
Labels
BugSomething isn't workingSomething isn't workingHugging Face IntegrationAn issue or PR that is related to the interaction between bitsandbytes and HF libraries.An issue or PR that is related to the interaction between bitsandbytes and HF libraries.Medium Priority(will be worked on after all high priority issues)(will be worked on after all high priority issues)