@@ -45,32 +45,37 @@ def test_normal_logprobs(self):
4545 self .assertEqual (result [0 ][1 ].logprob , - 2.5 )
4646 self .assertEqual (result [0 ][2 ].logprob , - 1.0 )
4747
48- def test_negative_inf_logprobs_raises_error (self ):
49- """Test that logprobs containing -inf raises AttributeError """
48+ def test_negative_inf_logprobs_gets_clamped (self ):
49+ """Test that logprobs containing -inf get clamped to -9999.0 """
5050 logprob_dict = {
5151 1 : Logprob (logprob = float ("-inf" ), rank = 1 , decoded_token = "hello" ),
5252 2 : Logprob (logprob = - 1.0 , rank = 2 , decoded_token = "world" ),
5353 }
5454 prompt_logprobs = [logprob_dict ]
5555
56- # Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
57- with self .assertRaises (AttributeError ) as context :
58- clamp_prompt_logprobs (prompt_logprobs )
56+ # Since Logprob is now a dataclass, its fields can be modified
57+ result = clamp_prompt_logprobs (prompt_logprobs )
5958
60- self .assertIn ("can't set attribute" , str (context .exception ))
59+ # The -inf value should be clamped to -9999.0
60+ self .assertEqual (result [0 ][1 ].logprob , - 9999.0 )
61+ self .assertEqual (result [0 ][2 ].logprob , - 1.0 ) # unchanged
6162
62- def test_multiple_negative_inf_raises_error (self ):
63- """Test that multiple -inf logprobs values raise AttributeError """
63+ def test_multiple_negative_inf_gets_clamped (self ):
64+ """Test that multiple -inf logprobs values get clamped to -9999.0 """
6465 logprob_dict = {
6566 1 : Logprob (logprob = float ("-inf" ), rank = 1 , decoded_token = "hello" ),
6667 2 : Logprob (logprob = float ("-inf" ), rank = 2 , decoded_token = "world" ),
6768 3 : Logprob (logprob = - 0.5 , rank = 3 , decoded_token = "test" ),
6869 }
6970 prompt_logprobs = [logprob_dict ]
7071
71- # Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
72- with self .assertRaises (AttributeError ):
73- clamp_prompt_logprobs (prompt_logprobs )
72+ # Since Logprob is now a dataclass, its fields can be modified
73+ result = clamp_prompt_logprobs (prompt_logprobs )
74+
75+ # All -inf values should be clamped to -9999.0
76+ self .assertEqual (result [0 ][1 ].logprob , - 9999.0 )
77+ self .assertEqual (result [0 ][2 ].logprob , - 9999.0 )
78+ self .assertEqual (result [0 ][3 ].logprob , - 0.5 ) # unchanged
7479
7580 def test_none_dict_in_list (self ):
7681 """Test case when list contains None"""
@@ -116,15 +121,15 @@ def test_mixed_values_without_inf(self):
116121 self .assertEqual (result [0 ][4 ].logprob , - 1.5 )
117122
118123 def test_return_same_object (self ):
119- """Test that function returns the same object (in-place modification attempt )"""
124+ """Test that function returns the same object (in-place modification)"""
120125 logprob_dict = {
121126 1 : Logprob (logprob = - 2.0 , rank = 1 , decoded_token = "hello" ),
122127 }
123128 prompt_logprobs = [logprob_dict ]
124129
125130 result = clamp_prompt_logprobs (prompt_logprobs )
126131
127- # Should return the same object (function attempts in-place modification)
132+ # Should return the same object (function performs in-place modification)
128133 self .assertIs (result , prompt_logprobs )
129134 self .assertIs (result [0 ], prompt_logprobs [0 ])
130135
0 commit comments