Skip to content

Commit 87d0cc7

Browse files
authored
Use PyTorchModelHubMixin for InstructionDataGuardNet (#416)
Signed-off-by: Sarah Yurick <[email protected]>
1 parent b4c67b5 commit 87d0cc7

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

nemo_curator/classifiers/aegis.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@
2424
import torch.nn.functional as F
2525
from crossfit import op
2626
from crossfit.backend.torch.hf.model import HFModel
27-
from huggingface_hub import hf_hub_download
27+
from huggingface_hub import PyTorchModelHubMixin
2828
from peft import PeftModel
29-
from safetensors.torch import load_file
3029
from torch.nn import Dropout, Linear
3130
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
3231

@@ -75,7 +74,7 @@ class AegisConfig:
7574
]
7675

7776

78-
class InstructionDataGuardNet(torch.nn.Module):
77+
class InstructionDataGuardNet(torch.nn.Module, PyTorchModelHubMixin):
7978
def __init__(self, input_dim, dropout=0.7):
8079
super().__init__()
8180
self.input_dim = input_dim
@@ -180,12 +179,14 @@ def load_model(self, device: str = "cuda"):
180179
add_instruction_data_guard=self.config.add_instruction_data_guard,
181180
)
182181
if self.config.add_instruction_data_guard:
183-
weights_path = hf_hub_download(
184-
repo_id=self.config.instruction_data_guard_path,
185-
filename="model.safetensors",
182+
model.instruction_data_guard_net = (
183+
model.instruction_data_guard_net.from_pretrained(
184+
self.config.instruction_data_guard_path
185+
)
186+
)
187+
model.instruction_data_guard_net = model.instruction_data_guard_net.to(
188+
device
186189
)
187-
state_dict = load_file(weights_path)
188-
model.instruction_data_guard_net.load_state_dict(state_dict)
189190
model.instruction_data_guard_net.eval()
190191

191192
model = model.to(device)

0 commit comments

Comments
 (0)