三维重构终版
This commit is contained in:
74
3D_construction/script/linknet_segmentor.py
Normal file
74
3D_construction/script/linknet_segmentor.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# 导入我们刚刚创建的模型定义
|
||||
from .linknet_model_def import LinkNet
|
||||
|
||||
# 模型缓存
|
||||
_linknet_models = {}
|
||||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"LinkNet will use device: {_device}")
|
||||
|
||||
|
||||
def _get_endpoints_from_mask(mask: np.ndarray):
|
||||
"""内部函数:从二值化mask中提取直线端点。"""
|
||||
points = cv2.findNonZero(mask)
|
||||
if points is None:
|
||||
return None, None
|
||||
line_params = cv2.fitLine(points, cv2.DIST_L2, 0, 0.01, 0.01)
|
||||
direction_vector = np.array([line_params[0][0], line_params[1][0]])
|
||||
points_flat = points.reshape(-1, 2)
|
||||
projections = points_flat.dot(direction_vector)
|
||||
min_idx, max_idx = np.argmin(projections), np.argmax(projections)
|
||||
start_point, end_point = tuple(points_flat[min_idx]), tuple(points_flat[max_idx])
|
||||
return start_point, end_point
|
||||
|
||||
|
||||
def segment_and_find_endpoints(original_image: np.ndarray,
|
||||
crop_box: tuple,
|
||||
model_path: str,
|
||||
image_size: int = 256):
|
||||
"""
|
||||
在指定的裁切区域内使用LinkNet进行分割,并找出焊缝端点。
|
||||
返回原始图像坐标系下的 (start_point, end_point)。
|
||||
"""
|
||||
if model_path not in _linknet_models:
|
||||
print(f"Loading LinkNet model from: {model_path}")
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error: LinkNet model file not found at {model_path}")
|
||||
return None, None
|
||||
model = LinkNet(num_classes=1)
|
||||
model.load_state_dict(torch.load(model_path, map_location=_device))
|
||||
model.to(_device)
|
||||
model.eval()
|
||||
_linknet_models[model_path] = model
|
||||
|
||||
model = _linknet_models[model_path]
|
||||
|
||||
x1, y1, x2, y2 = crop_box
|
||||
cropped_img = original_image[y1:y2, x1:x2]
|
||||
|
||||
img_gray = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2GRAY)
|
||||
crop_h, crop_w = img_gray.shape
|
||||
img_resized = cv2.resize(img_gray, (image_size, image_size))
|
||||
img_normalized = img_resized / 255.0
|
||||
img_tensor = torch.from_numpy(img_normalized).unsqueeze(0).unsqueeze(0).float().to(_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(img_tensor)
|
||||
|
||||
pred_mask_resized = output.cpu().numpy()[0, 0]
|
||||
pred_mask_binary = (pred_mask_resized > 0.5).astype(np.uint8)
|
||||
predicted_mask = cv2.resize(pred_mask_binary, (crop_w, crop_h), interpolation=cv2.INTER_NEAREST) * 255
|
||||
|
||||
start_crop, end_crop = _get_endpoints_from_mask(predicted_mask)
|
||||
|
||||
if start_crop is None:
|
||||
return None, None
|
||||
|
||||
start_orig = (start_crop[0] + x1, start_crop[1] + y1)
|
||||
end_orig = (end_crop[0] + x1, end_crop[1] + y1)
|
||||
|
||||
return start_orig, end_orig
|
||||
Reference in New Issue
Block a user