三维重构终版

This commit is contained in:
2025-11-02 21:36:35 +08:00
parent f91b09da9d
commit f39009b853
126 changed files with 2870 additions and 2 deletions

127
unet/generate_unet.py Normal file
View File

@@ -0,0 +1,127 @@
import os
import json
import cv2
import numpy as np
from PIL import Image
def generate_unet_dataset(original_image_dir, annotation_dir, output_dir, line_thickness=12):
"""
根据给定的原始图片和LabelMe格式的JSON标注生成UNet训练数据集。
该函数会首先根据标注中的矩形label: "3")裁剪图片,
然后根据两条直线label: "1""2"生成对应的二值化掩码mask
Args:
original_image_dir (str): 包含原始图片 (000.jpg, 001.jpg, ...) 的文件夹路径。
annotation_dir (str): 包含JSON标注文件 (l1.json, r1.json, ...) 的文件夹路径。
output_dir (str): 用于存放处理后数据(图片和掩码)的输出文件夹路径。
line_thickness (int): 在掩码上绘制直线的粗细度。
"""
# 1. 在输出目录中创建子文件夹
cropped_img_path = os.path.join(output_dir, 'images')
mask1_path = os.path.join(output_dir, 'masks_line1')
mask2_path = os.path.join(output_dir, 'masks_line2')
os.makedirs(cropped_img_path, exist_ok=True)
os.makedirs(mask1_path, exist_ok=True)
os.makedirs(mask2_path, exist_ok=True)
print(f"输出文件夹已创建于: {output_dir}")
# 2. 遍历所有原始图片
image_files = sorted([f for f in os.listdir(original_image_dir) if f.endswith(('.jpg', '.jpeg', '.png'))])
for image_filename in image_files:
try:
# 提取图片编号,例如从 '007.jpg' 中得到 7
file_index = int(os.path.splitext(image_filename)[0])
# 3. 根据命名规则确定对应的JSON文件名
# 0-49 -> l, 50-99 -> r
side = 'l' if file_index < 50 else 'r'
# 0-4 -> 1, 5-9 -> 2 ...
label_index = (file_index % 50) // 5 * 5 + 1
json_filename = f"{side}{label_index}.json"
json_filepath = os.path.join(annotation_dir, json_filename)
if not os.path.exists(json_filepath):
print(f"警告:找不到图片 '{image_filename}' 对应的标注文件 '{json_filename}',已跳过。")
continue
# 4. 读取并解析JSON文件
with open(json_filepath, 'r') as f:
data = json.load(f)
# 提取标注形状
shapes = data['shapes']
rect_shape = next((s for s in shapes if s['label'] == '1' and s['shape_type'] == 'rectangle'), None)
line1_shape = next((s for s in shapes if s['label'] == '2' and s['shape_type'] == 'line'), None)
line2_shape = next((s for s in shapes if s['label'] == '3' and s['shape_type'] == 'line'), None)
if not all([rect_shape, line1_shape, line2_shape]):
print(f"警告:标注文件 '{json_filename}' 中缺少必要的形状(矩形'3', 直线'1''2'),已跳过。")
continue
# 5. 图像裁剪
# 读取原始图片
image_path = os.path.join(original_image_dir, image_filename)
original_image = Image.open(image_path)
# 获取矩形坐标
p1, p2 = rect_shape['points']
# 标准化矩形坐标 (min_x, min_y, max_x, max_y)
left = int(min(p1[0], p2[0]))
upper = int(min(p1[1], p2[1]))
right = int(max(p1[0], p2[0]))
lower = int(max(p1[1], p2[1]))
# 执行裁剪
cropped_image = original_image.crop((left, upper, right, lower))
# 保存裁剪后的图片
cropped_image_filename = os.path.join(cropped_img_path, os.path.basename(image_filename))
cropped_image.save(cropped_image_filename)
# 6. 生成并保存掩码
# 获取裁剪后图片的尺寸
width, height = cropped_image.size
# 调整直线坐标系,使其与裁剪后的图片对应
line1_points = np.array(line1_shape['points']) - [left, upper]
line2_points = np.array(line2_shape['points']) - [left, upper]
# 创建两个空的黑色背景(掩码)
mask1 = np.zeros((height, width), dtype=np.uint8)
mask2 = np.zeros((height, width), dtype=np.uint8)
# 在掩码上绘制白色线条
# OpenCV的line函数需要整数坐标
pt1_l1 = tuple(line1_points[0].astype(int))
pt2_l1 = tuple(line1_points[1].astype(int))
cv2.line(mask1, pt1_l1, pt2_l1, color=255, thickness=line_thickness)
pt1_l2 = tuple(line2_points[0].astype(int))
pt2_l2 = tuple(line2_points[1].astype(int))
cv2.line(mask2, pt1_l2, pt2_l2, color=255, thickness=line_thickness)
# 保存掩码为PNG格式
# 使用splitext来获取不带扩展名的文件名
base_filename, _ = os.path.splitext(image_filename)
png_filename = base_filename + ".png"
mask1_savename = os.path.join(mask1_path, png_filename)
mask2_savename = os.path.join(mask2_path, png_filename)
cv2.imwrite(mask1_savename, mask1)
cv2.imwrite(mask2_savename, mask2)
print(f"成功处理: {image_filename} -> {json_filename}")
except Exception as e:
print(f"处理图片 '{image_filename}' 时发生错误: {e}")
print("\n所有图片处理完成!")
if __name__ =="__main__":
generate_unet_dataset("../label/up","../label/up_json", "data_up")