Skip to content

Why did Gemma7B perform poorly #365

Open
@JuJoker

Description

@JuJoker

Description of the bug:

I ran the Gemema-7B model based on the code in the example, and found that the model's answers were rather poor and didn't seem to understand my question at all. Is this normal? My device is Nvidia 4090 GPU. My code as follows:

import os
import sys
import torch
sys.path.append('gemma_pytorch')

from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM

VARIANT = '7b'
MACHINE_TYPE = 'cuda'

# Set up model config.
model_config = get_config_for_7b()

# Ensure that the tokenizer is present
tokenizer_path = os.path.join('./gemma-7b', 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join('./gemma-7b', f'gemma-7b.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

model_config.tokenizer = tokenizer_path

model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'
MODEL_CHAT_START = '<start_of_turn>model\n'
MODEL_CHAT_END = '<end_of_turn>\n'

MULTI_CHAT = ''

# while True:
#     input_text = input("User input(press q to exit):")
#     if input_text != 'q':
#         user_prompt = USER_CHAT_TEMPLATE.format(prompt=input_text)
#         MULTI_CHAT = MULTI_CHAT + user_prompt + MODEL_CHAT_START
#         model_response = model.generate(
#             user_prompt,
#             device=device,
#             output_len=64,
#         )
#         print(f'Model reply: {model_response}')
#         MULTI_CHAT = MULTI_CHAT + model_response + MODEL_CHAT_END
#     else:
#         break

prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

res = model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)
print(res)

The according response as follows:
image

Actual vs expected behavior:

but this code run on the kaggle, the result seems correctly,
image

Any other information you'd like to share?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    component:demosUpdate demosstatus:triagedIssue/PR triaged to the corresponding sub-teamtype:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions