Skip to content

Commit c81b861

Browse files
committed
plot for intervention scoring
2 parents c0ef075 + cabb151 commit c81b861

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

delphi/scorers/intervention/surprisal_intervention_scorer.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,10 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor:
176176
one_hot_activation = torch.zeros(1, 1, d_latent, device=sae_device)
177177

178178
if feature_id >= d_latent:
179-
print(f"""DEBUG: ERROR - Feature ID {feature_id} is out of bounds
180-
for d_latent {d_latent}""")
179+
print(
180+
f"""DEBUG: ERROR - Feature ID {feature_id} is out of bounds
181+
for d_latent {d_latent}"""
182+
)
181183
return torch.zeros(1)
182184

183185
one_hot_activation[0, 0, feature_id] = 1.0
@@ -582,16 +584,19 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
582584

583585
if candidate is None:
584586
# This will raise an error if the key isn't found
585-
raise ValueError(f"ERROR: Surprisal scorer could not find an SAE "
586-
f"for hookpoint '{hookpoint_str}' in self.explainer_model")
587+
raise ValueError(
588+
f"ERROR: Surprisal scorer could not find an SAE "
589+
f"for hookpoint '{hookpoint_str}' in self.explainer_model"
590+
)
587591

588592
if isinstance(candidate, functools.partial):
589593
# As shown in load_sparsify.py, the SAE is in the 'sae' keyword.
590594
if candidate.keywords and "sae" in candidate.keywords:
591595
return candidate.keywords["sae"] # Unwrapped successfully
592596
else:
593597
# This will raise an error if the partial is missing the keyword
594-
raise ValueError(f"""ERROR: Found a partial for
598+
raise ValueError(
599+
f"""ERROR: Found a partial for
595600
{hookpoint_str} but could not
596601
find the 'sae' keyword.
597602
func: {candidate.func}
@@ -600,8 +605,10 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
600605
)
601606

602607
# This will raise an error if the candidate isn't a partial
603-
raise ValueError(f"""ERROR: Candidate for {hookpoint_str} was not a partial
604-
object, which was not expected. Type: {type(candidate)}""")
608+
raise ValueError(
609+
f"""ERROR: Candidate for {hookpoint_str} was not a partial
610+
object, which was not expected. Type: {type(candidate)}"""
611+
)
605612

606613
def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor:
607614
hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None)

0 commit comments

Comments
 (0)