Skip to content

Commit

Permalink
fix nrtr export inference model
Browse files Browse the repository at this point in the history
  • Loading branch information
Topdu committed Oct 13, 2021
1 parent 6c6f19d commit 6594bc2
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 166 deletions.
6 changes: 3 additions & 3 deletions configs/rec/rec_mtb_nrtr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Architecture:
name: Transformer
d_model: 512
num_encoder_layers: 6
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation.
beam_size: -1 # When Beam size is greater than 0, it means to use beam search when evaluation.


Loss:
Expand All @@ -65,7 +65,7 @@ Train:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- NRTRDecodeImage: # load image
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
Expand All @@ -85,7 +85,7 @@ Eval:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- NRTRDecodeImage: # load image
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
Expand Down
7 changes: 6 additions & 1 deletion ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,21 +174,26 @@ def __init__(self,
super(NRTRLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)

def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len - 1:
return None
data['length'] = np.array(len(text))
text.insert(0, 2)
text.append(3)
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data

def add_special_char(self, dict_character):
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character


class CTCLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """

Expand Down
23 changes: 22 additions & 1 deletion ppocr/data/imaug/rec_img_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,33 @@ def __call__(self, data):


class NRTRRecResizeImg(object):
def __init__(self, image_shape, resize_type, **kwargs):
def __init__(self, image_shape, resize_type, padding=False, **kwargs):
self.image_shape = image_shape
self.resize_type = resize_type
self.padding = padding

def __call__(self, data):
img = data['image']
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
image_shape = self.image_shape
if self.padding:
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
norm_img = np.expand_dims(resized_image, -1)
norm_img = norm_img.transpose((2, 0, 1))
resized_image = norm_img.astype(np.float32) / 128. - 1.
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
data['image'] = padding_im
return data
if self.resize_type == 'PIL':
image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
Expand Down
1 change: 0 additions & 1 deletion ppocr/data/simple_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import os
import random
from paddle.io import Dataset

from .imaug import transform, create_operators


Expand Down
8 changes: 5 additions & 3 deletions ppocr/modeling/backbones/rec_nrtr_mtb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from paddle import nn
import paddle


class MTB(nn.Layer):
Expand Down Expand Up @@ -40,7 +41,8 @@ def forward(self, images):
x = self.block(images)
if self.cnn_num == 2:
# (b, w, h, c)
x = x.transpose([0, 3, 2, 1])
x_shape = x.shape
x = x.reshape([x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
x = paddle.transpose(x, [0, 3, 2, 1])
x_shape = paddle.shape(x)
x = paddle.reshape(
x, [x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
return x
89 changes: 37 additions & 52 deletions ppocr/modeling/heads/multiheadAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def forward(self,
value,
key_padding_mask=None,
incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None):
"""
Inputs of forward function
Expand All @@ -88,46 +86,42 @@ def forward(self,
attn_output: [target length, batch size, embed dim]
attn_output_weights: [batch size, target length, sequence length]
"""
tgt_len, bsz, embed_dim = query.shape
assert embed_dim == self.embed_dim
assert list(query.shape) == [tgt_len, bsz, embed_dim]
assert key.shape == value.shape

q_shape = paddle.shape(query)
src_shape = paddle.shape(key)
q = self._in_proj_q(query)
k = self._in_proj_k(key)
v = self._in_proj_v(value)
q *= self.scaling

q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose(
[1, 0, 2])
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
[1, 0, 2])
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
[1, 0, 2])

src_len = k.shape[1]

q = paddle.transpose(
paddle.reshape(
q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]),
[1, 2, 0, 3])
k = paddle.transpose(
paddle.reshape(
k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
[1, 2, 0, 3])
v = paddle.transpose(
paddle.reshape(
v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
[1, 2, 0, 3])
if key_padding_mask is not None:
assert key_padding_mask.shape[0] == bsz
assert key_padding_mask.shape[1] == src_len

attn_output_weights = paddle.bmm(q, k.transpose([0, 2, 1]))
assert list(attn_output_weights.
shape) == [bsz * self.num_heads, tgt_len, src_len]

assert key_padding_mask.shape[0] == q_shape[1]
assert key_padding_mask.shape[1] == src_shape[0]
attn_output_weights = paddle.matmul(q,
paddle.transpose(k, [0, 1, 3, 2]))
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0)
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.reshape(
[bsz, self.num_heads, tgt_len, src_len])
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
attn_output_weights = paddle.reshape(
attn_output_weights,
[q_shape[1], self.num_heads, q_shape[0], src_shape[0]])
key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2)
key = paddle.cast(key, 'float32')
y = paddle.full(
shape=paddle.shape(key), dtype='float32', fill_value='-inf')
y = paddle.where(key == 0., key, y)
attn_output_weights += y
attn_output_weights = attn_output_weights.reshape(
[bsz * self.num_heads, tgt_len, src_len])

attn_output_weights = F.softmax(
attn_output_weights.astype('float32'),
axis=-1,
Expand All @@ -136,43 +130,34 @@ def forward(self,
attn_output_weights = F.dropout(
attn_output_weights, p=self.dropout, training=self.training)

attn_output = paddle.bmm(attn_output_weights, v)
assert list(attn_output.
shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn_output = attn_output.transpose([1, 0, 2]).reshape(
[tgt_len, bsz, embed_dim])
attn_output = paddle.matmul(attn_output_weights, v)
attn_output = paddle.reshape(
paddle.transpose(attn_output, [2, 0, 1, 3]),
[q_shape[0], q_shape[1], self.embed_dim])
attn_output = self.out_proj(attn_output)

if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.reshape(
[bsz, self.num_heads, tgt_len, src_len])
attn_output_weights = attn_output_weights.sum(
axis=1) / self.num_heads
else:
attn_output_weights = None
return attn_output, attn_output_weights
return attn_output

def _in_proj_q(self, query):
query = query.transpose([1, 2, 0])
query = paddle.transpose(query, [1, 2, 0])
query = paddle.unsqueeze(query, axis=2)
res = self.conv1(query)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
res = paddle.transpose(res, [2, 0, 1])
return res

def _in_proj_k(self, key):
key = key.transpose([1, 2, 0])
key = paddle.transpose(key, [1, 2, 0])
key = paddle.unsqueeze(key, axis=2)
res = self.conv2(key)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
res = paddle.transpose(res, [2, 0, 1])
return res

def _in_proj_v(self, value):
value = value.transpose([1, 2, 0]) #(1, 2, 0)
value = paddle.transpose(value, [1, 2, 0]) #(1, 2, 0)
value = paddle.unsqueeze(value, axis=2)
res = self.conv3(value)
res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1])
res = paddle.transpose(res, [2, 0, 1])
return res
Loading

0 comments on commit 6594bc2

Please sign in to comment.