forked from PaddlePaddle/PaddleOCR
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
551a682
commit ed02b91
Showing
17 changed files
with
405 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
#Licensed under the Apache License, Version 2.0 (the "License"); | ||
#you may not use this file except in compliance with the License. | ||
#You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
#Unless required by applicable law or agreed to in writing, software | ||
#distributed under the License is distributed on an "AS IS" BASIS, | ||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
#See the License for the specific language governing permissions and | ||
#limitations under the License. | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
|
||
from paddle.nn import L1Loss | ||
from paddle.nn import MSELoss as L2Loss | ||
from paddle.nn import SmoothL1Loss | ||
|
||
|
||
class CELoss(nn.Layer): | ||
def __init__(self, name="loss_ce", epsilon=None): | ||
super().__init__() | ||
self.name = name | ||
if epsilon is not None and (epsilon <= 0 or epsilon >= 1): | ||
epsilon = None | ||
self.epsilon = epsilon | ||
|
||
def _labelsmoothing(self, target, class_num): | ||
if target.shape[-1] != class_num: | ||
one_hot_target = F.one_hot(target, class_num) | ||
else: | ||
one_hot_target = target | ||
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon) | ||
soft_target = paddle.reshape(soft_target, shape=[-1, class_num]) | ||
return soft_target | ||
|
||
def forward(self, x, label): | ||
loss_dict = {} | ||
if self.epsilon is not None: | ||
class_num = x.shape[-1] | ||
label = self._labelsmoothing(label, class_num) | ||
x = -F.log_softmax(x, axis=-1) | ||
loss = paddle.sum(x * label, axis=-1) | ||
else: | ||
if label.shape[-1] == x.shape[-1]: | ||
label = F.softmax(label, axis=-1) | ||
soft_label = True | ||
else: | ||
soft_label = False | ||
loss = F.cross_entropy(x, label=label, soft_label=soft_label) | ||
|
||
loss_dict[self.name] = paddle.mean(loss) | ||
return loss_dict | ||
|
||
|
||
class DMLLoss(nn.Layer): | ||
""" | ||
DMLLoss | ||
""" | ||
|
||
def __init__(self, name="loss_dml"): | ||
super().__init__() | ||
self.name = name | ||
|
||
def forward(self, out1, out2): | ||
loss_dict = {} | ||
soft_out1 = F.softmax(out1, axis=-1) | ||
log_soft_out1 = paddle.log(soft_out1) | ||
soft_out2 = F.softmax(out2, axis=-1) | ||
log_soft_out2 = paddle.log(soft_out2) | ||
loss = (F.kl_div( | ||
log_soft_out1, soft_out2, reduction='batchmean') + F.kl_div( | ||
log_soft_out2, soft_out1, reduction='batchmean')) / 2.0 | ||
loss_dict[self.name] = loss | ||
return loss_dict | ||
|
||
|
||
class DistanceLoss(nn.Layer): | ||
""" | ||
DistanceLoss: | ||
mode: loss mode | ||
name: loss key in the output dict | ||
""" | ||
|
||
def __init__(self, mode="l2", name="loss_dist", **kargs): | ||
assert mode in ["l1", "l2", "smooth_l1"] | ||
if mode == "l1": | ||
self.loss_func = nn.L1Loss(**kargs) | ||
elif mode == "l1": | ||
self.loss_func = nn.MSELoss(**kargs) | ||
elif mode == "smooth_l1": | ||
self.loss_func = nn.SmoothL1Loss(**kargs) | ||
|
||
self.name = "{}_{}".format(name, mode) | ||
|
||
def forward(self, x, y): | ||
return {self.name: self.loss_func(x, y)} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
|
||
from .distillation_loss import DistillationCTCLoss | ||
from .distillation_loss import DistillationDMLLoss | ||
|
||
|
||
class CombinedLoss(nn.Layer): | ||
""" | ||
CombinedLoss: | ||
a combionation of loss function | ||
""" | ||
|
||
def __init__(self, loss_config_list=None): | ||
super().__init__() | ||
self.loss_func = [] | ||
self.loss_weight = [] | ||
assert isinstance(loss_config_list, list), ( | ||
'operator config should be a list') | ||
for config in loss_config_list: | ||
assert isinstance(config, | ||
dict) and len(config) == 1, "yaml format error" | ||
name = list(config)[0] | ||
param = config[name] | ||
assert "weight" in param, "weight must be in param, but param just contains {}".format( | ||
param.keys()) | ||
self.loss_weight.append(param.pop("weight")) | ||
self.loss_func.append(eval(name)(**param)) | ||
|
||
def forward(self, input, batch, **kargs): | ||
loss_dict = {} | ||
for idx, loss_func in enumerate(self.loss_func): | ||
loss = loss_func(input, batch, **kargs) | ||
if isinstance(loss, paddle.Tensor): | ||
loss = {"loss_{}_{}".format(str(loss), idx): loss} | ||
weight = self.loss_weight[idx] | ||
loss = { | ||
"{}_{}".format(key, idx): loss[key] * weight | ||
for key in loss | ||
} | ||
loss_dict.update(loss) | ||
loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) | ||
return loss_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
#Licensed under the Apache License, Version 2.0 (the "License"); | ||
#you may not use this file except in compliance with the License. | ||
#You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
#Unless required by applicable law or agreed to in writing, software | ||
#distributed under the License is distributed on an "AS IS" BASIS, | ||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
#See the License for the specific language governing permissions and | ||
#limitations under the License. | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
|
||
from .rec_ctc_loss import CTCLoss | ||
from .basic_loss import DMLLoss | ||
|
||
|
||
class DistillationDMLLoss(DMLLoss): | ||
""" | ||
""" | ||
|
||
def __init__(self, | ||
model_name_list1=[], | ||
model_name_list2=[], | ||
key=None, | ||
name="loss_dml"): | ||
super().__init__(name=name) | ||
if not isinstance(model_name_list1, (list, )): | ||
model_name_list1 = [model_name_list1] | ||
if not isinstance(model_name_list2, (list, )): | ||
model_name_list2 = [model_name_list2] | ||
|
||
assert len(model_name_list1) == len(model_name_list2) | ||
self.model_name_list1 = model_name_list1 | ||
self.model_name_list2 = model_name_list2 | ||
self.key = key | ||
|
||
def forward(self, predicts, batch): | ||
loss_dict = dict() | ||
for idx in range(len(self.model_name_list1)): | ||
out1 = predicts[self.model_name_list1[idx]] | ||
out2 = predicts[self.model_name_list2[idx]] | ||
if self.key is not None: | ||
out1 = out1[self.key] | ||
out2 = out2[self.key] | ||
loss = super().forward(out1, out2) | ||
if isinstance(loss, dict): | ||
assert len(loss) == 1 | ||
loss = list(loss.values())[0] | ||
loss_dict["{}_{}".format(self.name, idx)] = loss | ||
return loss_dict | ||
|
||
|
||
class DistillationCTCLoss(CTCLoss): | ||
def __init__(self, model_name_list=[], key=None, name="loss_ctc"): | ||
super().__init__() | ||
self.model_name_list = model_name_list | ||
self.key = key | ||
self.name = name | ||
|
||
def forward(self, predicts, batch): | ||
loss_dict = dict() | ||
for model_name in self.model_name_list: | ||
out = predicts[model_name] | ||
if self.key is not None: | ||
out = out[self.key] | ||
loss = super().forward(out, batch) | ||
if isinstance(loss, dict): | ||
assert len(loss) == 1 | ||
loss = list(loss.values())[0] | ||
loss_dict["{}_{}".format(self.name, model_name)] = loss | ||
return loss_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.