Skip to content

Commit

Permalink
add srn for dygraph
Browse files Browse the repository at this point in the history
  • Loading branch information
tink2123 committed Dec 30, 2020
1 parent de3e2e7 commit c1fd466
Show file tree
Hide file tree
Showing 28 changed files with 1,594 additions and 70 deletions.
6 changes: 3 additions & 3 deletions configs/rec/rec_mv3_none_bilstm_ctc.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Global:
use_gpu: true
use_gpu: True
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
Expand Down Expand Up @@ -59,7 +59,7 @@ Metric:

Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
Expand All @@ -78,7 +78,7 @@ Train:

Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
Expand Down
4 changes: 2 additions & 2 deletions configs/rec/rec_mv3_none_none_ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Metric:

Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
Expand All @@ -77,7 +77,7 @@ Train:

Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
Expand Down
4 changes: 2 additions & 2 deletions configs/rec/rec_mv3_tps_bilstm_ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Metric:

Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
Expand All @@ -82,7 +82,7 @@ Train:

Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
Expand Down
4 changes: 2 additions & 2 deletions configs/rec/rec_r34_vd_none_bilstm_ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Metric:

Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
Expand All @@ -77,7 +77,7 @@ Train:

Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
Expand Down
4 changes: 2 additions & 2 deletions configs/rec/rec_r34_vd_none_none_ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Metric:

Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
Expand All @@ -75,7 +75,7 @@ Train:

Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
Expand Down
4 changes: 2 additions & 2 deletions configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Metric:

Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
Expand All @@ -81,7 +81,7 @@ Train:

Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
Expand Down
106 changes: 106 additions & 0 deletions configs/rec/rec_r50_fpn_srn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
Global:
use_gpu: True
epoch_num: 72
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/rec/srn
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 5000]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path:
character_type: en
max_text_length: 25
num_heads: 8
infer_mode: False
use_space_char: False


Optimizer:
name: Adam
lr:
name: Cosine
learning_rate: 0.0001

Architecture:
model_type: rec
algorithm: SRN
in_channels: 1
Transform:
Backbone:
name: ResNetFPN
Head:
name: SRNHead
max_text_length: 25
num_heads: 8
num_encoder_TUs: 2
num_decoder_TUs: 4
hidden_dims: 512

Loss:
name: SRNLoss

PostProcess:
name: SRNLabelDecode

Metric:
name: RecMetric
main_indicator: acc

Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/srn_train_data_duiqi
#label_file_list: ["./train_data/ic15_data/1.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SRNLabelEncode: # Class handling label
- SRNRecResizeImg:
image_shape: [1, 64, 256]
- KeepKeys:
keep_keys: ['image',
'label',
'length',
'encoder_word_pos',
'gsrm_word_pos',
'gsrm_slf_attn_bias1',
'gsrm_slf_attn_bias2'] # dataloader will return list in this order
loader:
shuffle: False
batch_size_per_card: 64
drop_last: True
num_workers: 4

Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SRNLabelEncode: # Class handling label
- SRNRecResizeImg:
image_shape: [1, 64, 256]
- KeepKeys:
keep_keys: ['image',
'label',
'length',
'encoder_word_pos',
'gsrm_word_pos',
'gsrm_slf_attn_bias1',
'gsrm_slf_attn_bias2']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 32
num_workers: 4
4 changes: 2 additions & 2 deletions ppocr/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDateSet
from ppocr.data.lmdb_dataset import LMDBDataSet

__all__ = ['build_dataloader', 'transform', 'create_operators']

Expand All @@ -54,7 +54,7 @@ def term_mp(sig_num, frame):
def build_dataloader(config, mode, device, logger):
config = copy.deepcopy(config)

support_dict = ['SimpleDataSet', 'LMDBDateSet']
support_dict = ['SimpleDataSet', 'LMDBDataSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
Expand Down
2 changes: 1 addition & 1 deletion ppocr/data/imaug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop

from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .randaugment import RandAugment
from .operators import *
from .label_ops import *
Expand Down
48 changes: 48 additions & 0 deletions ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def __init__(self,
support_character_type, character_type)

self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
Expand Down Expand Up @@ -213,3 +215,49 @@ def get_beg_end_flag_idx(self, beg_or_end):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx


class SRNLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """

def __init__(self,
max_text_length=25,
character_dict_path=None,
character_type='en',
use_space_char=False,
**kwargs):
super(SRNLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)

def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
return dict_character

def __call__(self, data):
text = data['label']
text = self.encode(text)
char_num = len(self.character_str)
if text is None:
return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text))
text = text + [char_num] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data

def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]

def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
Loading

0 comments on commit c1fd466

Please sign in to comment.