yolov8-line-cross/predictor.py

171 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# YOLOv8
from ultralytics import YOLO
# supervision
import supervision as sv
from supervision.draw.color import Color
# OpenCV
import cv2
# Line Selector
from line_selector import LineSelector
# time date
from datetime import datetime
import time
# 视频源
source = "line.mp4"
source = "rtmp://10.0.0.21:1935/live/picam3"
# 读取背景图像
cap = cv2.VideoCapture(source)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# 输出视频
########## 配置是否录制 ##########
record = True
if record:
output = "output2.mp4"
out = cv2.VideoWriter(output, cv2.VideoWriter_fourcc(*"avc1"), fps / 2, (width, height))
out.set(cv2.VIDEOWRITER_PROP_QUALITY, 80)
# 读取背景图像
ret, frame = cap.read()
if ret:
cv2.imwrite("background.png", frame)
selector = LineSelector("background.png")
selector.draw_image_and_get_coordinates()
if selector.success:
coordinates = selector.get_coordinates()
print("线段坐标: ", coordinates)
else:
print("未选择线段坐标")
cap.release()
exit(0)
cap.release()
# 加载目标检测模型
print("正在加载模型...")
model = YOLO("yolov8x.pt")
# 越线检测位置
LINE_START = sv.Point(coordinates[0], coordinates[1])
LINE_END = sv.Point(coordinates[2], coordinates[3])
line_counter = sv.LineZone(start=LINE_START, end=LINE_END)
# 线的可视化配置
line_color = Color(r=224, g=57, b=151)
line_annotator = sv.LineZoneAnnotator(
thickness=2, text_thickness=2, text_scale=1, color=line_color, text_offset=2.0, custom_in_text="North", custom_out_text="South"
)
# 目标检测可视化配置
box_annotator = sv.BoxAnnotator(thickness=2, text_thickness=2, text_scale=1)
# 起始时间
start_time = time.perf_counter()
# 限定帧数
########## 配置是否限定帧数 ##########
frame_limit = False
frame_limit_upper = 43000
frame_count = 0
# 逐帧跟踪
for result in model.track(source, device=0, verbose=False, stream=True):
# 限定帧数
frame_count += 1
if frame_count == frame_limit_upper and frame_limit:
print(f"已经到达{frame_limit_upper}帧,停止运行")
break
# 获取原始图像
frame = result.orig_img
# 用 supervision 解析预测结果
detections = sv.Detections.from_ultralytics(result)
## 过滤掉某些类别
# detections = detections[(detections.class_id != 60) & (detections.class_id != 0)]
# 解析追踪ID
if result.boxes.id is None:
continue
detections.tracker_id = result.boxes.id.numpy().astype(int)
# 获取每个目标的追踪ID、类别名称、置信度
class_ids = detections.class_id # 类别ID
confidences = detections.confidence # 置信度
tracker_ids = detections.tracker_id # 多目标追踪ID
labels = [
"#{} {} {:.1f}".format(
tracker_ids[i], model.names[class_ids[i]], confidences[i] * 100
)
for i in range(len(class_ids))
]
# 绘制目标检测可视化结果
frame = box_annotator.annotate(scene=frame, detections=detections, labels=labels)
# 越线检测
line_counter.trigger(detections=detections)
line_annotator.annotate(frame=frame, line_counter=line_counter)
# 显示日期
cv2.putText(
frame,
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255, 255, 255),
2,
cv2.LINE_AA,
)
# 结束时间
end_time = time.perf_counter()
# 计算帧率
fps = 1.0 / (end_time - start_time)
# 显示帧率
cv2.putText(
frame,
"FPS: {:.2f}".format(fps),
(10, 60),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255, 255, 255),
2,
cv2.LINE_AA,
)
# 保存视频
if record:
out.write(frame)
# 显示结果
cv2.imshow("Frame", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
# 起始时间
start_time = time.perf_counter()
# 释放资源
if record:
out.release()
cv2.destroyAllWindows()