import os import json import cv2 import numpy as np from PIL import Image def generate_unet_dataset(original_image_dir, annotation_dir, output_dir, line_thickness=12): """ 根据给定的原始图片和LabelMe格式的JSON标注,生成UNet训练数据集。 该函数会首先根据标注中的矩形(label: "3")裁剪图片, 然后根据两条直线(label: "1" 和 "2")生成对应的二值化掩码(mask)。 Args: original_image_dir (str): 包含原始图片 (000.jpg, 001.jpg, ...) 的文件夹路径。 annotation_dir (str): 包含JSON标注文件 (l1.json, r1.json, ...) 的文件夹路径。 output_dir (str): 用于存放处理后数据(图片和掩码)的输出文件夹路径。 line_thickness (int): 在掩码上绘制直线的粗细度。 """ # 1. 在输出目录中创建子文件夹 cropped_img_path = os.path.join(output_dir, 'images') mask1_path = os.path.join(output_dir, 'masks_line1') mask2_path = os.path.join(output_dir, 'masks_line2') os.makedirs(cropped_img_path, exist_ok=True) os.makedirs(mask1_path, exist_ok=True) os.makedirs(mask2_path, exist_ok=True) print(f"输出文件夹已创建于: {output_dir}") # 2. 遍历所有原始图片 image_files = sorted([f for f in os.listdir(original_image_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]) for image_filename in image_files: try: # 提取图片编号,例如从 '007.jpg' 中得到 7 file_index = int(os.path.splitext(image_filename)[0]) # 3. 根据命名规则确定对应的JSON文件名 # 0-49 -> l, 50-99 -> r side = 'l' if file_index < 50 else 'r' # 0-4 -> 1, 5-9 -> 2 ... label_index = (file_index % 50) // 5 * 5 + 1 json_filename = f"{side}{label_index}.json" json_filepath = os.path.join(annotation_dir, json_filename) if not os.path.exists(json_filepath): print(f"警告:找不到图片 '{image_filename}' 对应的标注文件 '{json_filename}',已跳过。") continue # 4. 读取并解析JSON文件 with open(json_filepath, 'r') as f: data = json.load(f) # 提取标注形状 shapes = data['shapes'] rect_shape = next((s for s in shapes if s['label'] == '1' and s['shape_type'] == 'rectangle'), None) line1_shape = next((s for s in shapes if s['label'] == '2' and s['shape_type'] == 'line'), None) line2_shape = next((s for s in shapes if s['label'] == '3' and s['shape_type'] == 'line'), None) if not all([rect_shape, line1_shape, line2_shape]): print(f"警告:标注文件 '{json_filename}' 中缺少必要的形状(矩形'3', 直线'1'或'2'),已跳过。") continue # 5. 图像裁剪 # 读取原始图片 image_path = os.path.join(original_image_dir, image_filename) original_image = Image.open(image_path) # 获取矩形坐标 p1, p2 = rect_shape['points'] # 标准化矩形坐标 (min_x, min_y, max_x, max_y) left = int(min(p1[0], p2[0])) upper = int(min(p1[1], p2[1])) right = int(max(p1[0], p2[0])) lower = int(max(p1[1], p2[1])) # 执行裁剪 cropped_image = original_image.crop((left, upper, right, lower)) # 保存裁剪后的图片 cropped_image_filename = os.path.join(cropped_img_path, os.path.basename(image_filename)) cropped_image.save(cropped_image_filename) # 6. 生成并保存掩码 # 获取裁剪后图片的尺寸 width, height = cropped_image.size # 调整直线坐标系,使其与裁剪后的图片对应 line1_points = np.array(line1_shape['points']) - [left, upper] line2_points = np.array(line2_shape['points']) - [left, upper] # 创建两个空的黑色背景(掩码) mask1 = np.zeros((height, width), dtype=np.uint8) mask2 = np.zeros((height, width), dtype=np.uint8) # 在掩码上绘制白色线条 # OpenCV的line函数需要整数坐标 pt1_l1 = tuple(line1_points[0].astype(int)) pt2_l1 = tuple(line1_points[1].astype(int)) cv2.line(mask1, pt1_l1, pt2_l1, color=255, thickness=line_thickness) pt1_l2 = tuple(line2_points[0].astype(int)) pt2_l2 = tuple(line2_points[1].astype(int)) cv2.line(mask2, pt1_l2, pt2_l2, color=255, thickness=line_thickness) # 保存掩码为PNG格式 # 使用splitext来获取不带扩展名的文件名 base_filename, _ = os.path.splitext(image_filename) png_filename = base_filename + ".png" mask1_savename = os.path.join(mask1_path, png_filename) mask2_savename = os.path.join(mask2_path, png_filename) cv2.imwrite(mask1_savename, mask1) cv2.imwrite(mask2_savename, mask2) print(f"成功处理: {image_filename} -> {json_filename}") except Exception as e: print(f"处理图片 '{image_filename}' 时发生错误: {e}") print("\n所有图片处理完成!") if __name__ =="__main__": generate_unet_dataset("../label/up","../label/up_json", "data_up")