56 lines
2.0 KiB
Python
56 lines
2.0 KiB
Python
import os
|
||
import cv2
|
||
from ultralytics import YOLO
|
||
|
||
# 这是一个好习惯,将模型加载放在函数外部,这样在多次调用函数时模型只需加载一次。
|
||
# 我们将模型路径作为参数传入,使其更具通用性。
|
||
models = {}
|
||
|
||
|
||
def detect_crop_area(image_path: str, model_path: str):
|
||
"""
|
||
使用YOLOv8模型检测图像中的裁切区域。
|
||
|
||
Args:
|
||
image_path (str): 原始图像的文件路径。
|
||
model_path (str): 用于检测的YOLOv8模型 (.pt) 的路径。
|
||
|
||
Returns:
|
||
tuple or None: 如果检测到物体,返回一个包含整数坐标的元组 (x1, y1, x2, y2)。
|
||
如果没有检测到或发生错误,返回 None。
|
||
"""
|
||
# 检查模型是否已加载,如果没有,则加载并缓存
|
||
if model_path not in models:
|
||
print(f"Loading YOLOv8 model from: {model_path}")
|
||
if not os.path.exists(model_path):
|
||
print(f"Error: Model file not found at {model_path}")
|
||
return None
|
||
models[model_path] = YOLO(model_path)
|
||
|
||
model = models[model_path]
|
||
|
||
# 检查图像文件是否存在
|
||
if not os.path.exists(image_path):
|
||
print(f"Error: Image file not found at {image_path}")
|
||
return None
|
||
|
||
try:
|
||
# 执行预测,verbose=False可以减少不必要的控制台输出
|
||
results = model.predict(source=image_path, conf=0.5, verbose=False)
|
||
|
||
# 检查是否有检测结果
|
||
if not results or not results[0].boxes:
|
||
print(f"Warning: YOLO did not detect any objects in {image_path}")
|
||
return None
|
||
|
||
# 获取置信度最高的那个检测框
|
||
# YOLOv8的results[0].boxes包含所有检测框,我们通常取第一个(置信度最高的)
|
||
box = results[0].boxes.xyxy[0].cpu().numpy().astype(int)
|
||
|
||
# 返回整数坐标的元组
|
||
return tuple(box)
|
||
|
||
except Exception as e:
|
||
print(f"An error occurred during prediction for {image_path}: {e}")
|
||
return None
|