三维重构终版

This commit is contained in:
2025-11-02 21:36:35 +08:00
parent f91b09da9d
commit f39009b853
126 changed files with 2870 additions and 2 deletions

View File

@@ -0,0 +1,74 @@
import os
import cv2
import torch
import numpy as np
# 导入我们刚刚创建的模型定义
from .linknet_model_def import LinkNet
# 模型缓存
_linknet_models = {}
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"LinkNet will use device: {_device}")
def _get_endpoints_from_mask(mask: np.ndarray):
"""内部函数从二值化mask中提取直线端点。"""
points = cv2.findNonZero(mask)
if points is None:
return None, None
line_params = cv2.fitLine(points, cv2.DIST_L2, 0, 0.01, 0.01)
direction_vector = np.array([line_params[0][0], line_params[1][0]])
points_flat = points.reshape(-1, 2)
projections = points_flat.dot(direction_vector)
min_idx, max_idx = np.argmin(projections), np.argmax(projections)
start_point, end_point = tuple(points_flat[min_idx]), tuple(points_flat[max_idx])
return start_point, end_point
def segment_and_find_endpoints(original_image: np.ndarray,
crop_box: tuple,
model_path: str,
image_size: int = 256):
"""
在指定的裁切区域内使用LinkNet进行分割并找出焊缝端点。
返回原始图像坐标系下的 (start_point, end_point)。
"""
if model_path not in _linknet_models:
print(f"Loading LinkNet model from: {model_path}")
if not os.path.exists(model_path):
print(f"Error: LinkNet model file not found at {model_path}")
return None, None
model = LinkNet(num_classes=1)
model.load_state_dict(torch.load(model_path, map_location=_device))
model.to(_device)
model.eval()
_linknet_models[model_path] = model
model = _linknet_models[model_path]
x1, y1, x2, y2 = crop_box
cropped_img = original_image[y1:y2, x1:x2]
img_gray = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2GRAY)
crop_h, crop_w = img_gray.shape
img_resized = cv2.resize(img_gray, (image_size, image_size))
img_normalized = img_resized / 255.0
img_tensor = torch.from_numpy(img_normalized).unsqueeze(0).unsqueeze(0).float().to(_device)
with torch.no_grad():
output = model(img_tensor)
pred_mask_resized = output.cpu().numpy()[0, 0]
pred_mask_binary = (pred_mask_resized > 0.5).astype(np.uint8)
predicted_mask = cv2.resize(pred_mask_binary, (crop_w, crop_h), interpolation=cv2.INTER_NEAREST) * 255
start_crop, end_crop = _get_endpoints_from_mask(predicted_mask)
if start_crop is None:
return None, None
start_orig = (start_crop[0] + x1, start_crop[1] + y1)
end_orig = (end_crop[0] + x1, end_crop[1] + y1)
return start_orig, end_orig