You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
4.1 KiB
128 lines
4.1 KiB
import argparse
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
import onnxruntime
|
|
from paddleocr.ppocr.data.imaug.operators import (
|
|
E2EResizeForTest, KeepKeys, NormalizeImage, ToCHWImage
|
|
)
|
|
# from ppocr.data.imaug.operators import (
|
|
# E2EResizeForTest, KeepKeys, NormalizeImage, ToCHWImage
|
|
# )
|
|
from paddleocr.ppocr.postprocess.pg_postprocess import PGPostProcess
|
|
# from ppocr.postprocess.pg_postprocess import PGPostProcess
|
|
from pgnet.chr_dct import chr_dct_list
|
|
|
|
|
|
class PGNetPredictor:
|
|
def __init__(self, model_path, cpu=False):
|
|
self.model_path = model_path
|
|
self.dict_path = "ic15_dict.txt"
|
|
|
|
if not os.path.exists(self.dict_path):
|
|
with open(self.dict_path, "w") as f:
|
|
f.writelines(chr_dct_list)
|
|
|
|
providers = ["CPUExecutionProvider"] if cpu else ["CUDAExecutionProvider"]
|
|
self.sess = onnxruntime.InferenceSession(model_path, providers=providers)
|
|
|
|
self.transforms = [
|
|
E2EResizeForTest(max_side_len=768, valid_set="totaltext"),
|
|
NormalizeImage(scale=1/255.0, mean=[0.485,0.456,0.406],
|
|
std=[0.229,0.224,0.225], order="hwc"),
|
|
ToCHWImage(),
|
|
KeepKeys(keep_keys=["image", "shape"]),
|
|
]
|
|
|
|
self.pgpostprocess = PGPostProcess(
|
|
character_dict_path=self.dict_path,
|
|
valid_set="totaltext",
|
|
score_thresh=0.5,
|
|
mode="fast",
|
|
)
|
|
|
|
def preprocess(self, img):
|
|
self.ori_im = img.copy()
|
|
data = {"image": img}
|
|
|
|
for transform in self.transforms:
|
|
data = transform(data)
|
|
|
|
img, shape_list = data
|
|
return np.expand_dims(img, axis=0), np.expand_dims(shape_list, axis=0)
|
|
|
|
def predict(self, img):
|
|
ort_inputs = {self.sess.get_inputs()[0].name: img}
|
|
outputs = self.sess.run(None, ort_inputs)
|
|
|
|
return {
|
|
"f_border": outputs[0],
|
|
"f_char": outputs[1],
|
|
"f_direction": outputs[2],
|
|
"f_score": outputs[3],
|
|
}
|
|
|
|
def clip_boxes(self, boxes, shape):
|
|
h, w = shape[:2]
|
|
clipped = []
|
|
for box in boxes:
|
|
box[:, 0] = np.clip(box[:, 0], 0, w - 1)
|
|
box[:, 1] = np.clip(box[:, 1], 0, h - 1)
|
|
clipped.append(box)
|
|
return np.array(clipped)
|
|
|
|
def postprocess(self, preds, shape_list):
|
|
result = self.pgpostprocess(preds, shape_list)
|
|
pts, texts = result["points"], result["texts"]
|
|
return self.clip_boxes(pts, self.ori_im.shape), texts
|
|
|
|
def infer(self, img):
|
|
img_input, shape = self.preprocess(img)
|
|
preds = self.predict(img_input)
|
|
return self.postprocess(preds, shape)
|
|
|
|
|
|
def draw_results(frame, boxes, texts):
|
|
for box, text in zip(boxes, texts):
|
|
box = box.astype(int).reshape(-1, 1, 2)
|
|
cv2.polylines(frame, [box], True, (255,255,0), 2)
|
|
cv2.putText(frame, text, tuple(box[0][0]), cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.7, (0,255,0), 2)
|
|
return frame
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="PGNet Video OCR")
|
|
parser.add_argument("--model", type=str, required=True)
|
|
parser.add_argument("--video", type=str, required=True)
|
|
parser.add_argument("--cpu", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
predictor = PGNetPredictor(args.model, args.cpu)
|
|
|
|
cap = cv2.VideoCapture(args.video)
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
|
|
out_name = os.path.splitext(os.path.basename(args.video))[0] + "_pgnet_output.mp4"
|
|
out_path = os.path.join(os.path.dirname(args.video), out_name)
|
|
|
|
writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
|
|
|
|
print(f"▶ Processing video... (Output: {out_path})")
|
|
|
|
while True:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
boxes, texts = predictor.infer(frame)
|
|
frame = draw_results(frame, boxes, texts)
|
|
writer.write(frame)
|
|
|
|
cap.release()
|
|
writer.release()
|
|
|
|
print("🎉 Done! Video saved:", out_path)
|