import os import cv2 import torch import numpy as np # 导入我们刚刚创建的模型定义 from .linknet_model_def import LinkNet # 模型缓存 _linknet_models = {} _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"LinkNet will use device: {_device}") def _get_endpoints_from_mask(mask: np.ndarray): """内部函数:从二值化mask中提取直线端点。""" points = cv2.findNonZero(mask) if points is None: return None, None line_params = cv2.fitLine(points, cv2.DIST_L2, 0, 0.01, 0.01) direction_vector = np.array([line_params[0][0], line_params[1][0]]) points_flat = points.reshape(-1, 2) projections = points_flat.dot(direction_vector) min_idx, max_idx = np.argmin(projections), np.argmax(projections) start_point, end_point = tuple(points_flat[min_idx]), tuple(points_flat[max_idx]) return start_point, end_point def segment_and_find_endpoints(original_image: np.ndarray, crop_box: tuple, model_path: str, image_size: int = 256): """ 在指定的裁切区域内使用LinkNet进行分割,并找出焊缝端点。 返回原始图像坐标系下的 (start_point, end_point)。 """ if model_path not in _linknet_models: print(f"Loading LinkNet model from: {model_path}") if not os.path.exists(model_path): print(f"Error: LinkNet model file not found at {model_path}") return None, None model = LinkNet(num_classes=1) model.load_state_dict(torch.load(model_path, map_location=_device)) model.to(_device) model.eval() _linknet_models[model_path] = model model = _linknet_models[model_path] x1, y1, x2, y2 = crop_box cropped_img = original_image[y1:y2, x1:x2] img_gray = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2GRAY) crop_h, crop_w = img_gray.shape img_resized = cv2.resize(img_gray, (image_size, image_size)) img_normalized = img_resized / 255.0 img_tensor = torch.from_numpy(img_normalized).unsqueeze(0).unsqueeze(0).float().to(_device) with torch.no_grad(): output = model(img_tensor) pred_mask_resized = output.cpu().numpy()[0, 0] pred_mask_binary = (pred_mask_resized > 0.5).astype(np.uint8) predicted_mask = cv2.resize(pred_mask_binary, (crop_w, crop_h), interpolation=cv2.INTER_NEAREST) * 255 start_crop, end_crop = _get_endpoints_from_mask(predicted_mask) if start_crop is None: return None, None start_orig = (start_crop[0] + x1, start_crop[1] + y1) end_orig = (end_crop[0] + x1, end_crop[1] + y1) return start_orig, end_orig