diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml index c89de02b8c..635c392d70 100644 --- a/configs/rec/rec_mtb_nrtr.yml +++ b/configs/rec/rec_mtb_nrtr.yml @@ -43,7 +43,7 @@ Architecture: name: MTB cnn_num: 2 Head: - name: TransformerOptim + name: Transformer d_model: 512 num_encoder_layers: 6 beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation. @@ -69,8 +69,9 @@ Train: img_mode: BGR channel_first: False - NRTRLabelEncode: # Class handling label - - PILResize: + - NRTRRecResizeImg: image_shape: [100, 32] + resize_type: PIL # PIL or OpenCV - KeepKeys: keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order loader: @@ -88,8 +89,9 @@ Eval: img_mode: BGR channel_first: False - NRTRLabelEncode: # Class handling label - - PILResize: + - NRTRRecResizeImg: image_shape: [100, 32] + resize_type: PIL # PIL or OpenCV - KeepKeys: keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order loader: diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 5c384c1d31..4418d075cb 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg from .randaugment import RandAugment from .copy_paste import CopyPaste from .operators import * diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 13a5c71de4..e914d38446 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -42,30 +42,21 @@ def __call__(self, data): data['image'] = norm_img return data -class PILResize(object): - def __init__(self, image_shape, **kwargs): - self.image_shape = image_shape - - def __call__(self, data): - img = data['image'] - image_pil = Image.fromarray(np.uint8(img)) - norm_img = image_pil.resize(self.image_shape, Image.ANTIALIAS) - norm_img = np.array(norm_img) - norm_img = np.expand_dims(norm_img, -1) - norm_img = norm_img.transpose((2, 0, 1)) - data['image'] = norm_img.astype(np.float32) / 128. - 1. - return data - -class CVResize(object): - def __init__(self, image_shape, **kwargs): +class NRTRRecResizeImg(object): + def __init__(self, image_shape, resize_type, **kwargs): self.image_shape = image_shape + self.resize_type = resize_type def __call__(self, data): img = data['image'] - #print(img) - norm_img = cv2.resize(img,self.image_shape) - norm_img = np.expand_dims(norm_img, -1) + if self.resize_type == 'PIL': + image_pil = Image.fromarray(np.uint8(img)) + img = image_pil.resize(self.image_shape, Image.ANTIALIAS) + img = np.array(img) + if self.resize_type == 'OpenCV': + img = cv2.resize(img, self.image_shape) + norm_img = np.expand_dims(img, -1) norm_img = norm_img.transpose((2, 0, 1)) data['image'] = norm_img.astype(np.float32) / 128. - 1. return data diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_nrtr_loss.py index 915f506ded..41714dd2a3 100644 --- a/ppocr/losses/rec_nrtr_loss.py +++ b/ppocr/losses/rec_nrtr_loss.py @@ -3,34 +3,26 @@ import paddle.nn.functional as F -def cal_performance(pred, tgt): - - pred = pred.max(1)[1] - tgt = tgt.contiguous().view(-1) - non_pad_mask = tgt.ne(0) - n_correct = pred.eq(tgt) - n_correct = n_correct.masked_select(non_pad_mask).sum().item() - return n_correct - - class NRTRLoss(nn.Layer): - def __init__(self,smoothing=True, **kwargs): + def __init__(self, smoothing=True, **kwargs): super(NRTRLoss, self).__init__() - self.loss_func = nn.CrossEntropyLoss(reduction='mean',ignore_index=0) + self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0) self.smoothing = smoothing def forward(self, pred, batch): pred = pred.reshape([-1, pred.shape[2]]) max_len = batch[2].max() - tgt = batch[1][:,1:2+max_len] - tgt = tgt.reshape([-1] ) + tgt = batch[1][:, 1:2 + max_len] + tgt = tgt.reshape([-1]) if self.smoothing: eps = 0.1 n_class = pred.shape[1] one_hot = F.one_hot(tgt, pred.shape[1]) one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) log_prb = F.log_softmax(pred, axis=1) - non_pad_mask = paddle.not_equal(tgt, paddle.zeros(tgt.shape,dtype='int64')) + non_pad_mask = paddle.not_equal( + tgt, paddle.zeros( + tgt.shape, dtype='int64')) loss = -(one_hot * log_prb).sum(axis=1) loss = loss.masked_select(non_pad_mask).mean() else: diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 11fd4b26d5..572ec4aa8a 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -26,13 +26,13 @@ def build_head(config): from .rec_ctc_head import CTCHead from .rec_att_head import AttentionHead from .rec_srn_head import SRNHead - from .rec_nrtr_optim_head import TransformerOptim + from .rec_nrtr_head import Transformer # cls head from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead' + 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead' ] #table head diff --git a/ppocr/modeling/heads/multiheadAttention.py b/ppocr/modeling/heads/multiheadAttention.py index 4be3702552..651d4f577d 100755 --- a/ppocr/modeling/heads/multiheadAttention.py +++ b/ppocr/modeling/heads/multiheadAttention.py @@ -24,7 +24,7 @@ ones_ = constant_(value=1.) -class MultiheadAttentionOptim(nn.Layer): +class MultiheadAttention(nn.Layer): """Allows the model to jointly attend to information from different representation subspaces. See reference: Attention Is All You Need @@ -46,7 +46,7 @@ def __init__(self, bias=True, add_bias_kv=False, add_zero_attn=False): - super(MultiheadAttentionOptim, self).__init__() + super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout diff --git a/ppocr/modeling/heads/rec_nrtr_optim_head.py b/ppocr/modeling/heads/rec_nrtr_head.py similarity index 98% rename from ppocr/modeling/heads/rec_nrtr_optim_head.py rename to ppocr/modeling/heads/rec_nrtr_head.py index 63473c1184..05dba677b4 100644 --- a/ppocr/modeling/heads/rec_nrtr_optim_head.py +++ b/ppocr/modeling/heads/rec_nrtr_head.py @@ -21,7 +21,7 @@ from paddle.nn.initializer import XavierNormal as xavier_uniform_ from paddle.nn import Dropout, Linear, LayerNorm, Conv2D import numpy as np -from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim +from ppocr.modeling.heads.multiheadAttention import MultiheadAttention from paddle.nn.initializer import Constant as constant_ from paddle.nn.initializer import XavierNormal as xavier_normal_ @@ -29,7 +29,7 @@ ones_ = constant_(value=1.) -class TransformerOptim(nn.Layer): +class Transformer(nn.Layer): """A transformer model. User is able to modify the attributes as needed. The architechture is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and @@ -63,7 +63,7 @@ def __init__(self, out_channels=0, dst_vocab_size=99, scale_embedding=True): - super(TransformerOptim, self).__init__() + super(Transformer, self).__init__() self.embedding = Embeddings( d_model=d_model, vocab=dst_vocab_size, @@ -215,8 +215,7 @@ def collect_active_part(beamed_tensor, curr_active_inst_idx, n_curr_active_inst = len(curr_active_inst_idx) new_shape = (n_curr_active_inst * n_bm, *d_hs) - beamed_tensor = beamed_tensor.reshape( - [n_prev_active_inst, -1]) + beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1]) beamed_tensor = beamed_tensor.index_select( paddle.to_tensor(curr_active_inst_idx), axis=0) beamed_tensor = beamed_tensor.reshape([*new_shape]) @@ -486,7 +485,7 @@ def __init__(self, attention_dropout_rate=0.0, residual_dropout_rate=0.1): super(TransformerEncoderLayer, self).__init__() - self.self_attn = MultiheadAttentionOptim( + self.self_attn = MultiheadAttention( d_model, nhead, dropout=attention_dropout_rate) self.conv1 = Conv2D( @@ -555,9 +554,9 @@ def __init__(self, attention_dropout_rate=0.0, residual_dropout_rate=0.1): super(TransformerDecoderLayer, self).__init__() - self.self_attn = MultiheadAttentionOptim( + self.self_attn = MultiheadAttention( d_model, nhead, dropout=attention_dropout_rate) - self.multihead_attn = MultiheadAttentionOptim( + self.multihead_attn = MultiheadAttention( d_model, nhead, dropout=attention_dropout_rate) self.conv1 = Conv2D(