@@ -167,11 +167,16 @@ def cross_entropy_loss(output, target, ignore_index=None):
167
167
def mse_loss (output , target ):
168
168
return (output - target ).sqr ().mean ()
169
169
170
- def bce_loss (output , target , size_average = True ):
170
+ def bce_loss (output , target , weight = None , size_average = True ):
171
+ loss = - (target * jt .log (jt .maximum (output , 1e-20 )) + (1 - target ) * jt .log (jt .maximum (1 - output , 1e-20 )))
172
+
173
+ if weight is not None :
174
+ loss *= weight
175
+
171
176
if size_average :
172
- return - ( target * jt . log ( jt . maximum ( output , 1e-20 )) + ( 1 - target ) * jt . log ( jt . maximum ( 1 - output , 1e-20 ))) .mean ()
177
+ return loss .mean ()
173
178
else :
174
- return - ( target * jt . log ( jt . maximum ( output , 1e-20 )) + ( 1 - target ) * jt . log ( jt . maximum ( 1 - output , 1e-20 ))) .sum ()
179
+ return loss .sum ()
175
180
176
181
def l1_loss (output , target ):
177
182
return (output - target ).abs ().mean ()
@@ -189,10 +194,11 @@ def execute(self, output, target):
189
194
return mse_loss (output , target )
190
195
191
196
class BCELoss (Module ):
192
- def __init__ (self ):
193
- pass
194
- def execute (self , output , target , size_average = True ):
195
- return bce_loss (output , target , size_average )
197
+ def __init__ (self , weight = None , size_average = True ):
198
+ self .weight = weight
199
+ self .size_average = size_average
200
+ def execute (self , output , target ):
201
+ return bce_loss (output , target , self .weight , self .size_average )
196
202
197
203
class L1Loss (Module ):
198
204
def __init__ (self ):
@@ -201,14 +207,17 @@ def execute(self, output, target):
201
207
return l1_loss (output , target )
202
208
203
209
class BCEWithLogitsLoss (Module ):
204
- def __init__ (self ):
210
+ def __init__ (self , weight = None , size_average = True ):
205
211
self .sigmoid = Sigmoid ()
206
- self .bce = BCELoss ()
207
- def execute (self , output , target , size_average = True ):
212
+ self .bce = BCELoss (weight , size_average )
213
+ def execute (self , output , target ):
208
214
output = self .sigmoid (output )
209
- output = self .bce (output , target , size_average )
215
+ output = self .bce (output , target )
210
216
return output
211
217
218
+ def binary_cross_entropy_with_logits (input , target , weight = None , size_average = True ):
219
+ return BCEWithLogitsLoss (weight , size_average )(input , target )
220
+
212
221
def softmax (x , dim = None ):
213
222
if dim is None :
214
223
x = (x - x .max ()).exp ()
0 commit comments