yolov8-line-cross/predictor.py

171 lines
4.1 KiB
Python
Raw Permalink Normal View History

2023-09-30 11:20:02 +08:00
# YOLOv8
from ultralytics import YOLO
# supervision
import supervision as sv
from supervision.draw.color import Color
# OpenCV
import cv2
# Line Selector
from line_selector import LineSelector
# time date
from datetime import datetime
import time
# 视频源
source = "line.mp4"
source = "rtmp://10.0.0.21:1935/live/picam3"
# 读取背景图像
cap = cv2.VideoCapture(source)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# 输出视频
########## 配置是否录制 ##########
record = True
if record:
output = "output2.mp4"
out = cv2.VideoWriter(output, cv2.VideoWriter_fourcc(*"avc1"), fps / 2, (width, height))
out.set(cv2.VIDEOWRITER_PROP_QUALITY, 80)
# 读取背景图像
ret, frame = cap.read()
if ret:
cv2.imwrite("background.png", frame)
selector = LineSelector("background.png")
selector.draw_image_and_get_coordinates()
if selector.success:
coordinates = selector.get_coordinates()
print("线段坐标: ", coordinates)
else:
print("未选择线段坐标")
cap.release()
exit(0)
cap.release()
# 加载目标检测模型
print("正在加载模型...")
model = YOLO("yolov8x.pt")
# 越线检测位置
LINE_START = sv.Point(coordinates[0], coordinates[1])
LINE_END = sv.Point(coordinates[2], coordinates[3])
line_counter = sv.LineZone(start=LINE_START, end=LINE_END)
# 线的可视化配置
line_color = Color(r=224, g=57, b=151)
line_annotator = sv.LineZoneAnnotator(
thickness=2, text_thickness=2, text_scale=1, color=line_color, text_offset=2.0, custom_in_text="North", custom_out_text="South"
)
# 目标检测可视化配置
box_annotator = sv.BoxAnnotator(thickness=2, text_thickness=2, text_scale=1)
# 起始时间
start_time = time.perf_counter()
# 限定帧数
########## 配置是否限定帧数 ##########
frame_limit = False
frame_limit_upper = 43000
frame_count = 0
# 逐帧跟踪
for result in model.track(source, device=0, verbose=False, stream=True):
# 限定帧数
frame_count += 1
if frame_count == frame_limit_upper and frame_limit:
print(f"已经到达{frame_limit_upper}帧,停止运行")
break
# 获取原始图像
frame = result.orig_img
# 用 supervision 解析预测结果
detections = sv.Detections.from_ultralytics(result)
## 过滤掉某些类别
# detections = detections[(detections.class_id != 60) & (detections.class_id != 0)]
# 解析追踪ID
if result.boxes.id is None:
continue
detections.tracker_id = result.boxes.id.numpy().astype(int)
# 获取每个目标的追踪ID、类别名称、置信度
class_ids = detections.class_id # 类别ID
confidences = detections.confidence # 置信度
tracker_ids = detections.tracker_id # 多目标追踪ID
labels = [
"#{} {} {:.1f}".format(
tracker_ids[i], model.names[class_ids[i]], confidences[i] * 100
)
for i in range(len(class_ids))
]
# 绘制目标检测可视化结果
frame = box_annotator.annotate(scene=frame, detections=detections, labels=labels)
# 越线检测
line_counter.trigger(detections=detections)
line_annotator.annotate(frame=frame, line_counter=line_counter)
# 显示日期
cv2.putText(
frame,
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255, 255, 255),
2,
cv2.LINE_AA,
)
# 结束时间
end_time = time.perf_counter()
# 计算帧率
fps = 1.0 / (end_time - start_time)
# 显示帧率
cv2.putText(
frame,
"FPS: {:.2f}".format(fps),
(10, 60),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255, 255, 255),
2,
cv2.LINE_AA,
)
# 保存视频
if record:
out.write(frame)
# 显示结果
cv2.imshow("Frame", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
# 起始时间
start_time = time.perf_counter()
# 释放资源
if record:
out.release()
cv2.destroyAllWindows()