116 lines
4.4 KiB
Python
116 lines
4.4 KiB
Python
import os
|
|
import random
|
|
import shutil
|
|
|
|
|
|
def split_dataset(image_dir, label_dir, output_dir, split_ratio=0.8):
|
|
"""
|
|
自动将图片和标签划分为训练集和验证集。
|
|
|
|
Args:
|
|
image_dir (str): 原始图片文件夹路径。
|
|
label_dir (str): 转换后的 YOLO 标签 (.txt) 文件夹路径。
|
|
output_dir (str): 整理好的数据集输出根目录 (例如 'weld_dataset')。
|
|
split_ratio (float): 训练集所占的比例,例如 0.8 代表 80% 训练, 20% 验证。
|
|
"""
|
|
print("开始划分数据集...")
|
|
|
|
# --- 1. 路径设置和文件夹创建 ---
|
|
train_img_path = os.path.join(output_dir, 'images', 'train')
|
|
val_img_path = os.path.join(output_dir, 'images', 'val')
|
|
train_label_path = os.path.join(output_dir, 'labels', 'train')
|
|
val_label_path = os.path.join(output_dir, 'labels', 'val')
|
|
|
|
# 创建所有必要的文件夹
|
|
os.makedirs(train_img_path, exist_ok=True)
|
|
os.makedirs(val_img_path, exist_ok=True)
|
|
os.makedirs(train_label_path, exist_ok=True)
|
|
os.makedirs(val_label_path, exist_ok=True)
|
|
|
|
# --- 2. 文件匹配 ---
|
|
# 获取所有标签文件的基础名(不含扩展名)
|
|
label_files = [os.path.splitext(f)[0] for f in os.listdir(label_dir) if f.endswith('.txt')]
|
|
|
|
# 查找对应的图片文件(支持多种格式)
|
|
image_files_map = {}
|
|
supported_formats = ['.jpg', '.jpeg', '.png', '.bmp']
|
|
for f in os.listdir(image_dir):
|
|
base_name, ext = os.path.splitext(f)
|
|
if ext.lower() in supported_formats:
|
|
image_files_map[base_name] = f
|
|
|
|
# 找出图片和标签都存在的文件对
|
|
valid_files = [base_name for base_name in label_files if base_name in image_files_map]
|
|
|
|
if not valid_files:
|
|
print(f"错误:在图片目录 '{image_dir}' 和标签目录 '{label_dir}' 之间未找到任何匹配的文件对。")
|
|
print("请确保图片和标签的文件名(除扩展名外)完全一致。")
|
|
return
|
|
|
|
print(f"共找到 {len(valid_files)} 个有效的图片-标签对。")
|
|
|
|
# --- 3. 随机划分 ---
|
|
random.shuffle(valid_files)
|
|
split_index = int(len(valid_files) * split_ratio)
|
|
train_files = valid_files[:split_index]
|
|
val_files = valid_files[split_index:]
|
|
|
|
# --- 4. 复制文件到目标位置 ---
|
|
def copy_files(file_list, dest_img_path, dest_label_path):
|
|
for base_name in file_list:
|
|
# 复制图片
|
|
img_name = image_files_map[base_name]
|
|
shutil.copy(os.path.join(image_dir, img_name), dest_img_path)
|
|
# 复制标签
|
|
label_name = base_name + '.txt'
|
|
shutil.copy(os.path.join(label_dir, label_name), dest_label_path)
|
|
|
|
print(f"正在复制 {len(train_files)} 个文件到训练集...")
|
|
copy_files(train_files, train_img_path, train_label_path)
|
|
|
|
print(f"正在复制 {len(val_files)} 个文件到验证集...")
|
|
copy_files(val_files, val_img_path, val_label_path)
|
|
|
|
print("\n数据集划分完成!")
|
|
print(f"训练集: {len(train_files)} 张图片 | 验证集: {len(val_files)} 张图片")
|
|
print(f"数据已整理至 '{output_dir}' 文件夹。")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# --- 用户需要配置的参数 ---
|
|
|
|
# # 1. 原始图片文件夹路径
|
|
# # !! 重要 !!: 请将这里的路径修改为您实际存放图片的文件夹
|
|
# # 可能是 'faster-rcnn/JPEGImages' 或其他名称
|
|
# ORIGINAL_IMAGE_DIR = '../OpenCV/data_bottom/test3'
|
|
#
|
|
# # 2. 转换后的 YOLO 标签文件夹路径
|
|
# YOLO_LABEL_DIR = 'data_bottom'
|
|
#
|
|
# # 3. 最终输出的数据集文件夹
|
|
# OUTPUT_DATASET_DIR = 'train_data'
|
|
#
|
|
# # 4. 训练集比例 (0.8 表示 80% 训练, 20% 验证)
|
|
# SPLIT_RATIO = 0.9
|
|
#
|
|
# # --- 运行主函数 ---
|
|
# split_dataset(ORIGINAL_IMAGE_DIR, YOLO_LABEL_DIR, OUTPUT_DATASET_DIR, SPLIT_RATIO)
|
|
#
|
|
# --- 用户需要配置的参数 ---
|
|
|
|
# 1. 原始图片文件夹路径
|
|
# !! 重要 !!: 请将这里的路径修改为您实际存放图片的文件夹
|
|
# 可能是 'faster-rcnn/JPEGImages' 或其他名称
|
|
ORIGINAL_IMAGE_DIR = '../label/up'
|
|
|
|
# 2. 转换后的 YOLO 标签文件夹路径
|
|
YOLO_LABEL_DIR = 'data_up'
|
|
|
|
# 3. 最终输出的数据集文件夹
|
|
OUTPUT_DATASET_DIR = 'train_data_up'
|
|
|
|
# 4. 训练集比例 (0.8 表示 80% 训练, 20% 验证)
|
|
SPLIT_RATIO = 0.9
|
|
|
|
# --- 运行主函数 ---
|
|
split_dataset(ORIGINAL_IMAGE_DIR, YOLO_LABEL_DIR, OUTPUT_DATASET_DIR, SPLIT_RATIO) |