From 395245c6661d5736013ee4a58155c6e745fb6745 Mon Sep 17 00:00:00 2001 From: TaurusXin Date: Sat, 30 Sep 2023 11:20:02 +0800 Subject: [PATCH] init repo --- .gitignore | 10 +++ README.md | 20 ++++++ line_selector.py | 66 ++++++++++++++++++ predictor.py | 170 +++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | Bin 0 -> 1616 bytes 5 files changed, 266 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 line_selector.py create mode 100644 predictor.py create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..879fea2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +__pycache__ + +runs + +*.mp4 +*.dll + +*.png + +*.pt \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..73a6b75 --- /dev/null +++ b/README.md @@ -0,0 +1,20 @@ +# YOLOv8 过线检测 + +基于 YOLOv8 的过线检测,支持自定义划线。 + +## 开始 + +建议的 Python 版本为 3.10.x,首先创建一个虚拟环境(推荐),然后按顺序安装以下包 + +```shell +# GPU 环境 +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 +pip install ultralytics +pip install supervision opencv-contrib-python pillow lapx +``` + +然后可以开始运行 + +```shell +python predictor.py +```shell diff --git a/line_selector.py b/line_selector.py new file mode 100644 index 0000000..b873b30 --- /dev/null +++ b/line_selector.py @@ -0,0 +1,66 @@ +import tkinter as tk + +class LineSelector: + def __init__(self, image_path): + self.image_path = image_path + self.start_x, self.start_y = None, None + self.end_x, self.end_y = None, None + self.line_id = None + self.coordinates = None + self.success = False + + def on_mouse_press(self, event): + self.start_x, self.start_y = event.x, event.y + + def on_mouse_drag(self, event): + if self.line_id: + self.canvas.delete(self.line_id) + self.end_x, self.end_y = event.x, event.y + self.line_id = self.canvas.create_line(self.start_x, self.start_y, self.end_x, self.end_y, fill="white") + + def on_mouse_release(self, event): + self.end_x, self.end_y = event.x, event.y + self.update_coordinates_label() + + def update_coordinates_label(self): + coordinates = f"起点坐标: ({self.start_x}, {self.start_y}),终点坐标: ({self.end_x}, {self.end_y})" + self.coordinates_label.config(text=coordinates) + + def on_confirm_button_click(self): + self.success = True + self.root.destroy() + self.root.quit() + + def draw_image_and_get_coordinates(self): + self.root = tk.Tk() + self.root.title("绘制检测线") + self.canvas = tk.Canvas(self.root, width=1920, height=1080) + self.canvas.pack() + image = tk.PhotoImage(file=self.image_path) + self.canvas.create_image(0, 0, anchor=tk.NW, image=image) + + self.coordinates_label = tk.Label(self.root, text="", fg="red") + self.coordinates_label.pack() + + confirm_button = tk.Button(self.root, text="确认", command=self.on_confirm_button_click) + confirm_button.pack() + + self.canvas.bind("", self.on_mouse_press) + self.canvas.bind("", self.on_mouse_drag) + self.canvas.bind("", self.on_mouse_release) + + self.root.mainloop() + + def get_coordinates(self): + self.coordinates = [self.start_x, self.start_y, self.end_x, self.end_y] + return self.coordinates + +def main(): + image_path = "background.png" + selector = LineSelector(image_path) + selector.draw_image_and_get_coordinates() + coordinates = selector.get_coordinates() + print("坐标:", coordinates) + +if __name__ == "__main__": + main() diff --git a/predictor.py b/predictor.py new file mode 100644 index 0000000..d6cce99 --- /dev/null +++ b/predictor.py @@ -0,0 +1,170 @@ +# YOLOv8 +from ultralytics import YOLO + +# supervision +import supervision as sv +from supervision.draw.color import Color + +# OpenCV +import cv2 + +# Line Selector +from line_selector import LineSelector + +# time date +from datetime import datetime +import time + +# 视频源 +source = "line.mp4" +source = "rtmp://10.0.0.21:1935/live/picam3" + +# 读取背景图像 +cap = cv2.VideoCapture(source) + +fps = cap.get(cv2.CAP_PROP_FPS) +width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) +height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + +# 输出视频 + +########## 配置是否录制 ########## +record = True + +if record: + output = "output2.mp4" + out = cv2.VideoWriter(output, cv2.VideoWriter_fourcc(*"avc1"), fps / 2, (width, height)) + out.set(cv2.VIDEOWRITER_PROP_QUALITY, 80) + +# 读取背景图像 +ret, frame = cap.read() +if ret: + cv2.imwrite("background.png", frame) + selector = LineSelector("background.png") + selector.draw_image_and_get_coordinates() + if selector.success: + coordinates = selector.get_coordinates() + print("线段坐标: ", coordinates) + else: + print("未选择线段坐标") + cap.release() + exit(0) + +cap.release() + +# 加载目标检测模型 +print("正在加载模型...") +model = YOLO("yolov8x.pt") + +# 越线检测位置 +LINE_START = sv.Point(coordinates[0], coordinates[1]) +LINE_END = sv.Point(coordinates[2], coordinates[3]) +line_counter = sv.LineZone(start=LINE_START, end=LINE_END) + +# 线的可视化配置 +line_color = Color(r=224, g=57, b=151) +line_annotator = sv.LineZoneAnnotator( + thickness=2, text_thickness=2, text_scale=1, color=line_color, text_offset=2.0, custom_in_text="North", custom_out_text="South" +) + +# 目标检测可视化配置 +box_annotator = sv.BoxAnnotator(thickness=2, text_thickness=2, text_scale=1) + +# 起始时间 +start_time = time.perf_counter() + +# 限定帧数 +########## 配置是否限定帧数 ########## +frame_limit = False +frame_limit_upper = 43000 +frame_count = 0 + +# 逐帧跟踪 +for result in model.track(source, device=0, verbose=False, stream=True): + # 限定帧数 + frame_count += 1 + if frame_count == frame_limit_upper and frame_limit: + print(f"已经到达{frame_limit_upper}帧,停止运行") + break + + # 获取原始图像 + frame = result.orig_img + + # 用 supervision 解析预测结果 + detections = sv.Detections.from_ultralytics(result) + + ## 过滤掉某些类别 + # detections = detections[(detections.class_id != 60) & (detections.class_id != 0)] + + # 解析追踪ID + if result.boxes.id is None: + continue + + detections.tracker_id = result.boxes.id.numpy().astype(int) + + # 获取每个目标的:追踪ID、类别名称、置信度 + class_ids = detections.class_id # 类别ID + confidences = detections.confidence # 置信度 + tracker_ids = detections.tracker_id # 多目标追踪ID + labels = [ + "#{} {} {:.1f}".format( + tracker_ids[i], model.names[class_ids[i]], confidences[i] * 100 + ) + for i in range(len(class_ids)) + ] + + # 绘制目标检测可视化结果 + frame = box_annotator.annotate(scene=frame, detections=detections, labels=labels) + + # 越线检测 + line_counter.trigger(detections=detections) + line_annotator.annotate(frame=frame, line_counter=line_counter) + + # 显示日期 + cv2.putText( + frame, + datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (255, 255, 255), + 2, + cv2.LINE_AA, + ) + + # 结束时间 + end_time = time.perf_counter() + + # 计算帧率 + fps = 1.0 / (end_time - start_time) + + # 显示帧率 + + cv2.putText( + frame, + "FPS: {:.2f}".format(fps), + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (255, 255, 255), + 2, + cv2.LINE_AA, + ) + + # 保存视频 + if record: + out.write(frame) + + # 显示结果 + cv2.imshow("Frame", frame) + + if cv2.waitKey(1) & 0xFF == ord("q"): + break + + # 起始时间 + start_time = time.perf_counter() + +# 释放资源 +if record: + out.release() +cv2.destroyAllWindows() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..39aabaef693a0d589f9aac18f857e28872b71141 GIT binary patch literal 1616 zcmbW1OK;Oq5QS%r#4kaq5+_Yz(FJ0MK&r%w6*75HH;&`R*R<)62fi~CKZ?o*Mb~QY zoH=vmb$|a_*~X5RtgtQTnqOv9PG%Q2^Y4|twpX^YTgyQ<;FC?PvIGC^xhwH|a+dIx zpH6rkunUg*c$%4_g^^sNA4@T<38&xQXjE*SRXhI z5lZ{nXCQ?5FP)>ryS7L7m+niqD&}V_M>MzCzwq}KhH;f6in@?MD8t0Ju()@&)wqH=p;f)wowYTY!q+6<=n45o}zUgnf=@2WrdoCm%Q#_-+R54!;)Xu z=R_toku|5^$o&yj+PW_LVa|u_6ECx-H*rF*-uCFX@@kn)ZK_^z?OZ|L!PD)EDiAN7 T-e*m_wtDVt(aTUc_d5Lnr8nu# literal 0 HcmV?d00001