Files
innovate_project/OpenCV/convert_to_voc.py
2025-11-02 21:36:35 +08:00

192 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import xml.etree.ElementTree as ET
from xml.dom import minidom
from tqdm import tqdm
import re # 引入正则表达式库
# --- 配置参数 ---
# 1. 原始JSON文件所在的文件夹路径
json_folder = '../label/up_json' # 示例路径请修改为您的JSON文件夹
# 2. 原始图片文件所在的文件夹路径 (用于获取图片尺寸)
image_folder = '../label/up' # 示例路径,请修改为您的图片文件夹
# 3. 生成的XML文件要保存的文件夹路径
output_xml_folder = '../label/up_xml'
# 4. 您要检测的目标类别名称 (对应 label "3")
class_name_for_label_3 = "Space weld workpiece" # 这是您XML示例中的名称
# 5. 分组的大小
group_size = 5
# --- 配置结束 ---
def create_xml_annotation(image_info, objects_info):
"""
根据传入的信息生成XML树对象
:param image_info: 包含图片文件名、尺寸等信息的字典
:param objects_info: 包含多个物体信息的列表,每个物体是一个字典
:return: XML ElementTree对象
"""
# 创建根节点
annotation = ET.Element('annotation')
# 子节点 - folder
folder = ET.SubElement(annotation, 'folder')
folder.text = 'JPEGImages'
# 子节点 - filename
filename_node = ET.SubElement(annotation, 'filename')
filename_node.text = image_info['filename']
# 子节点 - path (路径通常不那么重要,但最好有一个)
path = ET.SubElement(annotation, 'path')
# 路径指向JPEGImages文件夹
image_path_in_voc = os.path.join('..', 'JPEGImages', image_info['filename'])
path.text = image_path_in_voc
# 子节点 - source
source = ET.SubElement(annotation, 'source')
database = ET.SubElement(source, 'database')
database.text = 'Unknown'
# 子节点 - size
size = ET.SubElement(annotation, 'size')
width = ET.SubElement(size, 'width')
width.text = str(image_info['width'])
height = ET.SubElement(size, 'height')
height.text = str(image_info['height'])
depth = ET.SubElement(size, 'depth')
depth.text = str(image_info.get('depth', 3))
# 子节点 - segmented
segmented = ET.SubElement(annotation, 'segmented')
segmented.text = '0'
# 为每个物体添加 object 节点
for obj in objects_info:
object_node = ET.SubElement(annotation, 'object')
name = ET.SubElement(object_node, 'name')
name.text = obj['name']
pose = ET.SubElement(object_node, 'pose')
pose.text = 'Unspecified'
truncated = ET.SubElement(object_node, 'truncated')
truncated.text = '0'
difficult = ET.SubElement(object_node, 'difficult')
difficult.text = '0'
bndbox = ET.SubElement(object_node, 'bndbox')
xmin = ET.SubElement(bndbox, 'xmin')
xmin.text = str(int(obj['xmin']))
ymin = ET.SubElement(bndbox, 'ymin')
ymin.text = str(int(obj['ymin']))
xmax = ET.SubElement(bndbox, 'xmax')
xmax.text = str(int(obj['xmax']))
ymax = ET.SubElement(bndbox, 'ymax')
ymax.text = str(int(obj['ymax']))
return annotation
def prettify_xml(elem):
"""
格式化XML输出使其更易读
"""
rough_string = ET.tostring(elem, 'utf-8')
reparsed = minidom.parseString(rough_string)
return reparsed.toprettyxml(indent=" ")
def main():
if not os.path.exists(output_xml_folder):
os.makedirs(output_xml_folder)
print(f"创建输出文件夹: {output_xml_folder}")
json_files = sorted([f for f in os.listdir(json_folder) if f.endswith('.json')])
print(f"找到 {len(json_files)} 个JSON文件开始转换...")
for json_file in tqdm(json_files, desc="处理JSON文件"):
base_name = os.path.splitext(json_file)[0]
# 使用正则表达式匹配前缀和数字
match = re.match(r'([a-zA-Z]+)(\d+)', base_name)
# 1. 检查当前文件是否是一个分组的起始文件
is_group_start_file = False
if match:
num = int(match.group(2))
# 如果数字是 1, 6, 11, ... 这样的,就认为是起始文件
if (num - 1) % group_size == 0:
is_group_start_file = True
else:
# 如果文件名不符合 l1, r5 这种格式,我们认为它是“普通”文件,自己就是一个组
is_group_start_file = True
if not is_group_start_file:
# 如果不是起始文件如l2, l3...则跳过因为它的标注已由l1处理
continue
# --- 是起始文件,处理这个分组 ---
json_path = os.path.join(json_folder, json_file)
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 2. 从起始文件中提取所有符合条件的标注对象
objects_to_write = []
for shape in data.get('shapes', []):
if shape.get('label') == '1' and shape.get('shape_type') == 'rectangle':
points = shape.get('points', [])
if len(points) == 2:
x_coords = sorted([p[0] for p in points])
y_coords = sorted([p[1] for p in points])
objects_to_write.append({
'name': class_name_for_label_3,
'xmin': x_coords[0], 'ymin': y_coords[0],
'xmax': x_coords[1], 'ymax': y_coords[1],
})
if not objects_to_write:
continue
# 3. 确定该标注要应用到哪些图片上
target_image_names = []
if match:
# 文件名符合 l1, r6 等格式
prefix = match.group(1)
start_num = int(match.group(2))
for i in range(group_size):
# 假设图片格式为 .jpg
target_image_names.append(f"{prefix}{start_num + i}.jpg")
else:
# 普通文件,只应用到同名文件
# 假设图片格式为 .jpg
target_image_names.append(f"{base_name}.jpg")
# 4. 为分组内的每个目标图片生成XML文件
for image_name in target_image_names:
image_path = os.path.join(image_folder, image_name)
if not os.path.exists(image_path):
print(f"\n警告:找不到图片 '{image_name}'跳过生成其XML文件。")
continue
# 使用JSON中的尺寸信息
image_info = {'filename': image_name, 'width': data['imageWidth'], 'height': data['imageHeight']}
xml_tree = create_xml_annotation(image_info, objects_to_write)
xml_string = prettify_xml(xml_tree)
xml_filename = os.path.splitext(image_name)[0] + '.xml'
output_path = os.path.join(output_xml_folder, xml_filename)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(xml_string)
print("转换完成所有XML文件已保存在: ", output_xml_folder)
if __name__ == '__main__':
main()