119 lines
4.0 KiB
Python
119 lines
4.0 KiB
Python
import xml.etree.ElementTree as ET
|
|
import os
|
|
import glob
|
|
|
|
|
|
def voc_to_yolo(xml_file_path, output_dir, class_mapping):
|
|
"""
|
|
将单个 PASCAL VOC anntation (.xml) 文件转换为 YOLO (.txt) 格式。
|
|
|
|
Args:
|
|
xml_file_path (str): 输入的 .xml 文件路径。
|
|
output_dir (str): 输出 .txt 文件的目标文件夹。
|
|
class_mapping (dict): 类别名称到类别ID的映射字典。
|
|
"""
|
|
try:
|
|
# 解析 XML 文件
|
|
tree = ET.parse(xml_file_path)
|
|
root = tree.getroot()
|
|
|
|
# 获取图像尺寸
|
|
size = root.find('size')
|
|
if size is None:
|
|
print(f"警告: 在 {xml_file_path} 中未找到 <size> 标签,跳过此文件。")
|
|
return
|
|
|
|
img_width = int(size.find('width').text)
|
|
img_height = int(size.find('height').text)
|
|
|
|
# 准备用于写入的YOLO标注列表
|
|
yolo_annotations = []
|
|
|
|
# 遍历所有 object
|
|
for obj in root.findall('object'):
|
|
# 获取类别名称
|
|
class_name = obj.find('name').text
|
|
if class_name not in class_mapping:
|
|
print(f"警告: 类别 '{class_name}' 不在预定义的 class_mapping 中,跳过此物体。")
|
|
continue
|
|
|
|
class_id = class_mapping[class_name]
|
|
|
|
# 获取边界框坐标
|
|
bndbox = obj.find('bndbox')
|
|
xmin = float(bndbox.find('xmin').text)
|
|
ymin = float(bndbox.find('ymin').text)
|
|
xmax = float(bndbox.find('xmax').text)
|
|
ymax = float(bndbox.find('ymax').text)
|
|
|
|
# --- 核心转换公式 ---
|
|
x_center = (xmin + xmax) / 2.0 / img_width
|
|
y_center = (ymin + ymax) / 2.0 / img_height
|
|
width = (xmax - xmin) / img_width
|
|
height = (ymax - ymin) / img_height
|
|
|
|
# 将结果添加到列表
|
|
yolo_annotations.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
|
|
|
|
# 如果文件中有有效的物体,则写入 .txt 文件
|
|
if yolo_annotations:
|
|
# 构建输出文件名
|
|
base_filename = os.path.basename(xml_file_path)
|
|
txt_filename = os.path.splitext(base_filename)[0] + '.txt'
|
|
output_path = os.path.join(output_dir, txt_filename)
|
|
|
|
# 写入文件
|
|
with open(output_path, 'w') as f:
|
|
f.write('\n'.join(yolo_annotations))
|
|
|
|
# print(f"成功转换: {xml_file_path} -> {output_path}")
|
|
|
|
except Exception as e:
|
|
print(f"处理文件 {xml_file_path} 时发生错误: {e}")
|
|
|
|
|
|
def main():
|
|
# --- 用户需要配置的参数 ---
|
|
|
|
# 1. 定义你的类别和对应的ID (从0开始)
|
|
# 根据你的截图,你只有一个类别 "Space weld workpiece"
|
|
# 请确保这里的名称与你XML文件中的<name>标签完全一致!
|
|
CLASS_MAPPING = {
|
|
'Space weld workpiece': 0,
|
|
# 如果有其他类别,在这里添加,例如:
|
|
# 'weld_seam': 1,
|
|
}
|
|
|
|
# 2. 定义输入和输出文件夹
|
|
# 输入文件夹: 存放 .xml 文件的目录
|
|
input_xml_dir = '../label/up_xml'
|
|
|
|
# 输出文件夹: 存放转换后的 .txt 文件的目录
|
|
output_txt_dir = 'data_up'
|
|
|
|
# --- 脚本执行部分 ---
|
|
|
|
# 自动创建输出文件夹(如果不存在)
|
|
if not os.path.exists(output_txt_dir):
|
|
os.makedirs(output_txt_dir)
|
|
print(f"已创建输出文件夹: {output_txt_dir}")
|
|
|
|
# 查找所有 .xml 文件
|
|
xml_files = glob.glob(os.path.join(input_xml_dir, '*.xml'))
|
|
|
|
if not xml_files:
|
|
print(f"错误: 在目录 '{input_xml_dir}' 中没有找到任何 .xml 文件。请检查路径。")
|
|
return
|
|
|
|
print(f"找到 {len(xml_files)} 个 .xml 文件。开始转换...")
|
|
|
|
# 遍历并转换每个文件
|
|
for xml_file in xml_files:
|
|
voc_to_yolo(xml_file, output_txt_dir, CLASS_MAPPING)
|
|
|
|
print("\n转换完成!")
|
|
print(f"所有 YOLO 格式的标签文件已保存在: {output_txt_dir}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |