Files
innovate_project/3D_construction/script/linknet_segmentor.py
2025-11-02 21:36:35 +08:00

74 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import torch
import numpy as np
# 导入我们刚刚创建的模型定义
from .linknet_model_def import LinkNet
# 模型缓存
_linknet_models = {}
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"LinkNet will use device: {_device}")
def _get_endpoints_from_mask(mask: np.ndarray):
"""内部函数从二值化mask中提取直线端点。"""
points = cv2.findNonZero(mask)
if points is None:
return None, None
line_params = cv2.fitLine(points, cv2.DIST_L2, 0, 0.01, 0.01)
direction_vector = np.array([line_params[0][0], line_params[1][0]])
points_flat = points.reshape(-1, 2)
projections = points_flat.dot(direction_vector)
min_idx, max_idx = np.argmin(projections), np.argmax(projections)
start_point, end_point = tuple(points_flat[min_idx]), tuple(points_flat[max_idx])
return start_point, end_point
def segment_and_find_endpoints(original_image: np.ndarray,
crop_box: tuple,
model_path: str,
image_size: int = 256):
"""
在指定的裁切区域内使用LinkNet进行分割并找出焊缝端点。
返回原始图像坐标系下的 (start_point, end_point)。
"""
if model_path not in _linknet_models:
print(f"Loading LinkNet model from: {model_path}")
if not os.path.exists(model_path):
print(f"Error: LinkNet model file not found at {model_path}")
return None, None
model = LinkNet(num_classes=1)
model.load_state_dict(torch.load(model_path, map_location=_device))
model.to(_device)
model.eval()
_linknet_models[model_path] = model
model = _linknet_models[model_path]
x1, y1, x2, y2 = crop_box
cropped_img = original_image[y1:y2, x1:x2]
img_gray = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2GRAY)
crop_h, crop_w = img_gray.shape
img_resized = cv2.resize(img_gray, (image_size, image_size))
img_normalized = img_resized / 255.0
img_tensor = torch.from_numpy(img_normalized).unsqueeze(0).unsqueeze(0).float().to(_device)
with torch.no_grad():
output = model(img_tensor)
pred_mask_resized = output.cpu().numpy()[0, 0]
pred_mask_binary = (pred_mask_resized > 0.5).astype(np.uint8)
predicted_mask = cv2.resize(pred_mask_binary, (crop_w, crop_h), interpolation=cv2.INTER_NEAREST) * 255
start_crop, end_crop = _get_endpoints_from_mask(predicted_mask)
if start_crop is None:
return None, None
start_orig = (start_crop[0] + x1, start_crop[1] + y1)
end_orig = (end_crop[0] + x1, end_crop[1] + y1)
return start_orig, end_orig