@@ -176,8 +176,10 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor:
176176 one_hot_activation = torch .zeros (1 , 1 , d_latent , device = sae_device )
177177
178178 if feature_id >= d_latent :
179- print (f"""DEBUG: ERROR - Feature ID { feature_id } is out of bounds
180- for d_latent { d_latent } """ )
179+ print (
180+ f"""DEBUG: ERROR - Feature ID { feature_id } is out of bounds
181+ for d_latent { d_latent } """
182+ )
181183 return torch .zeros (1 )
182184
183185 one_hot_activation [0 , 0 , feature_id ] = 1.0
@@ -582,16 +584,19 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
582584
583585 if candidate is None :
584586 # This will raise an error if the key isn't found
585- raise ValueError (f"ERROR: Surprisal scorer could not find an SAE "
586- f"for hookpoint '{ hookpoint_str } ' in self.explainer_model" )
587+ raise ValueError (
588+ f"ERROR: Surprisal scorer could not find an SAE "
589+ f"for hookpoint '{ hookpoint_str } ' in self.explainer_model"
590+ )
587591
588592 if isinstance (candidate , functools .partial ):
589593 # As shown in load_sparsify.py, the SAE is in the 'sae' keyword.
590594 if candidate .keywords and "sae" in candidate .keywords :
591595 return candidate .keywords ["sae" ] # Unwrapped successfully
592596 else :
593597 # This will raise an error if the partial is missing the keyword
594- raise ValueError (f"""ERROR: Found a partial for
598+ raise ValueError (
599+ f"""ERROR: Found a partial for
595600 { hookpoint_str } but could not
596601 find the 'sae' keyword.
597602 func: { candidate .func }
@@ -600,8 +605,10 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
600605 )
601606
602607 # This will raise an error if the candidate isn't a partial
603- raise ValueError (f"""ERROR: Candidate for { hookpoint_str } was not a partial
604- object, which was not expected. Type: { type (candidate )} """ )
608+ raise ValueError (
609+ f"""ERROR: Candidate for { hookpoint_str } was not a partial
610+ object, which was not expected. Type: { type (candidate )} """
611+ )
605612
606613 def _get_intervention_direction (self , record : LatentRecord ) -> torch .Tensor :
607614 hookpoint_str = self .hookpoint_str or getattr (record , "hookpoint" , None )
0 commit comments