Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions sae_lens/toolkit/pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,174 @@
return cfg_dict, state_dict


def get_dictionary_learning_config_from_hf(
repo_id: str,
folder_name: str,
device: str,
force_download: bool = False,
cfg_overrides: dict[str, Any] | None = None,
) -> dict[str, Any]:
config_filename = f"{folder_name}/config.json"
config_path = hf_hub_download(

Check warning on line 1050 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1049-L1050

Added lines #L1049 - L1050 were not covered by tests
repo_id, filename=config_filename, force_download=force_download
)
sae_path = Path(config_path).parent
return get_dictionary_learning_config_from_disk(

Check warning on line 1054 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1053-L1054

Added lines #L1053 - L1054 were not covered by tests
sae_path, device=device, cfg_overrides=cfg_overrides
)


def get_dictionary_learning_config_from_disk(
path: str | Path,
device: str | None = None,
cfg_overrides: dict[str, Any] | None = None,
) -> dict[str, Any]:
path = Path(path)

Check warning on line 1064 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1064

Added line #L1064 was not covered by tests

with open(path / "config.json") as f:
config = json.load(f)

Check warning on line 1067 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1066-L1067

Added lines #L1066 - L1067 were not covered by tests

trainer = config["trainer"]
buffer = config.get("buffer", {})
trainer_class = trainer["trainer_class"]

Check warning on line 1071 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1069-L1071

Added lines #L1069 - L1071 were not covered by tests

if trainer_class in {
"StandardTrainer",
"PAnnealTrainer",
"StandardTrainerAprilUpdate",
"BatchTopKTrainer",
"MatryoshkaBatchTopKTrainer",
"TopKTrainer",
}:
architecture = "standard" if "TopK" not in trainer_class else "topk"

Check warning on line 1081 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1081

Added line #L1081 was not covered by tests
elif trainer_class == "GatedSAETrainer":
architecture = "gated"

Check warning on line 1083 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1083

Added line #L1083 was not covered by tests
elif trainer_class == "JumpReluTrainer":
architecture = "jumprelu"

Check warning on line 1085 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1085

Added line #L1085 was not covered by tests
else:
architecture = "standard"

Check warning on line 1087 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1087

Added line #L1087 was not covered by tests

if "TopK" in trainer_class:
activation_fn_str = "topk"
activation_fn_kwargs = {"k": trainer["k"]}

Check warning on line 1091 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1090-L1091

Added lines #L1090 - L1091 were not covered by tests
else:
activation_fn_str = "relu"
activation_fn_kwargs = {}

Check warning on line 1094 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1093-L1094

Added lines #L1093 - L1094 were not covered by tests

hook_name = f"blocks.{trainer['layer']}.hook_resid_post"

Check warning on line 1096 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1096

Added line #L1096 was not covered by tests

cfg_dict: dict[str, Any] = {

Check warning on line 1098 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1098

Added line #L1098 was not covered by tests
"architecture": architecture,
"d_in": trainer["activation_dim"],
"d_sae": trainer["dict_size"],
"dtype": "float32",
"device": device or "cpu",
"model_name": trainer["lm_name"].split("/")[-1],
"hook_name": hook_name,
"hook_layer": trainer["layer"],
"hook_head_index": None,
"activation_fn_str": activation_fn_str,
**(
{"activation_fn_kwargs": activation_fn_kwargs}
if activation_fn_kwargs
else {}
),
"apply_b_dec_to_input": True,
"finetuning_scaling_factor": False,
"sae_lens_training_version": None,
"prepend_bos": True,
"dataset_path": "monology/pile-uncopyrighted",
"context_size": buffer.get("ctx_len", 128),
"normalize_activations": "none",
"neuronpedia_id": None,
"dataset_trust_remote_code": True,
}

if cfg_overrides:
cfg_dict.update(cfg_overrides)

Check warning on line 1126 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1126

Added line #L1126 was not covered by tests

return cfg_dict

Check warning on line 1128 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1128

Added line #L1128 was not covered by tests


def dictionary_learning_huggingface_loader(
repo_id: str,
folder_name: str,
device: str = "cpu",
force_download: bool = False,
cfg_overrides: dict[str, Any] | None = None,
) -> tuple[dict[str, Any], dict[str, torch.Tensor], None]:
sae_path = hf_hub_download(

Check warning on line 1138 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1138

Added line #L1138 was not covered by tests
repo_id, filename=f"{folder_name}/ae.pt", force_download=force_download
)
config_path = f"{folder_name}/config.json"
hf_hub_download(repo_id, filename=config_path, force_download=force_download)
cfg_dict, state_dict = dictionary_learning_disk_loader(

Check warning on line 1143 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1141-L1143

Added lines #L1141 - L1143 were not covered by tests
Path(sae_path).parent, device=device, cfg_overrides=cfg_overrides
)
return cfg_dict, state_dict, None

