diff --git a/sae_lens/toolkit/pretrained_sae_loaders.py b/sae_lens/toolkit/pretrained_sae_loaders.py index 396272d12..eac539ba4 100644 --- a/sae_lens/toolkit/pretrained_sae_loaders.py +++ b/sae_lens/toolkit/pretrained_sae_loaders.py @@ -1039,6 +1039,174 @@ def sparsify_disk_loader( return cfg_dict, state_dict +def get_dictionary_learning_config_from_hf( + repo_id: str, + folder_name: str, + device: str, + force_download: bool = False, + cfg_overrides: dict[str, Any] | None = None, +) -> dict[str, Any]: + config_filename = f"{folder_name}/config.json" + config_path = hf_hub_download( + repo_id, filename=config_filename, force_download=force_download + ) + sae_path = Path(config_path).parent + return get_dictionary_learning_config_from_disk( + sae_path, device=device, cfg_overrides=cfg_overrides + ) + + +def get_dictionary_learning_config_from_disk( + path: str | Path, + device: str | None = None, + cfg_overrides: dict[str, Any] | None = None, +) -> dict[str, Any]: + path = Path(path) + + with open(path / "config.json") as f: + config = json.load(f) + + trainer = config["trainer"] + buffer = config.get("buffer", {}) + trainer_class = trainer["trainer_class"] + + if trainer_class in { + "StandardTrainer", + "PAnnealTrainer", + "StandardTrainerAprilUpdate", + "BatchTopKTrainer", + "MatryoshkaBatchTopKTrainer", + "TopKTrainer", + }: + architecture = "standard" if "TopK" not in trainer_class else "topk" + elif trainer_class == "GatedSAETrainer": + architecture = "gated" + elif trainer_class == "JumpReluTrainer": + architecture = "jumprelu" + else: + architecture = "standard" + + if "TopK" in trainer_class: + activation_fn_str = "topk" + activation_fn_kwargs = {"k": trainer["k"]} + else: + activation_fn_str = "relu" + activation_fn_kwargs = {} + + hook_name = f"blocks.{trainer['layer']}.hook_resid_post" + + cfg_dict: dict[str, Any] = { + "architecture": architecture, + "d_in": trainer["activation_dim"], + "d_sae": trainer["dict_size"], + "dtype": "float32", + "device": device or "cpu", + "model_name": trainer["lm_name"].split("/")[-1], + "hook_name": hook_name, + "hook_layer": trainer["layer"], + "hook_head_index": None, + "activation_fn_str": activation_fn_str, + **( + {"activation_fn_kwargs": activation_fn_kwargs} + if activation_fn_kwargs + else {} + ), + "apply_b_dec_to_input": True, + "finetuning_scaling_factor": False, + "sae_lens_training_version": None, + "prepend_bos": True, + "dataset_path": "monology/pile-uncopyrighted", + "context_size": buffer.get("ctx_len", 128), + "normalize_activations": "none", + "neuronpedia_id": None, + "dataset_trust_remote_code": True, + } + + if cfg_overrides: + cfg_dict.update(cfg_overrides) + + return cfg_dict + + +def dictionary_learning_huggingface_loader( + repo_id: str, + folder_name: str, + device: str = "cpu", + force_download: bool = False, + cfg_overrides: dict[str, Any] | None = None, +) -> tuple[dict[str, Any], dict[str, torch.Tensor], None]: + sae_path = hf_hub_download( + repo_id, filename=f"{folder_name}/ae.pt", force_download=force_download + ) + config_path = f"{folder_name}/config.json" + hf_hub_download(repo_id, filename=config_path, force_download=force_download) + cfg_dict, state_dict = dictionary_learning_disk_loader( + Path(sae_path).parent, device=device, cfg_overrides=cfg_overrides + ) + return cfg_dict, state_dict, None + + +def dictionary_learning_disk_loader( + path: str | Path, + device: str = "cpu", + cfg_overrides: dict[str, Any] | None = None, +) -> tuple[dict[str, Any], dict[str, torch.Tensor]]: + cfg_dict = get_dictionary_learning_config_from_disk(path, device, cfg_overrides) + + weight_path = Path(path) / "ae.pt" + state_dict_loaded = ( + load_file(weight_path, device=device) + if weight_path.suffix == ".safetensors" + else torch.load(weight_path, map_location=device) + ) + state_dict_loaded.pop("group_sizes", None) + + dtype = DTYPE_MAP[cfg_dict["dtype"]] + + W_enc = ( + state_dict_loaded["W_enc"] + if "W_enc" in state_dict_loaded + else state_dict_loaded["encoder.weight"].T + ).to(dtype) + + W_dec = ( + state_dict_loaded["W_dec"] + if "W_dec" in state_dict_loaded + else state_dict_loaded["decoder.weight"].T + ).to(dtype) + + if "b_enc" in state_dict_loaded: + b_enc = state_dict_loaded["b_enc"].to(dtype) + elif "encoder.bias" in state_dict_loaded: + b_enc = state_dict_loaded["encoder.bias"].to(dtype) + else: + b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device) + + if "b_dec" in state_dict_loaded: + b_dec = state_dict_loaded["b_dec"].to(dtype) + elif "bias" in state_dict_loaded: + b_dec = state_dict_loaded["bias"].to(dtype) + elif "decoder.bias" in state_dict_loaded: + b_dec = state_dict_loaded["decoder.bias"].to(dtype) + else: + b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device) + + state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec} + + architecture = cfg_dict["architecture"] + if architecture == "jumprelu" and "threshold" in state_dict_loaded: + state_dict["threshold"] = state_dict_loaded["threshold"].to(dtype) + if architecture == "gated": + state_dict.pop("b_enc", None) + state_dict["r_mag"] = state_dict_loaded["r_mag"].to(dtype) + state_dict["b_mag"] = state_dict_loaded.get( + "b_mag", state_dict_loaded["mag_bias"] + ).to(dtype) + state_dict["b_gate"] = state_dict_loaded["gate_bias"].to(dtype) + + return cfg_dict, state_dict + + NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = { "sae_lens": sae_lens_huggingface_loader, "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader, @@ -1048,6 +1216,7 @@ def sparsify_disk_loader( "dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1, "deepseek_r1": deepseek_r1_sae_huggingface_loader, "sparsify": sparsify_huggingface_loader, + "dictionary_learning": dictionary_learning_huggingface_loader, } @@ -1060,4 +1229,5 @@ def sparsify_disk_loader( "dictionary_learning_1": get_dictionary_learning_config_1_from_hf, "deepseek_r1": get_deepseek_r1_config_from_hf, "sparsify": get_sparsify_config_from_hf, + "dictionary_learning": get_dictionary_learning_config_from_hf, }