Skip to content

Commit 37058a9

Browse files
authored
Change NSFW Model (#307)
* Change download for NSFW model Signed-off-by: Ryan Wolf <[email protected]> * Fix model init Signed-off-by: Ryan Wolf <[email protected]> * Fix embedding size Signed-off-by: Ryan Wolf <[email protected]> --------- Signed-off-by: Ryan Wolf <[email protected]>
1 parent 017ff97 commit 37058a9

File tree

1 file changed

+35
-29
lines changed
  • nemo_curator/image/classifiers

1 file changed

+35
-29
lines changed

nemo_curator/image/classifiers/nsfw.py

+35-29
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
import zipfile
1516
from typing import Optional
1617

1718
import requests
@@ -23,33 +24,35 @@
2324

2425

2526
# MLP code taken from LAION's CLIP-based-NSFW-Detector
26-
# https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/h14_nsfw_model.py
27-
class H14_NSFW_Detector(nn.Module):
28-
def __init__(self, input_size=1024):
27+
# https://github.com/LAION-AI/CLIP-based-NSFW-Detector/issues/7
28+
class Normalization(nn.Module):
29+
def __init__(self, shape):
2930
super().__init__()
30-
self.input_size = input_size
31-
self.layers = nn.Sequential(
32-
nn.Linear(self.input_size, 1024),
33-
nn.ReLU(),
34-
nn.Dropout(0.2),
35-
nn.Linear(1024, 2048),
36-
nn.ReLU(),
37-
nn.Dropout(0.2),
38-
nn.Linear(2048, 1024),
39-
nn.ReLU(),
40-
nn.Dropout(0.2),
41-
nn.Linear(1024, 256),
42-
nn.ReLU(),
43-
nn.Dropout(0.2),
44-
nn.Linear(256, 128),
45-
nn.ReLU(),
46-
nn.Dropout(0.2),
47-
nn.Linear(128, 16),
48-
nn.Linear(16, 1),
49-
)
31+
self.register_buffer("mean", torch.zeros(shape))
32+
self.register_buffer("variance", torch.ones(shape))
33+
34+
def forward(self, x):
35+
return (x - self.mean) / self.variance.sqrt()
36+
37+
38+
class NSFWModel(nn.Module):
39+
def __init__(self):
40+
super().__init__()
41+
self.norm = Normalization([768])
42+
self.linear_1 = nn.Linear(768, 64)
43+
self.linear_2 = nn.Linear(64, 512)
44+
self.linear_3 = nn.Linear(512, 256)
45+
self.linear_4 = nn.Linear(256, 1)
46+
self.act = nn.ReLU()
47+
self.act_out = nn.Sigmoid()
5048

5149
def forward(self, x):
52-
return self.layers(x)
50+
x = self.norm(x)
51+
x = self.act(self.linear_1(x))
52+
x = self.act(self.linear_2(x))
53+
x = self.act(self.linear_3(x))
54+
x = self.act_out(self.linear_4(x))
55+
return x
5356

5457

5558
class NsfwClassifier(ImageClassifier):
@@ -66,7 +69,7 @@ def __init__(
6669
pred_column=pred_column,
6770
pred_type=float,
6871
batch_size=batch_size,
69-
embedding_size=1024,
72+
embedding_size=768,
7073
)
7174

7275
if model_path is None:
@@ -76,21 +79,24 @@ def __init__(
7679

7780
@staticmethod
7881
def _get_default_model():
79-
weights_name = "h14_nsfw.pth"
82+
weights_name = "clip_autokeras_binary_nsfw.pth"
8083
model_path = os.path.join(NEMO_CURATOR_HOME, weights_name)
8184
os.makedirs(NEMO_CURATOR_HOME, exist_ok=True)
8285

8386
if not os.path.exists(model_path):
84-
url = f"https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/{weights_name}?raw=true"
87+
url = "https://github.com/LAION-AI/CLIP-based-NSFW-Detector/files/10250461/clip_autokeras_binary_nsfw.zip"
8588
r = requests.get(url)
8689

87-
with open(model_path, "wb") as f:
90+
raw_zip_path = os.path.join(NEMO_CURATOR_HOME, "nsfw.zip")
91+
with open(raw_zip_path, "wb") as f:
8892
f.write(r.content)
93+
with zipfile.ZipFile(raw_zip_path, "r") as f:
94+
f.extractall(NEMO_CURATOR_HOME)
8995

9096
return model_path
9197

9298
def load_model(self, device):
93-
model = H14_NSFW_Detector(input_size=self.embedding_size).to(device)
99+
model = NSFWModel().to(device)
94100
weights = torch.load(self.model_path, map_location=torch.device("cpu"))
95101
model.load_state_dict(weights)
96102
model.eval()

0 commit comments

Comments
 (0)