99import logging
1010import os
1111import sys
12+ from time import time
1213
14+ import cv2
15+ import numpy as np
1316import yaml
1417from addict import Dict
1518from PIL import Image
2225from mindspore import Tensor , get_context , set_auto_parallel_context , set_context
2326from mindspore .communication import get_group_size , get_rank , init
2427
28+ from deploy .py_infer .src .infer_args import str2bool # noqa
2529from mindocr .data import build_dataset
2630from mindocr .data .transforms import create_transforms , run_transforms
2731from mindocr .models import build_model
2832from mindocr .postprocess import build_postprocess
2933from mindocr .utils .visualize import draw_boxes , show_imgs
3034from tools .arg_parser import _merge_options , _parse_options
35+ from tools .infer .text .utils import get_image_paths
3136from tools .modelarts_adapter .modelarts import modelarts_setup
3237
3338__dir__ = os .path .dirname (os .path .abspath (__file__ ))
@@ -155,21 +160,7 @@ def predict_single_step(cfg, save_res=True):
155160 )
156161
157162 # 3.Build model
158- amp_level = cfg .system .get ("amp_level_infer" , "O0" )
159- if get_context ("device_target" ) == "GPU" and amp_level == "O3" :
160- logger .warning (
161- "Model evaluation does not support amp_level O3 on GPU currently. "
162- "The program has switched to amp_level O2 automatically."
163- )
164- amp_level = "O2"
165- cfg .model .backbone .pretrained = False
166- if cfg .predict .ckpt_load_path is None :
167- logger .warning (
168- f"No ckpt is available for { cfg .model .task } , "
169- "please check your configuration of 'predict.ckpt_load_path' in the yaml file."
170- )
171- network = build_model (cfg .model , ckpt_load_path = cfg .predict .ckpt_load_path , amp_level = amp_level )
172- network .set_train (False )
163+ network = build_model_from_config (cfg )
173164
174165 # 4.Build postprocessor for network output
175166 postprocessor = build_postprocess (cfg .postprocess )
@@ -230,72 +221,220 @@ def predict_single_step(cfg, save_res=True):
230221 return preds_list
231222
232223
233- def predict_system (args , det_cfg , rec_cfg ):
234- """Run predict for both det and rec task"""
235- # merge image_dir option in model config
236- det_cfg .predict .dataset .dataset_root = ""
237- det_cfg .predict .dataset .data_dir = args .image_dir
238- output_save_dir = det_cfg .predict .output_save_dir or "./output"
239-
240- # get det result from predict
241- preds_list = predict_single_step (det_cfg , save_res = False )
242-
243- # set amp level
244- amp_level = det_cfg .system .get ("amp_level_infer" , "O0" )
224+ def build_model_from_config (cfg ):
225+ amp_level = cfg .system .get ("amp_level_infer" , "O0" )
245226 if get_context ("device_target" ) == "GPU" and amp_level == "O3" :
246227 logger .warning (
247228 "Model evaluation does not support amp_level O3 on GPU currently. "
248229 "The program has switched to amp_level O2 automatically."
249230 )
250231 amp_level = "O2"
251-
252- # create preprocess and postprocess for rec task
253- transforms = create_transforms (rec_cfg .predict .dataset .transform_pipeline )
254- postprocessor = build_postprocess (rec_cfg .postprocess )
255-
256- # build rec model from yaml
257- rec_cfg .model .backbone .pretrained = False
258- if rec_cfg .predict .ckpt_load_path is None :
232+ cfg .model .backbone .pretrained = False
233+ if cfg .predict .ckpt_load_path is None :
259234 logger .warning (
260- f"No ckpt is available for { rec_cfg .model .type } , "
235+ f"No ckpt is available for { cfg .model .task } , "
261236 "please check your configuration of 'predict.ckpt_load_path' in the yaml file."
262237 )
263- rec_network = build_model (rec_cfg .model , ckpt_load_path = rec_cfg .predict .ckpt_load_path , amp_level = amp_level )
264-
265- # start rec task
266- logger .info ("Start rec" )
267- img_list = [] # list of img_path
268- boxes_all = [] # list of boxes of all image
269- text_scores_all = [] # list of text and scores of all image
270- for preds_batch in tqdm (preds_list ):
271- # preds_batch is a dictionary of det prediction output, which contains det information of a batch
272- preds_batch ["texts" ] = []
273- preds_batch ["confs" ] = []
274- for i , crops in enumerate (preds_batch ["crops" ]):
275- # A batch may contain multiple images
276- img_path = preds_batch ["img_path" ][i ]
277- img_box = []
278- img_text_scores = []
279- for j , crop in enumerate (crops ):
280- # For each image, it may contain several crops
281- data = {"image" : crop }
282- data ["image_ori" ] = crop .copy ()
283- data ["image_shape" ] = crop .shape
284- data = run_transforms (data , transforms [1 :])
285- data = rec_network (Tensor (data ["image" ]).expand_dims (0 ))
286- out = postprocessor (data )
287- confs = out ["confs" ][0 ]
288- if confs > 0.5 :
289- # Keep text with a confidence greater than 0.5
290- box = preds_batch ["polys" ][i ][j ]
291- text = out ["texts" ][0 ]
292- img_box .append (box )
293- img_text_scores .append ((text , confs ))
294- # Each image saves its path, box and texts_scores
295- img_list .append (img_path )
296- boxes_all .append (img_box )
297- text_scores_all .append (img_text_scores )
298- save_res (boxes_all , text_scores_all , img_list , save_path = os .path .join (output_save_dir , "system_results.txt" ))
238+ network = build_model (cfg .model , ckpt_load_path = cfg .predict .ckpt_load_path , amp_level = amp_level )
239+ network .set_train (False )
240+ return network
241+
242+
243+ def sort_polys (polys ):
244+ return sorted (polys , key = lambda points : (points [0 ][1 ], points [0 ][0 ]))
245+
246+
247+ def concat_crops (crops : list ):
248+ max_height = max (crop .shape [0 ] for crop in crops )
249+ resized_crops = []
250+ for crop in crops :
251+ h , w , c = crop .shape
252+ new_h = max_height
253+ new_w = int ((w / h ) * new_h )
254+
255+ resized_img = cv2 .resize (crop , (new_w , new_h ), interpolation = cv2 .INTER_LINEAR )
256+ resized_crops .append (resized_img )
257+ crops = np .concatenate (resized_crops , axis = 1 )
258+ return crops
259+
260+
261+ class Predict_System :
262+ def __init__ (self , det_cfg , rec_cfg , is_concat = False ):
263+ for transform in det_cfg .predict .dataset .transform_pipeline :
264+ if "DecodeImage" in transform :
265+ transform ["DecodeImage" ].update ({"keep_ori" : True })
266+ break
267+ self .det_transforms = create_transforms (det_cfg .predict .dataset .transform_pipeline )
268+ self .det_model = build_model_from_config (det_cfg )
269+ self .det_postprocess = build_postprocess (det_cfg .postprocess )
270+
271+ self .rec_batch_size = rec_cfg .predict .loader .batch_size
272+ self .rec_preprocess = create_transforms (rec_cfg .predict .dataset .transform_pipeline )
273+ self .rec_model = build_model_from_config (rec_cfg )
274+ self .rec_postprocess = build_postprocess (rec_cfg .postprocess )
275+
276+ self .is_concat = is_concat
277+
278+ def predict_rec (self , crops : list ):
279+ """
280+ Run text recognition serially for input images
281+
282+ Args:
283+ img_or_path_list: list of str for img path or np.array for RGB image
284+ do_visualize: visualize preprocess and final result and save them
285+
286+ Return:
287+ rec_res: list of tuple, where each tuple is (text, score) - text recognition result for each input image
288+ in order.
289+ where text is the predicted text string, score is its confidence score.
290+ e.g. [('apple', 0.9), ('bike', 1.0)]
291+ """
292+ rec_res = []
293+ num_crops = len (crops )
294+
295+ for idx in range (0 , num_crops , self .rec_batch_size ): # batch begin index i
296+ batch_begin = idx
297+ batch_end = min (idx + self .rec_batch_size , num_crops )
298+ logger .info (f"Rec img idx range: [{ batch_begin } , { batch_end } )" )
299+ # TODO: set max_wh_ratio to the maximum wh ratio of images in the batch. and update it for resize,
300+ # which may improve recognition accuracy in batch-mode
301+ # especially for long text image. max_wh_ratio=max(max_wh_ratio, img_w / img_h).
302+ # The short ones should be scaled with a.r. unchanged and padded to max width in batch.
303+
304+ # preprocess
305+ # TODO: run in parallel with multiprocessing
306+ img_batch = []
307+ for j in range (batch_begin , batch_end ): # image index j
308+ data = run_transforms ({"image" : crops [j ]}, self .rec_preprocess [1 :])
309+ img_batch .append (data ["image" ])
310+
311+ img_batch = np .stack (img_batch ) if len (img_batch ) > 1 else np .expand_dims (img_batch [0 ], axis = 0 )
312+
313+ # infer
314+ net_pred = self .rec_model (Tensor (img_batch ))
315+
316+ # postprocess
317+ batch_res = self .rec_postprocess (net_pred )
318+ rec_res .extend (list (zip (batch_res ["texts" ], batch_res ["confs" ])))
319+
320+ return rec_res
321+
322+ def predict (self , img_path ):
323+ """
324+ Detect and recognize texts in an image
325+
326+ Args:
327+ img_or_path (str or np.ndarray): path to image or image rgb values as a numpy array
328+
329+ Return:
330+ boxes (list): detected text boxes, in shape [num_boxes, num_points, 2], where the point coordinate (x, y)
331+ follows: x - horizontal (image width direction), y - vertical (image height)
332+ texts (list[tuple]): list of (text, score) where text is the recognized text string for each box,
333+ and score is the confidence score.
334+ time_profile (dict): record the time cost for each sub-task.
335+ """
336+
337+ time_profile = {}
338+ start = time ()
339+
340+ # detect text regions on an image
341+ data = {"img_path" : img_path }
342+ data = run_transforms (data , self .det_transforms )
343+ input_np = np .expand_dims (data ["image" ], axis = 0 )
344+ logits = self .det_model (Tensor (input_np ))
345+ pred = self .det_postprocess (logits , shape_list = np .expand_dims (data ["shape_list" ], axis = 0 ))
346+ polys = pred ["polys" ][0 ]
347+ scores = pred ["scores" ][0 ]
348+ pred = dict (polys = polys , scores = scores )
349+ det_res = validate_det_res (pred , data ["image_ori" ].shape [:2 ], min_poly_points = 3 , min_area = 3 )
350+ det_res ["img_ori" ] = data ["image_ori" ]
351+
352+ time_profile ["det" ] = time () - start
353+ polys = det_res ["polys" ].copy ()
354+ if len (polys ) == 0 :
355+ logger .warning (f"No text detected in { img_path } " )
356+ time_profile ["rec" ] = 0.0
357+ time_profile ["all" ] = time_profile ["det" ]
358+ return [], [], time_profile
359+ polys = sort_polys (polys )
360+ logger .info (f"Num detected text boxes: { len (polys )} \n Det time: { time_profile ['det' ]} " )
361+ if self .is_concat :
362+ logger .info ("After concatenating, 1 croped image will be recognized." )
363+
364+ # crop text regions
365+ crops = []
366+ for i in range (len (polys )):
367+ poly = polys [i ].astype (np .float32 )
368+ cropped_img = crop_text_region (data ["image_ori" ], poly , box_type = det_cfg .postprocess .box_type )
369+ crops .append (cropped_img )
370+
371+ # if self.save_crop_res:
372+ # cv2.imwrite(os.path.join(self.crop_res_save_dir, f"{fn}_crop_{i}.jpg"), cropped_img)
373+ # show_imgs(crops, is_bgr_img=False)
374+
375+ # recognize cropped images
376+ rs = time ()
377+ if self .is_concat :
378+ crops = [concat_crops (crops )]
379+ rec_res_all_crops = self .predict_rec (crops )
380+ time_profile ["rec" ] = time () - rs
381+
382+ logger .info (
383+ "Recognized texts: \n "
384+ + "\n " .join ([f"{ text } \t { score } " for text , score in rec_res_all_crops ])
385+ + f"\n Rec time: { time_profile ['rec' ]} "
386+ )
387+
388+ # filter out low-score texts and merge detection and recognition results
389+ boxes , text_scores = [], []
390+ for i in range (len (polys )):
391+ box = det_res ["polys" ][i ]
392+ if self .is_concat :
393+ text = rec_res_all_crops [0 ][0 ]
394+ text_score = rec_res_all_crops [0 ][1 ]
395+ else :
396+ text = rec_res_all_crops [i ][0 ]
397+ text_score = rec_res_all_crops [i ][1 ]
398+
399+ if text_score >= 0.5 :
400+ boxes .append (box )
401+ text_scores .append ((text , text_score ))
402+ time_profile ["all" ] = time () - start
403+ return boxes , text_scores , time_profile
404+
405+
406+ def predict_both_step (args , det_cfg , rec_cfg ):
407+ # parse args
408+ set_logger (name = "mindocr" )
409+ pred_sys = Predict_System (det_cfg = det_cfg , rec_cfg = rec_cfg , is_concat = args .is_concat )
410+ output_save_dir = det_cfg .predict .output_save_dir or "./output"
411+ img_paths = get_image_paths (args .image_dir )
412+
413+ set_context (mode = det_cfg .system .mode )
414+
415+ tot_time = {} # {'det': 0, 'rec': 0, 'all': 0}
416+ boxes_all , text_scores_all = [], []
417+ for i , img_path in enumerate (img_paths ):
418+ logger .info (f"Infering [{ i + 1 } /{ len (img_paths )} ]: { img_path } " )
419+ boxes , text_scores , time_prof = pred_sys .predict (img_path )
420+ boxes_all .append (boxes )
421+ text_scores_all .append (text_scores )
422+
423+ for k in time_prof :
424+ if k not in tot_time :
425+ tot_time [k ] = time_prof [k ]
426+ else :
427+ tot_time [k ] += time_prof [k ]
428+
429+ fps = len (img_paths ) / tot_time ["all" ]
430+ logger .info (f"Total time:{ tot_time ['all' ]} " )
431+ logger .info (f"Average FPS: { fps } " )
432+ avg_time = {k : tot_time [k ] / len (img_paths ) for k in tot_time }
433+ logger .info (f"Averge time cost: { avg_time } " )
434+
435+ # save result
436+ save_res (boxes_all , text_scores_all , img_paths , save_path = os .path .join (output_save_dir , "system_results.txt" ))
437+ logger .info (f"Done! Results saved in { os .path .join (output_save_dir , 'system_results.txt' )} " )
299438
300439
301440def create_parser ():
@@ -314,6 +453,7 @@ def create_parser():
314453 default = "configs/rec/crnn/crnn_resnet34.yaml" ,
315454 help = 'YAML config file specifying default arguments for rec (default="configs/rec/crnn/crnn_resnet34.yaml")' ,
316455 )
456+ parser .add_argument ("--is_concat" , type = str2bool , default = False , help = "image path or image directory" )
317457 parser .add_argument (
318458 "-o" ,
319459 "--opt" ,
@@ -323,7 +463,9 @@ def create_parser():
323463 )
324464 # modelarts
325465 group = parser .add_argument_group ("modelarts" )
326- group .add_argument ("--enable_modelarts" , type = bool , default = False , help = "Run on modelarts platform (default=False)" )
466+ group .add_argument (
467+ "--enable_modelarts" , type = str2bool , default = False , help = "Run on modelarts platform (default=False)"
468+ )
327469 group .add_argument (
328470 "--device_target" ,
329471 type = str ,
@@ -337,8 +479,6 @@ def create_parser():
337479 group .add_argument ("--pretrain_url" , type = str , default = "" , help = "pre_train_model paths in obs" )
338480 group .add_argument ("--train_url" , type = str , default = "" , help = "model folder to save/load" )
339481
340- # args = parser.parse_args()
341-
342482 return parser
343483
344484
@@ -378,4 +518,4 @@ def parse_args_and_config():
378518 elif args .task_mode == "system" :
379519 rec_cfg = Dict (rec_cfg )
380520 det_cfg = Dict (det_cfg )
381- predict_system (args , det_cfg , rec_cfg )
521+ predict_both_step (args , det_cfg , rec_cfg )
0 commit comments