Check warning on line 1146 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1146

Added line #L1146 was not covered by tests


def dictionary_learning_disk_loader(
path: str | Path,
device: str = "cpu",
cfg_overrides: dict[str, Any] | None = None,
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
cfg_dict = get_dictionary_learning_config_from_disk(path, device, cfg_overrides)

Check warning on line 1154 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1154

Added line #L1154 was not covered by tests

weight_path = Path(path) / "ae.pt"
state_dict_loaded = (

Check warning on line 1157 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1156-L1157

Added lines #L1156 - L1157 were not covered by tests
load_file(weight_path, device=device)
if weight_path.suffix == ".safetensors"
else torch.load(weight_path, map_location=device)
)
state_dict_loaded.pop("group_sizes", None)

Check warning on line 1162 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1162

Added line #L1162 was not covered by tests

dtype = DTYPE_MAP[cfg_dict["dtype"]]

Check warning on line 1164 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1164

Added line #L1164 was not covered by tests

W_enc = (

Check warning on line 1166 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1166

Added line #L1166 was not covered by tests
state_dict_loaded["W_enc"]
if "W_enc" in state_dict_loaded
else state_dict_loaded["encoder.weight"].T
).to(dtype)

W_dec = (

Check warning on line 1172 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1172

Added line #L1172 was not covered by tests
state_dict_loaded["W_dec"]
if "W_dec" in state_dict_loaded
else state_dict_loaded["decoder.weight"].T
).to(dtype)

if "b_enc" in state_dict_loaded:
b_enc = state_dict_loaded["b_enc"].to(dtype)

Check warning on line 1179 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1179

Added line #L1179 was not covered by tests
elif "encoder.bias" in state_dict_loaded:
b_enc = state_dict_loaded["encoder.bias"].to(dtype)

Check warning on line 1181 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1181

Added line #L1181 was not covered by tests
else:
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)

Check warning on line 1183 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1183

Added line #L1183 was not covered by tests

if "b_dec" in state_dict_loaded:
b_dec = state_dict_loaded["b_dec"].to(dtype)

Check warning on line 1186 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1186

Added line #L1186 was not covered by tests
elif "bias" in state_dict_loaded:
b_dec = state_dict_loaded["bias"].to(dtype)

Check warning on line 1188 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1188

Added line #L1188 was not covered by tests
elif "decoder.bias" in state_dict_loaded:
b_dec = state_dict_loaded["decoder.bias"].to(dtype)

Check warning on line 1190 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1190

Added line #L1190 was not covered by tests
else:
b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)

Check warning on line 1192 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1192

Added line #L1192 was not covered by tests

state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}

Check warning on line 1194 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1194

Added line #L1194 was not covered by tests

architecture = cfg_dict["architecture"]

Check warning on line 1196 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1196

Added line #L1196 was not covered by tests
if architecture == "jumprelu" and "threshold" in state_dict_loaded:
state_dict["threshold"] = state_dict_loaded["threshold"].to(dtype)

Check warning on line 1198 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1198

Added line #L1198 was not covered by tests
if architecture == "gated":
state_dict.pop("b_enc", None)
state_dict["r_mag"] = state_dict_loaded["r_mag"].to(dtype)
state_dict["b_mag"] = state_dict_loaded.get(

Check warning on line 1202 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1200-L1202

Added lines #L1200 - L1202 were not covered by tests
"b_mag", state_dict_loaded["mag_bias"]
).to(dtype)
state_dict["b_gate"] = state_dict_loaded["gate_bias"].to(dtype)

Check warning on line 1205 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1205

Added line #L1205 was not covered by tests

return cfg_dict, state_dict

Check warning on line 1207 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L1207

Added line #L1207 was not covered by tests


NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
"sae_lens": sae_lens_huggingface_loader,
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
Expand All @@ -1048,6 +1216,7 @@
"dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
"deepseek_r1": deepseek_r1_sae_huggingface_loader,
"sparsify": sparsify_huggingface_loader,
"dictionary_learning": dictionary_learning_huggingface_loader,
}


Expand All @@ -1060,4 +1229,5 @@
"dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
"deepseek_r1": get_deepseek_r1_config_from_hf,
"sparsify": get_sparsify_config_from_hf,
"dictionary_learning": get_dictionary_learning_config_from_hf,
}