You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
4.1 KiB

import argparse
import os
import cv2
import numpy as np
import onnxruntime
from paddleocr.ppocr.data.imaug.operators import (
E2EResizeForTest, KeepKeys, NormalizeImage, ToCHWImage
)
# from ppocr.data.imaug.operators import (
# E2EResizeForTest, KeepKeys, NormalizeImage, ToCHWImage
# )
from paddleocr.ppocr.postprocess.pg_postprocess import PGPostProcess
# from ppocr.postprocess.pg_postprocess import PGPostProcess
from pgnet.chr_dct import chr_dct_list
class PGNetPredictor:
def __init__(self, model_path, cpu=False):
self.model_path = model_path
self.dict_path = "ic15_dict.txt"
if not os.path.exists(self.dict_path):
with open(self.dict_path, "w") as f:
f.writelines(chr_dct_list)
providers = ["CPUExecutionProvider"] if cpu else ["CUDAExecutionProvider"]
self.sess = onnxruntime.InferenceSession(model_path, providers=providers)
self.transforms = [
E2EResizeForTest(max_side_len=768, valid_set="totaltext"),
NormalizeImage(scale=1/255.0, mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225], order="hwc"),
ToCHWImage(),
KeepKeys(keep_keys=["image", "shape"]),
]
self.pgpostprocess = PGPostProcess(
character_dict_path=self.dict_path,
valid_set="totaltext",
score_thresh=0.5,
mode="fast",
)
def preprocess(self, img):
self.ori_im = img.copy()
data = {"image": img}
for transform in self.transforms:
data = transform(data)
img, shape_list = data
return np.expand_dims(img, axis=0), np.expand_dims(shape_list, axis=0)
def predict(self, img):
ort_inputs = {self.sess.get_inputs()[0].name: img}
outputs = self.sess.run(None, ort_inputs)
return {
"f_border": outputs[0],
"f_char": outputs[1],
"f_direction": outputs[2],
"f_score": outputs[3],
}
def clip_boxes(self, boxes, shape):
h, w = shape[:2]
clipped = []
for box in boxes:
box[:, 0] = np.clip(box[:, 0], 0, w - 1)
box[:, 1] = np.clip(box[:, 1], 0, h - 1)
clipped.append(box)
return np.array(clipped)
def postprocess(self, preds, shape_list):
result = self.pgpostprocess(preds, shape_list)
pts, texts = result["points"], result["texts"]
return self.clip_boxes(pts, self.ori_im.shape), texts
def infer(self, img):
img_input, shape = self.preprocess(img)
preds = self.predict(img_input)
return self.postprocess(preds, shape)
def draw_results(frame, boxes, texts):
for box, text in zip(boxes, texts):
box = box.astype(int).reshape(-1, 1, 2)
cv2.polylines(frame, [box], True, (255,255,0), 2)
cv2.putText(frame, text, tuple(box[0][0]), cv2.FONT_HERSHEY_SIMPLEX,
0.7, (0,255,0), 2)
return frame
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="PGNet Video OCR")
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--video", type=str, required=True)
parser.add_argument("--cpu", action="store_true")
args = parser.parse_args()
predictor = PGNetPredictor(args.model, args.cpu)
cap = cv2.VideoCapture(args.video)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
out_name = os.path.splitext(os.path.basename(args.video))[0] + "_pgnet_output.mp4"
out_path = os.path.join(os.path.dirname(args.video), out_name)
writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
print(f"▶ Processing video... (Output: {out_path})")
while True:
ret, frame = cap.read()
if not ret:
break
boxes, texts = predictor.infer(frame)
frame = draw_results(frame, boxes, texts)
writer.write(frame)
cap.release()
writer.release()
print("🎉 Done! Video saved:", out_path)