74 lines
2.6 KiB
Python
74 lines
2.6 KiB
Python
import os
|
||
import cv2
|
||
import torch
|
||
import numpy as np
|
||
|
||
# 导入我们刚刚创建的模型定义
|
||
from .linknet_model_def import LinkNet
|
||
|
||
# 模型缓存
|
||
_linknet_models = {}
|
||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"LinkNet will use device: {_device}")
|
||
|
||
|
||
def _get_endpoints_from_mask(mask: np.ndarray):
|
||
"""内部函数:从二值化mask中提取直线端点。"""
|
||
points = cv2.findNonZero(mask)
|
||
if points is None:
|
||
return None, None
|
||
line_params = cv2.fitLine(points, cv2.DIST_L2, 0, 0.01, 0.01)
|
||
direction_vector = np.array([line_params[0][0], line_params[1][0]])
|
||
points_flat = points.reshape(-1, 2)
|
||
projections = points_flat.dot(direction_vector)
|
||
min_idx, max_idx = np.argmin(projections), np.argmax(projections)
|
||
start_point, end_point = tuple(points_flat[min_idx]), tuple(points_flat[max_idx])
|
||
return start_point, end_point
|
||
|
||
|
||
def segment_and_find_endpoints(original_image: np.ndarray,
|
||
crop_box: tuple,
|
||
model_path: str,
|
||
image_size: int = 256):
|
||
"""
|
||
在指定的裁切区域内使用LinkNet进行分割,并找出焊缝端点。
|
||
返回原始图像坐标系下的 (start_point, end_point)。
|
||
"""
|
||
if model_path not in _linknet_models:
|
||
print(f"Loading LinkNet model from: {model_path}")
|
||
if not os.path.exists(model_path):
|
||
print(f"Error: LinkNet model file not found at {model_path}")
|
||
return None, None
|
||
model = LinkNet(num_classes=1)
|
||
model.load_state_dict(torch.load(model_path, map_location=_device))
|
||
model.to(_device)
|
||
model.eval()
|
||
_linknet_models[model_path] = model
|
||
|
||
model = _linknet_models[model_path]
|
||
|
||
x1, y1, x2, y2 = crop_box
|
||
cropped_img = original_image[y1:y2, x1:x2]
|
||
|
||
img_gray = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2GRAY)
|
||
crop_h, crop_w = img_gray.shape
|
||
img_resized = cv2.resize(img_gray, (image_size, image_size))
|
||
img_normalized = img_resized / 255.0
|
||
img_tensor = torch.from_numpy(img_normalized).unsqueeze(0).unsqueeze(0).float().to(_device)
|
||
|
||
with torch.no_grad():
|
||
output = model(img_tensor)
|
||
|
||
pred_mask_resized = output.cpu().numpy()[0, 0]
|
||
pred_mask_binary = (pred_mask_resized > 0.5).astype(np.uint8)
|
||
predicted_mask = cv2.resize(pred_mask_binary, (crop_w, crop_h), interpolation=cv2.INTER_NEAREST) * 255
|
||
|
||
start_crop, end_crop = _get_endpoints_from_mask(predicted_mask)
|
||
|
||
if start_crop is None:
|
||
return None, None
|
||
|
||
start_orig = (start_crop[0] + x1, start_crop[1] + y1)
|
||
end_orig = (end_crop[0] + x1, end_crop[1] + y1)
|
||
|
||
return start_orig, end_orig |