Files
innovate_project/unet/generate_unet.py
2025-11-02 21:36:35 +08:00

127 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import cv2
import numpy as np
from PIL import Image
def generate_unet_dataset(original_image_dir, annotation_dir, output_dir, line_thickness=12):
"""
根据给定的原始图片和LabelMe格式的JSON标注生成UNet训练数据集。
该函数会首先根据标注中的矩形label: "3")裁剪图片,
然后根据两条直线label: "1""2"生成对应的二值化掩码mask
Args:
original_image_dir (str): 包含原始图片 (000.jpg, 001.jpg, ...) 的文件夹路径。
annotation_dir (str): 包含JSON标注文件 (l1.json, r1.json, ...) 的文件夹路径。
output_dir (str): 用于存放处理后数据(图片和掩码)的输出文件夹路径。
line_thickness (int): 在掩码上绘制直线的粗细度。
"""
# 1. 在输出目录中创建子文件夹
cropped_img_path = os.path.join(output_dir, 'images')
mask1_path = os.path.join(output_dir, 'masks_line1')
mask2_path = os.path.join(output_dir, 'masks_line2')
os.makedirs(cropped_img_path, exist_ok=True)
os.makedirs(mask1_path, exist_ok=True)
os.makedirs(mask2_path, exist_ok=True)
print(f"输出文件夹已创建于: {output_dir}")
# 2. 遍历所有原始图片
image_files = sorted([f for f in os.listdir(original_image_dir) if f.endswith(('.jpg', '.jpeg', '.png'))])
for image_filename in image_files:
try:
# 提取图片编号,例如从 '007.jpg' 中得到 7
file_index = int(os.path.splitext(image_filename)[0])
# 3. 根据命名规则确定对应的JSON文件名
# 0-49 -> l, 50-99 -> r
side = 'l' if file_index < 50 else 'r'
# 0-4 -> 1, 5-9 -> 2 ...
label_index = (file_index % 50) // 5 * 5 + 1
json_filename = f"{side}{label_index}.json"
json_filepath = os.path.join(annotation_dir, json_filename)
if not os.path.exists(json_filepath):
print(f"警告:找不到图片 '{image_filename}' 对应的标注文件 '{json_filename}',已跳过。")
continue
# 4. 读取并解析JSON文件
with open(json_filepath, 'r') as f:
data = json.load(f)
# 提取标注形状
shapes = data['shapes']
rect_shape = next((s for s in shapes if s['label'] == '1' and s['shape_type'] == 'rectangle'), None)
line1_shape = next((s for s in shapes if s['label'] == '2' and s['shape_type'] == 'line'), None)
line2_shape = next((s for s in shapes if s['label'] == '3' and s['shape_type'] == 'line'), None)
if not all([rect_shape, line1_shape, line2_shape]):
print(f"警告:标注文件 '{json_filename}' 中缺少必要的形状(矩形'3', 直线'1''2'),已跳过。")
continue
# 5. 图像裁剪
# 读取原始图片
image_path = os.path.join(original_image_dir, image_filename)
original_image = Image.open(image_path)
# 获取矩形坐标
p1, p2 = rect_shape['points']
# 标准化矩形坐标 (min_x, min_y, max_x, max_y)
left = int(min(p1[0], p2[0]))
upper = int(min(p1[1], p2[1]))
right = int(max(p1[0], p2[0]))
lower = int(max(p1[1], p2[1]))
# 执行裁剪
cropped_image = original_image.crop((left, upper, right, lower))
# 保存裁剪后的图片
cropped_image_filename = os.path.join(cropped_img_path, os.path.basename(image_filename))
cropped_image.save(cropped_image_filename)
# 6. 生成并保存掩码
# 获取裁剪后图片的尺寸
width, height = cropped_image.size
# 调整直线坐标系,使其与裁剪后的图片对应
line1_points = np.array(line1_shape['points']) - [left, upper]
line2_points = np.array(line2_shape['points']) - [left, upper]
# 创建两个空的黑色背景(掩码)
mask1 = np.zeros((height, width), dtype=np.uint8)
mask2 = np.zeros((height, width), dtype=np.uint8)
# 在掩码上绘制白色线条
# OpenCV的line函数需要整数坐标
pt1_l1 = tuple(line1_points[0].astype(int))
pt2_l1 = tuple(line1_points[1].astype(int))
cv2.line(mask1, pt1_l1, pt2_l1, color=255, thickness=line_thickness)
pt1_l2 = tuple(line2_points[0].astype(int))
pt2_l2 = tuple(line2_points[1].astype(int))
cv2.line(mask2, pt1_l2, pt2_l2, color=255, thickness=line_thickness)
# 保存掩码为PNG格式
# 使用splitext来获取不带扩展名的文件名
base_filename, _ = os.path.splitext(image_filename)
png_filename = base_filename + ".png"
mask1_savename = os.path.join(mask1_path, png_filename)
mask2_savename = os.path.join(mask2_path, png_filename)
cv2.imwrite(mask1_savename, mask1)
cv2.imwrite(mask2_savename, mask2)
print(f"成功处理: {image_filename} -> {json_filename}")
except Exception as e:
print(f"处理图片 '{image_filename}' 时发生错误: {e}")
print("\n所有图片处理完成!")
if __name__ =="__main__":
generate_unet_dataset("../label/up","../label/up_json", "data_up")