-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
176 lines (157 loc) · 6.42 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet)
import inspect
import paddle
from paddle import nn
from paddle.nn import functional as F
class CTCLoss(nn.Layer):
def __init__(self,
blank=0,
reduction='sum',
batch_average=False,
grad_norm_type=None):
super().__init__()
# last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
self.batch_average = batch_average
assert grad_norm_type in ('instance', 'batch', 'frame', None)
self.norm_by_times = False
self.norm_by_batchsize = False
self.norm_by_total_logits_len = False
if grad_norm_type is None:
# no grad norm
pass
elif grad_norm_type == 'instance':
self.norm_by_times = True
elif grad_norm_type == 'batch':
self.norm_by_batchsize = True
elif grad_norm_type == 'frame':
self.norm_by_total_logits_len = True
else:
raise ValueError(f"CTCLoss Grad Norm no support {grad_norm_type}")
kwargs = {
"norm_by_times": self.norm_by_times,
"norm_by_batchsize": self.norm_by_batchsize,
"norm_by_total_logits_len": self.norm_by_total_logits_len,
}
# Derive only the args which the func has
try:
param = inspect.signature(self.loss.forward).parameters
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
self._kwargs = {k: v for k, v in kwargs.items() if k in param}
_notin = {k: v for k, v in kwargs.items() if k not in param}
print(
f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}")
def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss.
Args:
logits ([paddle.Tensor]): [B, Tmax, D]
ys_pad ([paddle.Tensor]): [B, Tmax]
hlens ([paddle.Tensor]): [B]
ys_lens ([paddle.Tensor]): [B]
Returns:
[paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}.
"""
B = logits.shape[0]
# warp-ctc need logits, and do softmax on logits by itself
# warp-ctc need activation with shape [T, B, V + 1]
# logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2])
ys_pad = ys_pad.astype(paddle.int32)
loss = self.loss(logits, ys_pad, hlens, ys_lens, **self._kwargs)
if self.batch_average:
# Batch-size average
loss = loss / B
return loss
class LabelSmoothingLoss(nn.Layer):
"""Label-smoothing loss.
In a standard CE loss, the label's data distribution is:
[0,1,2] ->
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
In the smoothing version CE Loss,some probabilities
are taken from the true label prob (1.0) and are divided
among other labels.
e.g.
smoothing=0.1
[0,1,2] ->
[
[0.9, 0.05, 0.05],
[0.05, 0.9, 0.05],
[0.05, 0.05, 0.9],
]
"""
def __init__(self,
size: int,
padding_idx: int,
smoothing: float,
normalize_length: bool=False):
"""Label-smoothing loss.
Args:
size (int): the number of class
padding_idx (int): padding class id which will be ignored for loss
smoothing (float): smoothing rate (0.0 means the conventional CE)
normalize_length (bool):
True, normalize loss by sequence length;
False, normalize loss by batch size.
Defaults to False.
"""
super().__init__()
self.size = size
self.padding_idx = padding_idx
self.smoothing = smoothing
self.confidence = 1.0 - smoothing
self.normalize_length = normalize_length
self.criterion = nn.KLDivLoss(reduction="none")
def forward(self, x: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor:
"""Compute loss between x and target.
The model outputs and data labels tensors are flatten to
(batch*seqlen, class) shape and a mask is applied to the
padding part which should not be calculated for loss.
Args:
x (paddle.Tensor): prediction (batch, seqlen, class)
target (paddle.Tensor):
target signal masked with self.padding_id (batch, seqlen)
Returns:
loss (paddle.Tensor) : The KL loss, scalar float value
"""
B, T, D = x.shape
assert D == self.size
x = x.reshape((-1, self.size))
target = target.reshape([-1])
# use zeros_like instead of torch.no_grad() for true_dist,
# since no_grad() can not be exported by JIT
true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
#TODO(Hui Zhang): target = target * (1 - ignore) # avoid -1 index
target = target*(1-ignore.astype(target.dtype)) # avoid -1 index
# true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
target_mask = F.one_hot(target, self.size)
true_dist *= (1 - target_mask)
true_dist += target_mask * self.confidence
kl = self.criterion(F.log_softmax(x, axis=1), true_dist)
#TODO(Hui Zhang): sum not support bool type
#total = len(target) - int(ignore.sum())
total = len(target) - int(ignore.astype(target.dtype).sum())
denom = total if self.normalize_length else B
#numer = (kl * (1 - ignore)).sum()
numer = (kl*(1-ignore.unsqueeze(1).astype(kl.dtype))).sum()
return numer / denom