22import random
33from typing import Optional , Tuple
44
5+ import pytest
56import torch
67from torch .nn import functional as F
7- import pytest
8-
9- from litgpt .config import Config
10- from litgpt .model import (
11- apply_rope ,
12- CausalSelfAttention ,
13- GPT ,
14- build_rope_cache ,
15- )
16- from litgpt .kvcache import KVCache
17- from litgpt .utils import batched_index_select
188
199from litgpt .attention import (
10+ DefaultKeysAndValues ,
11+ MultiHeadSelfAttention ,
2012 build_mask_cache ,
2113 build_mask_slice ,
22- DefaultKeysAndValues ,
2314 do_softcapping ,
24- MultiHeadSelfAttention ,
2515 scaled_dot_product_attention ,
2616)
17+ from litgpt .config import Config
18+ from litgpt .kvcache import KVCache
19+ from litgpt .model import (
20+ GPT ,
21+ CausalSelfAttention ,
22+ apply_rope ,
23+ build_rope_cache ,
24+ )
25+ from litgpt .utils import batched_index_select
2726
2827
2928@pytest .mark .parametrize (
@@ -126,7 +125,8 @@ def test_build_mask_slice(
126125 for bs in range (batch_size ):
127126 for nq in range (n_query_groups ):
128127 token_positions [bs , nq , :] = torch .randperm (
129- seq_len , device = device ,
128+ seq_len ,
129+ device = device ,
130130 )[:cache_length ]
131131 mask = build_mask_slice (
132132 input_pos = input_pos ,
@@ -137,15 +137,16 @@ def test_build_mask_slice(
137137 sliding_window_size = sliding_window_size ,
138138 )
139139 mask_cmp = batched_index_select (
140- full_mask [input_pos : (input_pos + num ), :],
140+ full_mask [input_pos : (input_pos + num ), :],
141141 dim = 1 ,
142142 idx = token_positions ,
143143 )
144144 torch .testing .assert_close (mask , mask_cmp )
145145
146146
147147@pytest .mark .parametrize (
148- "dtype" , [torch .float32 , torch .float16 , torch .bfloat16 ],
148+ "dtype" ,
149+ [torch .float32 , torch .float16 , torch .bfloat16 ],
149150)
150151def test_mask_sliding_window (dtype ):
151152 """
@@ -329,9 +330,9 @@ def scaled_dot_product_attention(
329330 # with softcapping we cannot use SDPA
330331 if self .config .attention_logit_softcapping is not None :
331332 scores = q @ k .mT * scale
332- #self.debug_intermediates["scores1"] = scores
333+ # self.debug_intermediates["scores1"] = scores
333334 scores = do_softcapping (scores , self .config .attention_logit_softcapping )
334- #self.debug_intermediates["scores2"] = scores
335+ # self.debug_intermediates["scores2"] = scores
335336 if mask is None :
336337 mask = torch .ones (q .size (2 ), q .size (2 ), dtype = q .dtype , device = q .device ).triu (diagonal = 1 )
337338 mask .masked_fill_ (mask .bool (), torch .finfo (q .dtype ).min )
@@ -347,7 +348,8 @@ def scaled_dot_product_attention(
347348
348349
349350def rope_cache_OLD (
350- config : Config , device : Optional [torch .device ] = None ,
351+ config : Config ,
352+ device : Optional [torch .device ] = None ,
351353) -> Tuple [torch .Tensor , torch .Tensor ]:
352354 if config .rope_adjustments is None :
353355 extra_config = None
@@ -368,9 +370,7 @@ def rope_cache_OLD(
368370 extra_config = {name : config .rope_adjustments [name ] for name in adjusted_params_required }
369371 else :
370372 # Some but not all parameters are specified; raise an error
371- missing_params = [
372- param for param , present in zip (adjusted_params_required , params_present ) if not present
373- ]
373+ missing_params = [param for param , present in zip (adjusted_params_required , params_present ) if not present ]
374374 raise ValueError (
375375 f"The following adjusted RoPE parameters are missing in rope_adjustments: { ', ' .join (missing_params )} . "
376376 "All adjusted RoPE parameters must be specified together."
@@ -387,12 +387,13 @@ def rope_cache_OLD(
387387 )
388388
389389
390-
391390@pytest .mark .parametrize (
392- "model_name" , ["gemma-2-27b" , "gemma-3-27b-it" ],
391+ "model_name" ,
392+ ["gemma-2-27b" , "gemma-3-27b-it" ],
393393)
394394@pytest .mark .parametrize (
395- "dtype" , [torch .float32 , torch .float16 , torch .bfloat16 ],
395+ "dtype" ,
396+ [torch .float32 , torch .float16 , torch .bfloat16 ],
396397)
397398def test_multi_head_attention_for_gemma (model_name , dtype ):
398399 """
@@ -414,7 +415,7 @@ def test_multi_head_attention_for_gemma(model_name, dtype):
414415 n_embd = 32 ,
415416 intermediate_size = 86 ,
416417 rotary_percentage = 1.0 ,
417- rope_indices = [0 , 1 ] if is_gemma_3 else None ,
418+ rope_indices = [0 , 1 ] if is_gemma_3 else None ,
418419 )
419420
420421 # Obtain RoPE parameters and compare
@@ -433,10 +434,12 @@ def test_multi_head_attention_for_gemma(model_name, dtype):
433434 for rep in range (num_repeats ):
434435 block_idx = rep % 2
435436 attn_new = CausalSelfAttention (
436- config , block_idx = block_idx ,
437+ config ,
438+ block_idx = block_idx ,
437439 ).to (dtype = dtype )
438440 attn_old = CausalSelfAttention_OLD (
439- config , block_idx = block_idx ,
441+ config ,
442+ block_idx = block_idx ,
440443 ).to (dtype = dtype )
441444 # Ensure they have the same weights
442445 attn_old .load_state_dict (attn_new .state_dict ())
0 commit comments