You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
126 lines
4.5 KiB
126 lines
4.5 KiB
|
6 months ago
|
import argparse
|
||
|
|
import os
|
||
|
|
|
||
|
|
import cv2
|
||
|
|
import numpy as np
|
||
|
|
import onnxruntime
|
||
|
|
from paddleocr.ppocr.data.imaug.operators import (E2EResizeForTest, KeepKeys,
|
||
|
|
NormalizeImage, ToCHWImage)
|
||
|
|
from paddleocr.ppocr.postprocess.pg_postprocess import PGPostProcess
|
||
|
|
|
||
|
|
# from pgnet.chr_dct import chr_dct_list
|
||
|
|
|
||
|
|
|
||
|
|
class PGNetPredictor:
|
||
|
|
def __init__(self, img_path, cpu):
|
||
|
|
self.img_path = img_path
|
||
|
|
self.dict_path = "ic15_dict.txt"
|
||
|
|
# if not os.path.exists(self.dict_path):
|
||
|
|
# with open(self.dict_path, "w") as f:
|
||
|
|
# f.writelines(chr_dct_list)
|
||
|
|
if not cpu:
|
||
|
|
providers = ["CUDAExecutionProvider"]
|
||
|
|
else:
|
||
|
|
providers = ["CPUExecutionProvider"]
|
||
|
|
self.sess = onnxruntime.InferenceSession(
|
||
|
|
args.model_path, providers=providers)
|
||
|
|
|
||
|
|
def preprocess(self, img_path):
|
||
|
|
img = cv2.imread(img_path)
|
||
|
|
self.ori_im = img.copy()
|
||
|
|
data = {"image": img}
|
||
|
|
transforms = [
|
||
|
|
E2EResizeForTest(max_side_len=768, valid_set="totaltext"),
|
||
|
|
NormalizeImage(
|
||
|
|
scale=1.0 / 255.0,
|
||
|
|
mean=[0.485, 0.456, 0.406],
|
||
|
|
std=[0.229, 0.224, 0.225],
|
||
|
|
order="hwc",
|
||
|
|
),
|
||
|
|
ToCHWImage(),
|
||
|
|
KeepKeys(keep_keys=["image", "shape"]),
|
||
|
|
]
|
||
|
|
for transform in transforms:
|
||
|
|
data = transform(data)
|
||
|
|
img, shape_list = data
|
||
|
|
img = np.expand_dims(img, axis=0)
|
||
|
|
shape_list = np.expand_dims(shape_list, axis=0)
|
||
|
|
return img, shape_list
|
||
|
|
|
||
|
|
def predict(self, img):
|
||
|
|
ort_inputs = {self.sess.get_inputs()[0].name: img}
|
||
|
|
outputs = self.sess.run(None, ort_inputs)
|
||
|
|
preds = {}
|
||
|
|
preds["f_border"] = outputs[0]
|
||
|
|
preds["f_char"] = outputs[1]
|
||
|
|
preds["f_direction"] = outputs[2]
|
||
|
|
preds["f_score"] = outputs[3]
|
||
|
|
return preds
|
||
|
|
|
||
|
|
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
||
|
|
img_height, img_width = image_shape[0:2]
|
||
|
|
dt_boxes_new = []
|
||
|
|
for box in dt_boxes:
|
||
|
|
box = self.clip_det_res(box, img_height, img_width)
|
||
|
|
dt_boxes_new.append(box)
|
||
|
|
dt_boxes = np.array(dt_boxes_new)
|
||
|
|
return dt_boxes
|
||
|
|
|
||
|
|
def clip_det_res(self, points, img_height, img_width):
|
||
|
|
for pno in range(points.shape[0]):
|
||
|
|
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
||
|
|
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
||
|
|
return points
|
||
|
|
|
||
|
|
def postprocess(self, preds, shape_list):
|
||
|
|
pgpostprocess = PGPostProcess(
|
||
|
|
character_dict_path=self.dict_path,
|
||
|
|
valid_set="totaltext",
|
||
|
|
score_thresh=0.5,
|
||
|
|
mode="fast",
|
||
|
|
)
|
||
|
|
post_result = pgpostprocess(preds, shape_list)
|
||
|
|
points, strs = post_result["points"], post_result["texts"]
|
||
|
|
dt_boxes = self.filter_tag_det_res_only_clip(points, self.ori_im.shape)
|
||
|
|
return dt_boxes, strs
|
||
|
|
|
||
|
|
def __call__(self):
|
||
|
|
img, shape_list = self.preprocess(self.img_path)
|
||
|
|
preds = self.predict(img)
|
||
|
|
dt_boxes, strs = self.postprocess(preds, shape_list)
|
||
|
|
return dt_boxes, strs
|
||
|
|
|
||
|
|
def draw(self, dt_boxes, strs, img_path):
|
||
|
|
src_im = cv2.imread(img_path)
|
||
|
|
width, height, _ = src_im.shape
|
||
|
|
for box, str in zip(dt_boxes, strs):
|
||
|
|
box = box.astype(np.int32).reshape((-1, 1, 2))
|
||
|
|
cv2.polylines(src_im, [box], True, color=(
|
||
|
|
255, 255, 0), thickness=2)
|
||
|
|
cv2.putText(
|
||
|
|
src_im,
|
||
|
|
str,
|
||
|
|
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
|
||
|
|
fontFace=cv2.FONT_HERSHEY_COMPLEX,
|
||
|
|
fontScale=0.7 / 500 * width / 2,
|
||
|
|
color=(0, 255, 0),
|
||
|
|
thickness=int(1 / 1000 * width),
|
||
|
|
)
|
||
|
|
img_out_name = os.path.basename(img_path).split(".")[0]
|
||
|
|
img_out_name = f"{img_out_name}_pgnet.jpg"
|
||
|
|
cv2.imwrite(img_out_name, src_im)
|
||
|
|
return src_im
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
parser = argparse.ArgumentParser(description="PGPNET inference")
|
||
|
|
parser.add_argument("model_path", type=str, help="onnxmodel path")
|
||
|
|
parser.add_argument("img_path", type=str, help="image path")
|
||
|
|
parser.add_argument(
|
||
|
|
"--cpu", action="store_true", help="cpu inference, default device is gpu"
|
||
|
|
)
|
||
|
|
args = parser.parse_args()
|
||
|
|
pgnetpredictor = PGNetPredictor(args.img_path, args.cpu)
|
||
|
|
dt_boxes, strs = pgnetpredictor()
|
||
|
|
print(f"Predict string:{strs}")
|
||
|
|
pgnetpredictor.draw(dt_boxes, strs, args.img_path)
|