|
24 | 24 | import torch.nn.functional as F
|
25 | 25 | from crossfit import op
|
26 | 26 | from crossfit.backend.torch.hf.model import HFModel
|
27 |
| -from huggingface_hub import hf_hub_download |
| 27 | +from huggingface_hub import PyTorchModelHubMixin |
28 | 28 | from peft import PeftModel
|
29 |
| -from safetensors.torch import load_file |
30 | 29 | from torch.nn import Dropout, Linear
|
31 | 30 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
32 | 31 |
|
@@ -75,7 +74,7 @@ class AegisConfig:
|
75 | 74 | ]
|
76 | 75 |
|
77 | 76 |
|
78 |
| -class InstructionDataGuardNet(torch.nn.Module): |
| 77 | +class InstructionDataGuardNet(torch.nn.Module, PyTorchModelHubMixin): |
79 | 78 | def __init__(self, input_dim, dropout=0.7):
|
80 | 79 | super().__init__()
|
81 | 80 | self.input_dim = input_dim
|
@@ -180,12 +179,14 @@ def load_model(self, device: str = "cuda"):
|
180 | 179 | add_instruction_data_guard=self.config.add_instruction_data_guard,
|
181 | 180 | )
|
182 | 181 | if self.config.add_instruction_data_guard:
|
183 |
| - weights_path = hf_hub_download( |
184 |
| - repo_id=self.config.instruction_data_guard_path, |
185 |
| - filename="model.safetensors", |
| 182 | + model.instruction_data_guard_net = ( |
| 183 | + model.instruction_data_guard_net.from_pretrained( |
| 184 | + self.config.instruction_data_guard_path |
| 185 | + ) |
| 186 | + ) |
| 187 | + model.instruction_data_guard_net = model.instruction_data_guard_net.to( |
| 188 | + device |
186 | 189 | )
|
187 |
| - state_dict = load_file(weights_path) |
188 |
| - model.instruction_data_guard_net.load_state_dict(state_dict) |
189 | 190 | model.instruction_data_guard_net.eval()
|
190 | 191 |
|
191 | 192 | model = model.to(device)
|
|
0 commit comments