Skip to content

Commit

Permalink
add distillation function
Browse files Browse the repository at this point in the history
  • Loading branch information
littletomatodonkey committed Jun 2, 2021
1 parent 551a682 commit ed02b91
Show file tree
Hide file tree
Showing 17 changed files with 405 additions and 79 deletions.
39 changes: 24 additions & 15 deletions ppocr/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,37 @@
# limitations under the License.

import copy
import paddle
import paddle.nn as nn

# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss

def build_loss(config):
# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
# rec loss
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss

# cls loss
from .cls_loss import ClsLoss

# e2e loss
from .e2e_pg_loss import PGLoss

# rec loss
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
# basic loss function
from .basic_loss import DistanceLoss

# cls loss
from .cls_loss import ClsLoss
# combined loss function
from .combined_loss import CombinedLoss

# e2e loss
from .e2e_pg_loss import PGLoss

def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss']

'SRNLoss', 'PGLoss', 'CombinedLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format(
Expand Down
101 changes: 101 additions & 0 deletions ppocr/losses/basic_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from paddle.nn import L1Loss
from paddle.nn import MSELoss as L2Loss
from paddle.nn import SmoothL1Loss


class CELoss(nn.Layer):
def __init__(self, name="loss_ce", epsilon=None):
super().__init__()
self.name = name
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon

def _labelsmoothing(self, target, class_num):
if target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
return soft_target

def forward(self, x, label):
loss_dict = {}
if self.epsilon is not None:
class_num = x.shape[-1]
label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=-1)
loss = paddle.sum(x * label, axis=-1)
else:
if label.shape[-1] == x.shape[-1]:
label = F.softmax(label, axis=-1)
soft_label = True
else:
soft_label = False
loss = F.cross_entropy(x, label=label, soft_label=soft_label)

loss_dict[self.name] = paddle.mean(loss)
return loss_dict


class DMLLoss(nn.Layer):
"""
DMLLoss
"""

def __init__(self, name="loss_dml"):
super().__init__()
self.name = name

def forward(self, out1, out2):
loss_dict = {}
soft_out1 = F.softmax(out1, axis=-1)
log_soft_out1 = paddle.log(soft_out1)
soft_out2 = F.softmax(out2, axis=-1)
log_soft_out2 = paddle.log(soft_out2)
loss = (F.kl_div(
log_soft_out1, soft_out2, reduction='batchmean') + F.kl_div(
log_soft_out2, soft_out1, reduction='batchmean')) / 2.0
loss_dict[self.name] = loss
return loss_dict


class DistanceLoss(nn.Layer):
"""
DistanceLoss:
mode: loss mode
name: loss key in the output dict
"""

def __init__(self, mode="l2", name="loss_dist", **kargs):
assert mode in ["l1", "l2", "smooth_l1"]
if mode == "l1":
self.loss_func = nn.L1Loss(**kargs)
elif mode == "l1":
self.loss_func = nn.MSELoss(**kargs)
elif mode == "smooth_l1":
self.loss_func = nn.SmoothL1Loss(**kargs)

self.name = "{}_{}".format(name, mode)

def forward(self, x, y):
return {self.name: self.loss_func(x, y)}
2 changes: 1 addition & 1 deletion ppocr/losses/cls_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, **kwargs):
super(ClsLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean')

def __call__(self, predicts, batch):
def forward(self, predicts, batch):
label = batch[1]
loss = self.loss_func(input=predicts, label=label)
return {'loss': loss}
57 changes: 57 additions & 0 deletions ppocr/losses/combined_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn

from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationDMLLoss


class CombinedLoss(nn.Layer):
"""
CombinedLoss:
a combionation of loss function
"""

def __init__(self, loss_config_list=None):
super().__init__()
self.loss_func = []
self.loss_weight = []
assert isinstance(loss_config_list, list), (
'operator config should be a list')
for config in loss_config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
name = list(config)[0]
param = config[name]
assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys())
self.loss_weight.append(param.pop("weight"))
self.loss_func.append(eval(name)(**param))

def forward(self, input, batch, **kargs):
loss_dict = {}
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx]
loss = {
"{}_{}".format(key, idx): loss[key] * weight
for key in loss
}
loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
return loss_dict
76 changes: 76 additions & 0 deletions ppocr/losses/distillation_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import paddle
import paddle.nn as nn

from .rec_ctc_loss import CTCLoss
from .basic_loss import DMLLoss


class DistillationDMLLoss(DMLLoss):
"""
"""

def __init__(self,
model_name_list1=[],
model_name_list2=[],
key=None,
name="loss_dml"):
super().__init__(name=name)
if not isinstance(model_name_list1, (list, )):
model_name_list1 = [model_name_list1]
if not isinstance(model_name_list2, (list, )):
model_name_list2 = [model_name_list2]

assert len(model_name_list1) == len(model_name_list2)
self.model_name_list1 = model_name_list1
self.model_name_list2 = model_name_list2
self.key = key

def forward(self, predicts, batch):
loss_dict = dict()
for idx in range(len(self.model_name_list1)):
out1 = predicts[self.model_name_list1[idx]]
out2 = predicts[self.model_name_list2[idx]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
assert len(loss) == 1
loss = list(loss.values())[0]
loss_dict["{}_{}".format(self.name, idx)] = loss
return loss_dict


class DistillationCTCLoss(CTCLoss):
def __init__(self, model_name_list=[], key=None, name="loss_ctc"):
super().__init__()
self.model_name_list = model_name_list
self.key = key
self.name = name

def forward(self, predicts, batch):
loss_dict = dict()
for model_name in self.model_name_list:
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
loss = super().forward(out, batch)
if isinstance(loss, dict):
assert len(loss) == 1
loss = list(loss.values())[0]
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
2 changes: 1 addition & 1 deletion ppocr/losses/rec_ctc_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, **kwargs):
super(CTCLoss, self).__init__()
self.loss_func = nn.CTCLoss(blank=0, reduction='none')

def __call__(self, predicts, batch):
def forward(self, predicts, batch):
predicts = predicts.transpose((1, 0, 2))
N, B, _ = predicts.shape
preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
Expand Down
16 changes: 12 additions & 4 deletions ppocr/modeling/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@
# limitations under the License.

import copy
import importlib

from .base_model import BaseModel
from .distillation_model import DistillationModel

__all__ = ['build_model']


def build_model(config):
from .base_model import BaseModel

config = copy.deepcopy(config)
module_class = BaseModel(config)
return module_class
if not "name" in config:
arch = BaseModel(config)
else:
name = config.pop("name")
mod = importlib.import_module(__name__)
arch = getattr(mod, name)(config)
return arch
1 change: 0 additions & 1 deletion ppocr/modeling/architectures/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(self, config):
config (dict): the super parameters for module.
"""
super(BaseModel, self).__init__()

in_channels = config.get('in_channels', 3)
model_type = config['model_type']
# build transfrom,
Expand Down
Loading

0 comments on commit ed02b91

Please sign in to comment.