12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import os
15
+ import zipfile
15
16
from typing import Optional
16
17
17
18
import requests
23
24
24
25
25
26
# MLP code taken from LAION's CLIP-based-NSFW-Detector
26
- # https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/h14_nsfw_model.py
27
- class H14_NSFW_Detector (nn .Module ):
28
- def __init__ (self , input_size = 1024 ):
27
+ # https://github.com/LAION-AI/CLIP-based-NSFW-Detector/issues/7
28
+ class Normalization (nn .Module ):
29
+ def __init__ (self , shape ):
29
30
super ().__init__ ()
30
- self .input_size = input_size
31
- self .layers = nn .Sequential (
32
- nn .Linear (self .input_size , 1024 ),
33
- nn .ReLU (),
34
- nn .Dropout (0.2 ),
35
- nn .Linear (1024 , 2048 ),
36
- nn .ReLU (),
37
- nn .Dropout (0.2 ),
38
- nn .Linear (2048 , 1024 ),
39
- nn .ReLU (),
40
- nn .Dropout (0.2 ),
41
- nn .Linear (1024 , 256 ),
42
- nn .ReLU (),
43
- nn .Dropout (0.2 ),
44
- nn .Linear (256 , 128 ),
45
- nn .ReLU (),
46
- nn .Dropout (0.2 ),
47
- nn .Linear (128 , 16 ),
48
- nn .Linear (16 , 1 ),
49
- )
31
+ self .register_buffer ("mean" , torch .zeros (shape ))
32
+ self .register_buffer ("variance" , torch .ones (shape ))
33
+
34
+ def forward (self , x ):
35
+ return (x - self .mean ) / self .variance .sqrt ()
36
+
37
+
38
+ class NSFWModel (nn .Module ):
39
+ def __init__ (self ):
40
+ super ().__init__ ()
41
+ self .norm = Normalization ([768 ])
42
+ self .linear_1 = nn .Linear (768 , 64 )
43
+ self .linear_2 = nn .Linear (64 , 512 )
44
+ self .linear_3 = nn .Linear (512 , 256 )
45
+ self .linear_4 = nn .Linear (256 , 1 )
46
+ self .act = nn .ReLU ()
47
+ self .act_out = nn .Sigmoid ()
50
48
51
49
def forward (self , x ):
52
- return self .layers (x )
50
+ x = self .norm (x )
51
+ x = self .act (self .linear_1 (x ))
52
+ x = self .act (self .linear_2 (x ))
53
+ x = self .act (self .linear_3 (x ))
54
+ x = self .act_out (self .linear_4 (x ))
55
+ return x
53
56
54
57
55
58
class NsfwClassifier (ImageClassifier ):
@@ -66,7 +69,7 @@ def __init__(
66
69
pred_column = pred_column ,
67
70
pred_type = float ,
68
71
batch_size = batch_size ,
69
- embedding_size = 1024 ,
72
+ embedding_size = 768 ,
70
73
)
71
74
72
75
if model_path is None :
@@ -76,21 +79,24 @@ def __init__(
76
79
77
80
@staticmethod
78
81
def _get_default_model ():
79
- weights_name = "h14_nsfw .pth"
82
+ weights_name = "clip_autokeras_binary_nsfw .pth"
80
83
model_path = os .path .join (NEMO_CURATOR_HOME , weights_name )
81
84
os .makedirs (NEMO_CURATOR_HOME , exist_ok = True )
82
85
83
86
if not os .path .exists (model_path ):
84
- url = f "https://github.com/LAION-AI/CLIP-based-NSFW-Detector/blob/main/ { weights_name } ?raw=true "
87
+ url = "https://github.com/LAION-AI/CLIP-based-NSFW-Detector/files/10250461/clip_autokeras_binary_nsfw.zip "
85
88
r = requests .get (url )
86
89
87
- with open (model_path , "wb" ) as f :
90
+ raw_zip_path = os .path .join (NEMO_CURATOR_HOME , "nsfw.zip" )
91
+ with open (raw_zip_path , "wb" ) as f :
88
92
f .write (r .content )
93
+ with zipfile .ZipFile (raw_zip_path , "r" ) as f :
94
+ f .extractall (NEMO_CURATOR_HOME )
89
95
90
96
return model_path
91
97
92
98
def load_model (self , device ):
93
- model = H14_NSFW_Detector ( input_size = self . embedding_size ).to (device )
99
+ model = NSFWModel ( ).to (device )
94
100
weights = torch .load (self .model_path , map_location = torch .device ("cpu" ))
95
101
model .load_state_dict (weights )
96
102
model .eval ()
0 commit comments