Skip to content

Commit

Permalink
fix config add e2e_ch.md e2e_res_img623_pg
Browse files Browse the repository at this point in the history
  • Loading branch information
JetHong committed Mar 10, 2021
1 parent 8452816 commit 051fe64
Show file tree
Hide file tree
Showing 20 changed files with 243 additions and 479 deletions.
61 changes: 26 additions & 35 deletions configs/e2e/e2e_r50_vd_pg.yml
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
Global:
use_gpu: False
use_gpu: True
epoch_num: 600
log_smooth_window: 20
print_batch_step: 2
print_batch_step: 10
save_model_dir: ./output/pg_r50_vd_tt/
save_epoch_step: 1
# evaluation is run every 5000 iterationss after the 4000th iteration
save_epoch_step: 10
# evaluation is run every 0 iterationss after the 1000th iteration
eval_batch_step: [ 0, 1000 ]
# if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: False
# 1. If pretrained_model is saved in static mode, such as classification pretrained model
# from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
load_static_weights: True
cal_metric_during_train: False
pretrained_model:
checkpoints:
Expand All @@ -19,7 +22,7 @@ Global:

Architecture:
model_type: e2e
algorithm: PG
algorithm: PGNet
Transform:
Backbone:
name: ResNet
Expand All @@ -34,28 +37,16 @@ Architecture:
Loss:
name: PGLoss

#Optimizer:
# name: Adam
# beta1: 0.9
# beta2: 0.999
# lr:
# name: Cosine
# learning_rate: 0.001
# warmup_epoch: 1
# regularizer:
# name: 'L2'
# factor: 0

Optimizer:
name: RMSProp
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
learning_rate: 0.001
decay_epochs: [ 40, 80, 120, 160, 200 ]
values: [ 0.001, 0.00033, 0.0001, 0.000033, 0.00001 ]
regularizer:
name: 'L2'
factor: 0.00005
factor: 0


PostProcess:
name: PGPostProcess
Expand All @@ -65,45 +56,45 @@ PostProcess:

Metric:
name: E2EMetric
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
main_indicator: f_score_e2e

Train:
dataset:
name: PGDateSet
label_file_list:
ratio_list:
data_format: textnet # textnet/partvgg
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
label_file_list: [./train_data/total_text/train/]
ratio_list: [1.0]
data_format: icdar
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- PGProcessTrain:
batch_size: 14
data_format: icdar
tcl_len: 64
min_crop_size: 24
min_text_size: 4
max_text_size: 512
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
- KeepKeys:
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
loader:
shuffle: True
drop_last: True
batch_size_per_card: 1
num_workers: 8
batch_size_per_card: 14
num_workers: 16

Eval:
dataset:
name: PGDateSet
name: PGDataSet
data_dir: ./train_data/
label_file_list:
label_file_list: [./train_data/total_text/test/]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- E2ELabelEncode:
label_list: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
max_len: 50
- E2EResizeForTest:
valid_set: totaltext
max_side_len: 768
Expand Down
120 changes: 120 additions & 0 deletions doc/doc_ch/e2e.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# 端到端文字识别

本节以partvgg/totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。

## 数据准备
支持两种不同的数据形式textnet / icdar ,分别为四点标注数据和十四点标注数据,十四点标注数据效果要比四点标注效果好
###数据形式为textnet

解压数据集和下载标注文件后,PaddleOCR/train_data/part_vgg_synth/train/ 有一个文件夹和一个文件,分别是:
```
/PaddleOCR/train_data/part_vgg_synth/train/
└─ image/ partvgg数据集的训练数据
└─ train_annotation_info.txt partvgg数据集的测试标注
```

提供的标注文件格式如下,中间用"\t"分隔:
```
" 图像文件名 图像标注信息--四点标注 图像标注信息--识别标注
119_nile_110_31 140.2 222.5 266.0 194.6 278.7 251.8 152.9 279.7 Path: 32.9 133.1 106.0 130.8 106.4 143.8 33.3 146.1 were 21.8 81.9 106.9 80.4 107.7 123.2 22.6 124.7 why
```
标注文件txt当中,其中每一行代表一组数据,以第一行为例。第一个代表同级目录image/下面的文件名, 后面每9个代表一组标注信息,前8个代表文本框的四个点坐标(x,y),从左上角的点开始顺时针排列。
最后一个代表文字的识别结果,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**


###数据形式为icdar
解压数据集和下载标注文件后,PaddleOCR/train_data/total_text/train/ 有两个文件夹,分别是:
```
/PaddleOCR/train_data/total_text/train/
└─ rgb/ total_text数据集的训练数据
└─ poly/ total_text数据集的测试标注
```

