1818
1919from transformers import AutoTokenizer , is_torch_available , set_seed
2020from transformers .testing_utils import (
21+ Expectations ,
2122 cleanup ,
23+ require_deterministic_for_xpu ,
2224 require_read_token ,
2325 require_torch ,
2426 require_torch_accelerator ,
@@ -170,36 +172,30 @@ def test_model_1a8b_logits(self):
170172 input_ids = torch .tensor ([input_ids ]).to (model .model .embed_tokens .weight .device )
171173 with torch .no_grad ():
172174 out = model (input_ids ).logits .float ().cpu ()
175+ # fmt: off
173176 # Expected mean on dim = -1
174- EXPECTED_MEAN = torch .tensor (
175- [
176- [
177- - 1.3855 ,
178- - 0.5123 ,
179- - 1.3143 ,
180- - 1.2144 ,
181- - 1.0791 ,
182- - 1.2117 ,
183- - 1.4704 ,
184- - 0.7648 ,
185- - 0.6175 ,
186- - 1.2402 ,
187- - 1.1459 ,
188- - 1.0083 ,
189- - 1.0247 ,
190- - 0.8830 ,
191- - 1.5643 ,
192- - 1.7266 ,
193- - 1.6254 ,
194- ]
195- ]
177+ EXPECTED_MEANS = Expectations (
178+ {
179+ ("cuda" , None ): torch .tensor ([[- 1.3855 , - 0.5123 , - 1.3143 , - 1.2144 , - 1.0791 , - 1.2117 , - 1.4704 , - 0.7648 , - 0.6175 , - 1.2402 , - 1.1459 , - 1.0083 , - 1.0247 , - 0.8830 , - 1.5643 , - 1.7266 , - 1.6254 ,]]),
180+ ("xpu" , None ): torch .tensor ([[- 1.3863 , - 0.4653 , - 1.3246 , - 1.3199 , - 1.0940 , - 1.2254 , - 1.4716 , - 0.8852 , - 0.5920 , - 1.2182 , - 1.1782 , - 1.0268 , - 1.0114 , - 0.8816 , - 1.5774 , - 1.7408 , - 1.6147 ,]]),
181+ }
196182 )
197- torch .testing .assert_close (out .mean (- 1 ), EXPECTED_MEAN , rtol = 1e-2 , atol = 1e-2 )
183+ # fmt: on
184+ EXPECTED_MEAN = EXPECTED_MEANS .get_expectation ()
185+ out_mean = out .mean (- 1 )
186+ torch .testing .assert_close (out_mean , EXPECTED_MEAN , rtol = 1e-2 , atol = 1e-2 )
187+ # fmt: off
198188 # Expected portion of the logits
199- EXPECTED_SLICE = torch .tensor (
200- [- 1.2656 , 2.4844 , 5.5000 , - 1.3359 , - 1.3203 , - 1.3438 , 1.9375 , 5.8438 , - 0.6523 , - 1.2891 ]
189+ EXPECTED_SLICES = Expectations (
190+ {
191+ ("cuda" , None ): torch .tensor ([- 1.2656 , 2.4844 , 5.5000 , - 1.3359 , - 1.3203 , - 1.3438 , 1.9375 , 5.8438 , - 0.6523 , - 1.2891 ]),
192+ ("xpu" , None ): torch .tensor ([- 1.2656 , 2.4531 , 5.4375 , - 1.3438 , - 1.3203 , - 1.3516 , 1.9297 , 5.7812 , - 0.6719 , - 1.3203 ]),
193+ }
201194 )
202- torch .testing .assert_close (out [0 , 0 , :10 ], EXPECTED_SLICE , rtol = 1e-4 , atol = 1e-4 )
195+ # fmt: on
196+ EXPECTED_SLICE = EXPECTED_SLICES .get_expectation ()
197+ out_slice = out [0 , 0 , :10 ]
198+ torch .testing .assert_close (out_slice , EXPECTED_SLICE , rtol = 1e-4 , atol = 1e-4 )
203199
204200 @slow
205201 def test_model_1a8b_generation (self ):
@@ -217,13 +213,25 @@ def test_model_1a8b_generation(self):
217213 self .assertEqual (EXPECTED_TEXT_COMPLETION , text )
218214
219215 @slow
216+ @require_deterministic_for_xpu
220217 def test_model_1a8b_batched_chat_generation (self ):
221218 prompts = ["Who are you?" , "Complete the text: Lorem ipsum dolor " , "The Meji Restoration in Japan ended" ]
222- EXPECTED_TEXT_COMPLETIONS = [
223- "Who are you?, a language model designed to assist with information and tasks? \n I am" ,
224- "Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor" ,
225- "The Meji Restoration in Japan ended or the Meiji Restoration (1868–1912) marked a pivotal" ,
226- ]
219+ # fmt: off
220+ EXPECTED_TEXT_COMPLETIONS = Expectations (
221+ {
222+ ("cuda" , None ): ["Who are you?, a language model designed to assist with information and tasks? \n I am" ,
223+ "Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor" ,
224+ "The Meji Restoration in Japan ended or the Meiji Restoration (1868–1912) marked a pivotal" ,
225+ ],
226+ ("xpu" , None ): ['Who are you? (AI) designed to assist? \n I am an AI assistant developed to' ,
227+ 'Complete the text: Lorem ipsum dolor ipsum dolor ipsum dolor ipsum dolor ipsum dolor' ,
228+ 'The Meji Restoration in Japan ended** \n * **Key Event:** The overthrow of the Tokugawa'
229+ ],
230+ }
231+ )
232+ # fmt: on
233+ EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS .get_expectation ()
234+
227235 set_seed (1789 )
228236 tokenizer = AutoTokenizer .from_pretrained ("LiquidAI/LFM2-8B-A1B" , use_fast = False )
229237 model = self .get_model ()
@@ -233,4 +241,4 @@ def test_model_1a8b_batched_chat_generation(self):
233241 with torch .no_grad ():
234242 generated_ids = model .generate (** batched_input_ids , max_new_tokens = 15 , do_sample = False )
235243 text = tokenizer .batch_decode (generated_ids , skip_special_tokens = True )
236- self .assertEqual (EXPECTED_TEXT_COMPLETIONS , text )
244+ self .assertEqual (EXPECTED_TEXT_COMPLETION , text )
0 commit comments