Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 40 additions & 32 deletions tests/models/lfm2_moe/test_modeling_lfm2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

from transformers import AutoTokenizer, is_torch_available, set_seed
from transformers.testing_utils import (
Expectations,
cleanup,
require_deterministic_for_xpu,
require_read_token,
require_torch,
require_torch_accelerator,
Expand Down Expand Up @@ -170,36 +172,30 @@ def test_model_1a8b_logits(self):
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.float().cpu()
# fmt: off
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor(
[
[
-1.3855,
-0.5123,
-1.3143,
-1.2144,
-1.0791,
-1.2117,
-1.4704,
-0.7648,
-0.6175,
-1.2402,
-1.1459,
-1.0083,
-1.0247,
-0.8830,
-1.5643,
-1.7266,
-1.6254,
]
]
EXPECTED_MEANS = Expectations(
{
("cuda", None): torch.tensor([[-1.3855, -0.5123, -1.3143, -1.2144, -1.0791, -1.2117, -1.4704, -0.7648, -0.6175, -1.2402, -1.1459, -1.0083, -1.0247, -0.8830, -1.5643, -1.7266, -1.6254,]]),
("xpu", None): torch.tensor([[-1.3863, -0.4653, -1.3246, -1.3199, -1.0940, -1.2254, -1.4716, -0.8852, -0.5920, -1.2182, -1.1782, -1.0268, -1.0114, -0.8816, -1.5774, -1.7408, -1.6147,]]),
}
)
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
# fmt: on
EXPECTED_MEAN = EXPECTED_MEANS.get_expectation()
out_mean = out.mean(-1)
torch.testing.assert_close(out_mean, EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
# fmt: off
# Expected portion of the logits
EXPECTED_SLICE = torch.tensor(
[-1.2656, 2.4844, 5.5000, -1.3359, -1.3203, -1.3438, 1.9375, 5.8438, -0.6523, -1.2891]
EXPECTED_SLICES = Expectations(
{
("cuda", None): torch.tensor([-1.2656, 2.4844, 5.5000, -1.3359, -1.3203, -1.3438, 1.9375, 5.8438, -0.6523, -1.2891]),
("xpu", None): torch.tensor([-1.2656, 2.4531, 5.4375, -1.3438, -1.3203, -1.3516, 1.9297, 5.7812, -0.6719, -1.3203]),
}
)
torch.testing.assert_close(out[0, 0, :10], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
# fmt: on
EXPECTED_SLICE = EXPECTED_SLICES.get_expectation()
out_slice = out[0, 0, :10]
torch.testing.assert_close(out_slice, EXPECTED_SLICE, rtol=1e-4, atol=1e-4)

@slow
def test_model_1a8b_generation(self):
Expand All @@ -217,13 +213,25 @@ def test_model_1a8b_generation(self):
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)

@slow
@require_deterministic_for_xpu
def test_model_1a8b_batched_chat_generation(self):
prompts = ["Who are you?", "Complete the text: Lorem ipsum dolor ", "The Meji Restoration in Japan ended"]
EXPECTED_TEXT_COMPLETIONS = [
"Who are you?, a language model designed to assist with information and tasks? \nI am",
"Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor",
"The Meji Restoration in Japan ended or the Meiji Restoration (1868–1912) marked a pivotal",
]
# fmt: off
EXPECTED_TEXT_COMPLETIONS = Expectations(
{
("cuda", None): ["Who are you?, a language model designed to assist with information and tasks? \nI am",
"Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor",
"The Meji Restoration in Japan ended or the Meiji Restoration (1868–1912) marked a pivotal",
],
("xpu", None): ['Who are you? (AI) designed to assist? \nI am an AI assistant developed to',
'Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor',
'The Meji Restoration in Japan ended** \n* **Key Event:** The overthrow of the Tokugawa'
],
}
)
# fmt: on
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()

set_seed(1789)
tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-8B-A1B", use_fast=False)
model = self.get_model()
Expand All @@ -233,4 +241,4 @@ def test_model_1a8b_batched_chat_generation(self):
with torch.no_grad():
generated_ids = model.generate(**batched_input_ids, max_new_tokens=15, do_sample=False)
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETIONS, text)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)