127 lines
5.4 KiB
Python
127 lines
5.4 KiB
Python
import os
|
||
import json
|
||
import cv2
|
||
import numpy as np
|
||
from PIL import Image
|
||
|
||
|
||
def generate_unet_dataset(original_image_dir, annotation_dir, output_dir, line_thickness=12):
|
||
"""
|
||
根据给定的原始图片和LabelMe格式的JSON标注,生成UNet训练数据集。
|
||
|
||
该函数会首先根据标注中的矩形(label: "3")裁剪图片,
|
||
然后根据两条直线(label: "1" 和 "2")生成对应的二值化掩码(mask)。
|
||
|
||
Args:
|
||
original_image_dir (str): 包含原始图片 (000.jpg, 001.jpg, ...) 的文件夹路径。
|
||
annotation_dir (str): 包含JSON标注文件 (l1.json, r1.json, ...) 的文件夹路径。
|
||
output_dir (str): 用于存放处理后数据(图片和掩码)的输出文件夹路径。
|
||
line_thickness (int): 在掩码上绘制直线的粗细度。
|
||
"""
|
||
# 1. 在输出目录中创建子文件夹
|
||
cropped_img_path = os.path.join(output_dir, 'images')
|
||
mask1_path = os.path.join(output_dir, 'masks_line1')
|
||
mask2_path = os.path.join(output_dir, 'masks_line2')
|
||
|
||
os.makedirs(cropped_img_path, exist_ok=True)
|
||
os.makedirs(mask1_path, exist_ok=True)
|
||
os.makedirs(mask2_path, exist_ok=True)
|
||
|
||
print(f"输出文件夹已创建于: {output_dir}")
|
||
|
||
# 2. 遍历所有原始图片
|
||
image_files = sorted([f for f in os.listdir(original_image_dir) if f.endswith(('.jpg', '.jpeg', '.png'))])
|
||
|
||
for image_filename in image_files:
|
||
try:
|
||
# 提取图片编号,例如从 '007.jpg' 中得到 7
|
||
file_index = int(os.path.splitext(image_filename)[0])
|
||
|
||
# 3. 根据命名规则确定对应的JSON文件名
|
||
# 0-49 -> l, 50-99 -> r
|
||
side = 'l' if file_index < 50 else 'r'
|
||
# 0-4 -> 1, 5-9 -> 2 ...
|
||
label_index = (file_index % 50) // 5 * 5 + 1
|
||
json_filename = f"{side}{label_index}.json"
|
||
json_filepath = os.path.join(annotation_dir, json_filename)
|
||
|
||
if not os.path.exists(json_filepath):
|
||
print(f"警告:找不到图片 '{image_filename}' 对应的标注文件 '{json_filename}',已跳过。")
|
||
continue
|
||
|
||
# 4. 读取并解析JSON文件
|
||
with open(json_filepath, 'r') as f:
|
||
data = json.load(f)
|
||
|
||
# 提取标注形状
|
||
shapes = data['shapes']
|
||
rect_shape = next((s for s in shapes if s['label'] == '1' and s['shape_type'] == 'rectangle'), None)
|
||
line1_shape = next((s for s in shapes if s['label'] == '2' and s['shape_type'] == 'line'), None)
|
||
line2_shape = next((s for s in shapes if s['label'] == '3' and s['shape_type'] == 'line'), None)
|
||
|
||
if not all([rect_shape, line1_shape, line2_shape]):
|
||
print(f"警告:标注文件 '{json_filename}' 中缺少必要的形状(矩形'3', 直线'1'或'2'),已跳过。")
|
||
continue
|
||
|
||
# 5. 图像裁剪
|
||
# 读取原始图片
|
||
image_path = os.path.join(original_image_dir, image_filename)
|
||
original_image = Image.open(image_path)
|
||
|
||
# 获取矩形坐标
|
||
p1, p2 = rect_shape['points']
|
||
# 标准化矩形坐标 (min_x, min_y, max_x, max_y)
|
||
left = int(min(p1[0], p2[0]))
|
||
upper = int(min(p1[1], p2[1]))
|
||
right = int(max(p1[0], p2[0]))
|
||
lower = int(max(p1[1], p2[1]))
|
||
|
||
# 执行裁剪
|
||
cropped_image = original_image.crop((left, upper, right, lower))
|
||
|
||
# 保存裁剪后的图片
|
||
cropped_image_filename = os.path.join(cropped_img_path, os.path.basename(image_filename))
|
||
cropped_image.save(cropped_image_filename)
|
||
|
||
# 6. 生成并保存掩码
|
||
# 获取裁剪后图片的尺寸
|
||
width, height = cropped_image.size
|
||
|
||
# 调整直线坐标系,使其与裁剪后的图片对应
|
||
line1_points = np.array(line1_shape['points']) - [left, upper]
|
||
line2_points = np.array(line2_shape['points']) - [left, upper]
|
||
|
||
# 创建两个空的黑色背景(掩码)
|
||
mask1 = np.zeros((height, width), dtype=np.uint8)
|
||
mask2 = np.zeros((height, width), dtype=np.uint8)
|
||
|
||
# 在掩码上绘制白色线条
|
||
# OpenCV的line函数需要整数坐标
|
||
pt1_l1 = tuple(line1_points[0].astype(int))
|
||
pt2_l1 = tuple(line1_points[1].astype(int))
|
||
cv2.line(mask1, pt1_l1, pt2_l1, color=255, thickness=line_thickness)
|
||
|
||
pt1_l2 = tuple(line2_points[0].astype(int))
|
||
pt2_l2 = tuple(line2_points[1].astype(int))
|
||
cv2.line(mask2, pt1_l2, pt2_l2, color=255, thickness=line_thickness)
|
||
|
||
# 保存掩码为PNG格式
|
||
# 使用splitext来获取不带扩展名的文件名
|
||
base_filename, _ = os.path.splitext(image_filename)
|
||
png_filename = base_filename + ".png"
|
||
|
||
mask1_savename = os.path.join(mask1_path, png_filename)
|
||
mask2_savename = os.path.join(mask2_path, png_filename)
|
||
|
||
cv2.imwrite(mask1_savename, mask1)
|
||
cv2.imwrite(mask2_savename, mask2)
|
||
|
||
print(f"成功处理: {image_filename} -> {json_filename}")
|
||
|
||
except Exception as e:
|
||
print(f"处理图片 '{image_filename}' 时发生错误: {e}")
|
||
|
||
print("\n所有图片处理完成!")
|
||
|
||
if __name__ =="__main__":
|
||
generate_unet_dataset("../label/up","../label/up_json", "data_up") |