diff --git a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py index 126a19be9411..9a106e9172ce 100644 --- a/tests/models/lfm2_moe/test_modeling_lfm2_moe.py +++ b/tests/models/lfm2_moe/test_modeling_lfm2_moe.py @@ -18,7 +18,9 @@ from transformers import AutoTokenizer, is_torch_available, set_seed from transformers.testing_utils import ( + Expectations, cleanup, + require_deterministic_for_xpu, require_read_token, require_torch, require_torch_accelerator, @@ -170,36 +172,30 @@ def test_model_1a8b_logits(self): input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) with torch.no_grad(): out = model(input_ids).logits.float().cpu() + # fmt: off # Expected mean on dim = -1 - EXPECTED_MEAN = torch.tensor( - [ - [ - -1.3855, - -0.5123, - -1.3143, - -1.2144, - -1.0791, - -1.2117, - -1.4704, - -0.7648, - -0.6175, - -1.2402, - -1.1459, - -1.0083, - -1.0247, - -0.8830, - -1.5643, - -1.7266, - -1.6254, - ] - ] + EXPECTED_MEANS = Expectations( + { + ("cuda", None): torch.tensor([[-1.3855, -0.5123, -1.3143, -1.2144, -1.0791, -1.2117, -1.4704, -0.7648, -0.6175, -1.2402, -1.1459, -1.0083, -1.0247, -0.8830, -1.5643, -1.7266, -1.6254,]]), + ("xpu", None): torch.tensor([[-1.3863, -0.4653, -1.3246, -1.3199, -1.0940, -1.2254, -1.4716, -0.8852, -0.5920, -1.2182, -1.1782, -1.0268, -1.0114, -0.8816, -1.5774, -1.7408, -1.6147,]]), + } ) - torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2) + # fmt: on + EXPECTED_MEAN = EXPECTED_MEANS.get_expectation() + out_mean = out.mean(-1) + torch.testing.assert_close(out_mean, EXPECTED_MEAN, rtol=1e-2, atol=1e-2) + # fmt: off # Expected portion of the logits - EXPECTED_SLICE = torch.tensor( - [-1.2656, 2.4844, 5.5000, -1.3359, -1.3203, -1.3438, 1.9375, 5.8438, -0.6523, -1.2891] + EXPECTED_SLICES = Expectations( + { + ("cuda", None): torch.tensor([-1.2656, 2.4844, 5.5000, -1.3359, -1.3203, -1.3438, 1.9375, 5.8438, -0.6523, -1.2891]), + ("xpu", None): torch.tensor([-1.2656, 2.4531, 5.4375, -1.3438, -1.3203, -1.3516, 1.9297, 5.7812, -0.6719, -1.3203]), + } ) - torch.testing.assert_close(out[0, 0, :10], EXPECTED_SLICE, rtol=1e-4, atol=1e-4) + # fmt: on + EXPECTED_SLICE = EXPECTED_SLICES.get_expectation() + out_slice = out[0, 0, :10] + torch.testing.assert_close(out_slice, EXPECTED_SLICE, rtol=1e-4, atol=1e-4) @slow def test_model_1a8b_generation(self): @@ -217,13 +213,25 @@ def test_model_1a8b_generation(self): self.assertEqual(EXPECTED_TEXT_COMPLETION, text) @slow + @require_deterministic_for_xpu def test_model_1a8b_batched_chat_generation(self): prompts = ["Who are you?", "Complete the text: Lorem ipsum dolor ", "The Meji Restoration in Japan ended"] - EXPECTED_TEXT_COMPLETIONS = [ - "Who are you?, a language model designed to assist with information and tasks? \nI am", - "Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor", - "The Meji Restoration in Japan ended or the Meiji Restoration (1868–1912) marked a pivotal", - ] + # fmt: off + EXPECTED_TEXT_COMPLETIONS = Expectations( + { + ("cuda", None): ["Who are you?, a language model designed to assist with information and tasks? \nI am", + "Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor", + "The Meji Restoration in Japan ended or the Meiji Restoration (1868–1912) marked a pivotal", + ], + ("xpu", None): ['Who are you? (AI) designed to assist? \nI am an AI assistant developed to', + 'Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor', + 'The Meji Restoration in Japan ended** \n* **Key Event:** The overthrow of the Tokugawa' + ], + } + ) + # fmt: on + EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation() + set_seed(1789) tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-8B-A1B", use_fast=False) model = self.get_model() @@ -233,4 +241,4 @@ def test_model_1a8b_batched_chat_generation(self): with torch.no_grad(): generated_ids = model.generate(**batched_input_ids, max_new_tokens=15, do_sample=False) text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETIONS, text) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text)