提供的标注文件格式如下,中间用"\t"分隔:
```
" 图像标注信息--十四点标注数据 图像标注信息--识别标注
1004.0,689.0,1019.0,698.0,1034.0,708.0,1049.0,718.0,1064.0,728.0,1079.0,738.0,1095.0,748.0,1094.0,774.0,1079.0,765.0,1065.0,756.0,1050.0,747.0,1036.0,738.0,1021.0,729.0,1007.0,721.0 EST
1102.0,755.0,1116.0,764.0,1131.0,773.0,1146.0,783.0,1161.0,792.0,1176.0,801.0,1191.0,811.0,1193.0,837.0,1178.0,828.0,1164.0,819.0,1150.0,810.0,1135.0,801.0,1121.0,792.0,1107.0,784.0 1972
```
标注文件当中,其中每一个txt文件代表一组数据,文件名同级目录rgb/下面的文件名。以第一行为例,前面28个代表文本框的十四个点坐标(x,y),从左上角的点开始顺时针排列。
最后一个代表文字的识别结果,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
如果您想在其他数据集上训练,可以按照上述形式构建标注文件。

## 快速启动训练

首先下载模型backbone的pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet_vd系列,
您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。
```shell
cd PaddleOCR/
下载ResNet50_vd的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar

# 解压预训练模型文件,以ResNet50_vd为例
tar -xf ./pretrain_models/ResNet50_vd_ssld_pretrained.tar ./pretrain_models/

# 注:正确解压backbone预训练权重文件后,文件夹下包含众多以网络层命名的权重文件,格式如下:
./pretrain_models/ResNet50_vd_ssld_pretrained/
└─ conv_last_bn_mean
└─ conv_last_bn_offset
└─ conv_last_bn_scale
└─ conv_last_bn_variance
└─ ......

```

#### 启动训练

*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*

```shell
# 单机单卡训练 e2e 模型
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml \
-o Global.pretrain_weights=./pretrain_models/ResNet50_vd_ssld_pretrained/ Global.load_static_weights=True
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml \
-o Global.pretrain_weights=./pretrain_models/ResNet50_vd_ssld_pretrained/ Global.load_static_weights=True
```


上述指令中,通过-c 选择训练使用configs/e2e/e2e_r50_vd_pg.yml配置文件。
有关配置文件的详细解释,请参考[链接](./config.md)

您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
```shell
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0001
```

#### 断点训练

如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
```shell
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model
```

**注意**`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。

## 指标评估

PaddleOCR计算三个OCR端到端相关的指标,分别是:Precision、Recall、Hmean。

运行如下代码,根据配置文件`e2e_r50_vd_pg.yml``save_res_path`指定的测试集检测结果文件,计算评估指标。

评估时设置后处理参数`max_side_len=768`,使用不同数据集、不同模型训练,可调整参数进行优化
训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。
```shell
python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{path/to/weights}/best_accuracy"
```



## 测试端到端效果

测试单张图像的端到端识别效果
```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
```

测试文件夹下所有图像的端到端识别效果
```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
```
Binary file added doc/imgs_results/e2e_res_img623_pg.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 2 additions & 3 deletions ppocr/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDateSet
from ppocr.data.pgnet_dataset import PGDataSet

__all__ = ['build_dataloader', 'transform', 'create_operators']

Expand All @@ -55,8 +55,7 @@ def term_mp(sig_num, frame):
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)


support_dict = ['SimpleDataSet', 'LMDBDateSet', 'PGDateSet']
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
Expand Down
14 changes: 8 additions & 6 deletions ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def __call__(self, data):


class E2ELabelEncode(object):
def __init__(self, label_list, **kwargs):
self.label_list = label_list
self.max_len = 50
def __init__(self, Lexicon_Table, max_len, **kwargs):
self.Lexicon_Table = Lexicon_Table
self.max_len = max_len
self.pad_num = len(self.Lexicon_Table)

def __call__(self, data):
text_label_index_list, temp_text = [], []
Expand All @@ -46,9 +47,10 @@ def __call__(self, data):
text = text.upper()
temp_text = []
for c_ in text:
if c_ in self.label_list:
temp_text.append(self.label_list.index(c_))
temp_text = temp_text + [36] * (self.max_len - len(temp_text))
if c_ in self.Lexicon_Table:
temp_text.append(self.Lexicon_Table.index(c_))
temp_text = temp_text + [self.pad_num] * (self.max_len -
len(temp_text))
text_label_index_list.append(temp_text)
data['strs'] = np.array(text_label_index_list)
return data
Expand Down
6 changes: 1 addition & 5 deletions ppocr/data/imaug/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def resize_image_type0(self, img):
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
# return img, np.array([h, w])
return img, [ratio_h, ratio_w]

def resize_image_type2(self, img):
Expand All @@ -206,7 +205,6 @@ def resize_image_type2(self, img):
resize_w = w
resize_h = h

# Fix the longer side
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
Expand Down Expand Up @@ -245,10 +243,8 @@ def __call__(self, data):
return data

def resize_image_for_totaltext(self, im, max_side_len=512):
"""
"""
h, w, _ = im.shape

h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
Expand Down
Loading

0 comments on commit 051fe64

Please sign in to comment.