三维重构终版
This commit is contained in:
		
							
								
								
									
										127
									
								
								unet/generate_unet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								unet/generate_unet.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,127 @@
 | 
			
		||||
import os
 | 
			
		||||
import json
 | 
			
		||||
import cv2
 | 
			
		||||
import numpy as np
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_unet_dataset(original_image_dir, annotation_dir, output_dir, line_thickness=12):
 | 
			
		||||
    """
 | 
			
		||||
    根据给定的原始图片和LabelMe格式的JSON标注,生成UNet训练数据集。
 | 
			
		||||
 | 
			
		||||
    该函数会首先根据标注中的矩形(label: "3")裁剪图片,
 | 
			
		||||
    然后根据两条直线(label: "1" 和 "2")生成对应的二值化掩码(mask)。
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        original_image_dir (str): 包含原始图片 (000.jpg, 001.jpg, ...) 的文件夹路径。
 | 
			
		||||
        annotation_dir (str): 包含JSON标注文件 (l1.json, r1.json, ...) 的文件夹路径。
 | 
			
		||||
        output_dir (str): 用于存放处理后数据(图片和掩码)的输出文件夹路径。
 | 
			
		||||
        line_thickness (int): 在掩码上绘制直线的粗细度。
 | 
			
		||||
    """
 | 
			
		||||
    # 1. 在输出目录中创建子文件夹
 | 
			
		||||
    cropped_img_path = os.path.join(output_dir, 'images')
 | 
			
		||||
    mask1_path = os.path.join(output_dir, 'masks_line1')
 | 
			
		||||
    mask2_path = os.path.join(output_dir, 'masks_line2')
 | 
			
		||||
 | 
			
		||||
    os.makedirs(cropped_img_path, exist_ok=True)
 | 
			
		||||
    os.makedirs(mask1_path, exist_ok=True)
 | 
			
		||||
    os.makedirs(mask2_path, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
    print(f"输出文件夹已创建于: {output_dir}")
 | 
			
		||||
 | 
			
		||||
    # 2. 遍历所有原始图片
 | 
			
		||||
    image_files = sorted([f for f in os.listdir(original_image_dir) if f.endswith(('.jpg', '.jpeg', '.png'))])
 | 
			
		||||
 | 
			
		||||
    for image_filename in image_files:
 | 
			
		||||
        try:
 | 
			
		||||
            # 提取图片编号,例如从 '007.jpg' 中得到 7
 | 
			
		||||
            file_index = int(os.path.splitext(image_filename)[0])
 | 
			
		||||
 | 
			
		||||
            # 3. 根据命名规则确定对应的JSON文件名
 | 
			
		||||
            # 0-49 -> l, 50-99 -> r
 | 
			
		||||
            side = 'l' if file_index < 50 else 'r'
 | 
			
		||||
            # 0-4 -> 1, 5-9 -> 2 ...
 | 
			
		||||
            label_index = (file_index % 50) // 5 * 5 + 1
 | 
			
		||||
            json_filename = f"{side}{label_index}.json"
 | 
			
		||||
            json_filepath = os.path.join(annotation_dir, json_filename)
 | 
			
		||||
 | 
			
		||||
            if not os.path.exists(json_filepath):
 | 
			
		||||
                print(f"警告:找不到图片 '{image_filename}' 对应的标注文件 '{json_filename}',已跳过。")
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            # 4. 读取并解析JSON文件
 | 
			
		||||
            with open(json_filepath, 'r') as f:
 | 
			
		||||
                data = json.load(f)
 | 
			
		||||
 | 
			
		||||
            # 提取标注形状
 | 
			
		||||
            shapes = data['shapes']
 | 
			
		||||
            rect_shape = next((s for s in shapes if s['label'] == '1' and s['shape_type'] == 'rectangle'), None)
 | 
			
		||||
            line1_shape = next((s for s in shapes if s['label'] == '2' and s['shape_type'] == 'line'), None)
 | 
			
		||||
            line2_shape = next((s for s in shapes if s['label'] == '3' and s['shape_type'] == 'line'), None)
 | 
			
		||||
 | 
			
		||||
            if not all([rect_shape, line1_shape, line2_shape]):
 | 
			
		||||
                print(f"警告:标注文件 '{json_filename}' 中缺少必要的形状(矩形'3', 直线'1'或'2'),已跳过。")
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            # 5. 图像裁剪
 | 
			
		||||
            # 读取原始图片
 | 
			
		||||
            image_path = os.path.join(original_image_dir, image_filename)
 | 
			
		||||
            original_image = Image.open(image_path)
 | 
			
		||||
 | 
			
		||||
            # 获取矩形坐标
 | 
			
		||||
            p1, p2 = rect_shape['points']
 | 
			
		||||
            # 标准化矩形坐标 (min_x, min_y, max_x, max_y)
 | 
			
		||||
            left = int(min(p1[0], p2[0]))
 | 
			
		||||
            upper = int(min(p1[1], p2[1]))
 | 
			
		||||
            right = int(max(p1[0], p2[0]))
 | 
			
		||||
            lower = int(max(p1[1], p2[1]))
 | 
			
		||||
 | 
			
		||||
            # 执行裁剪
 | 
			
		||||
            cropped_image = original_image.crop((left, upper, right, lower))
 | 
			
		||||
 | 
			
		||||
            # 保存裁剪后的图片
 | 
			
		||||
            cropped_image_filename = os.path.join(cropped_img_path, os.path.basename(image_filename))
 | 
			
		||||
            cropped_image.save(cropped_image_filename)
 | 
			
		||||
 | 
			
		||||
            # 6. 生成并保存掩码
 | 
			
		||||
            # 获取裁剪后图片的尺寸
 | 
			
		||||
            width, height = cropped_image.size
 | 
			
		||||
 | 
			
		||||
            # 调整直线坐标系,使其与裁剪后的图片对应
 | 
			
		||||
            line1_points = np.array(line1_shape['points']) - [left, upper]
 | 
			
		||||
            line2_points = np.array(line2_shape['points']) - [left, upper]
 | 
			
		||||
 | 
			
		||||
            # 创建两个空的黑色背景(掩码)
 | 
			
		||||
            mask1 = np.zeros((height, width), dtype=np.uint8)
 | 
			
		||||
            mask2 = np.zeros((height, width), dtype=np.uint8)
 | 
			
		||||
 | 
			
		||||
            # 在掩码上绘制白色线条
 | 
			
		||||
            # OpenCV的line函数需要整数坐标
 | 
			
		||||
            pt1_l1 = tuple(line1_points[0].astype(int))
 | 
			
		||||
            pt2_l1 = tuple(line1_points[1].astype(int))
 | 
			
		||||
            cv2.line(mask1, pt1_l1, pt2_l1, color=255, thickness=line_thickness)
 | 
			
		||||
 | 
			
		||||
            pt1_l2 = tuple(line2_points[0].astype(int))
 | 
			
		||||
            pt2_l2 = tuple(line2_points[1].astype(int))
 | 
			
		||||
            cv2.line(mask2, pt1_l2, pt2_l2, color=255, thickness=line_thickness)
 | 
			
		||||
 | 
			
		||||
            # 保存掩码为PNG格式
 | 
			
		||||
            # 使用splitext来获取不带扩展名的文件名
 | 
			
		||||
            base_filename, _ = os.path.splitext(image_filename)
 | 
			
		||||
            png_filename = base_filename + ".png"
 | 
			
		||||
 | 
			
		||||
            mask1_savename = os.path.join(mask1_path, png_filename)
 | 
			
		||||
            mask2_savename = os.path.join(mask2_path, png_filename)
 | 
			
		||||
 | 
			
		||||
            cv2.imwrite(mask1_savename, mask1)
 | 
			
		||||
            cv2.imwrite(mask2_savename, mask2)
 | 
			
		||||
 | 
			
		||||
            print(f"成功处理: {image_filename} -> {json_filename}")
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"处理图片 '{image_filename}' 时发生错误: {e}")
 | 
			
		||||
 | 
			
		||||
    print("\n所有图片处理完成!")
 | 
			
		||||
 | 
			
		||||
if __name__ =="__main__":
 | 
			
		||||
    generate_unet_dataset("../label/up","../label/up_json", "data_up")
 | 
			
		||||
		Reference in New Issue
	
	Block a user