Files
innovate_project/yolov8/split_dataset.py
2025-11-02 21:36:35 +08:00

116 lines
4.4 KiB
Python

import os
import random
import shutil
def split_dataset(image_dir, label_dir, output_dir, split_ratio=0.8):
"""
自动将图片和标签划分为训练集和验证集。
Args:
image_dir (str): 原始图片文件夹路径。
label_dir (str): 转换后的 YOLO 标签 (.txt) 文件夹路径。
output_dir (str): 整理好的数据集输出根目录 (例如 'weld_dataset')。
split_ratio (float): 训练集所占的比例,例如 0.8 代表 80% 训练, 20% 验证。
"""
print("开始划分数据集...")
# --- 1. 路径设置和文件夹创建 ---
train_img_path = os.path.join(output_dir, 'images', 'train')
val_img_path = os.path.join(output_dir, 'images', 'val')
train_label_path = os.path.join(output_dir, 'labels', 'train')
val_label_path = os.path.join(output_dir, 'labels', 'val')
# 创建所有必要的文件夹
os.makedirs(train_img_path, exist_ok=True)
os.makedirs(val_img_path, exist_ok=True)
os.makedirs(train_label_path, exist_ok=True)
os.makedirs(val_label_path, exist_ok=True)
# --- 2. 文件匹配 ---
# 获取所有标签文件的基础名(不含扩展名)
label_files = [os.path.splitext(f)[0] for f in os.listdir(label_dir) if f.endswith('.txt')]
# 查找对应的图片文件(支持多种格式)
image_files_map = {}
supported_formats = ['.jpg', '.jpeg', '.png', '.bmp']
for f in os.listdir(image_dir):
base_name, ext = os.path.splitext(f)
if ext.lower() in supported_formats:
image_files_map[base_name] = f
# 找出图片和标签都存在的文件对
valid_files = [base_name for base_name in label_files if base_name in image_files_map]
if not valid_files:
print(f"错误:在图片目录 '{image_dir}' 和标签目录 '{label_dir}' 之间未找到任何匹配的文件对。")
print("请确保图片和标签的文件名(除扩展名外)完全一致。")
return
print(f"共找到 {len(valid_files)} 个有效的图片-标签对。")
# --- 3. 随机划分 ---
random.shuffle(valid_files)
split_index = int(len(valid_files) * split_ratio)
train_files = valid_files[:split_index]
val_files = valid_files[split_index:]
# --- 4. 复制文件到目标位置 ---
def copy_files(file_list, dest_img_path, dest_label_path):
for base_name in file_list:
# 复制图片
img_name = image_files_map[base_name]
shutil.copy(os.path.join(image_dir, img_name), dest_img_path)
# 复制标签
label_name = base_name + '.txt'
shutil.copy(os.path.join(label_dir, label_name), dest_label_path)
print(f"正在复制 {len(train_files)} 个文件到训练集...")
copy_files(train_files, train_img_path, train_label_path)
print(f"正在复制 {len(val_files)} 个文件到验证集...")
copy_files(val_files, val_img_path, val_label_path)
print("\n数据集划分完成!")
print(f"训练集: {len(train_files)} 张图片 | 验证集: {len(val_files)} 张图片")
print(f"数据已整理至 '{output_dir}' 文件夹。")
if __name__ == '__main__':
# --- 用户需要配置的参数 ---
# # 1. 原始图片文件夹路径
# # !! 重要 !!: 请将这里的路径修改为您实际存放图片的文件夹
# # 可能是 'faster-rcnn/JPEGImages' 或其他名称
# ORIGINAL_IMAGE_DIR = '../OpenCV/data_bottom/test3'
#
# # 2. 转换后的 YOLO 标签文件夹路径
# YOLO_LABEL_DIR = 'data_bottom'
#
# # 3. 最终输出的数据集文件夹
# OUTPUT_DATASET_DIR = 'train_data'
#
# # 4. 训练集比例 (0.8 表示 80% 训练, 20% 验证)
# SPLIT_RATIO = 0.9
#
# # --- 运行主函数 ---
# split_dataset(ORIGINAL_IMAGE_DIR, YOLO_LABEL_DIR, OUTPUT_DATASET_DIR, SPLIT_RATIO)
#
# --- 用户需要配置的参数 ---
# 1. 原始图片文件夹路径
# !! 重要 !!: 请将这里的路径修改为您实际存放图片的文件夹
# 可能是 'faster-rcnn/JPEGImages' 或其他名称
ORIGINAL_IMAGE_DIR = '../label/up'
# 2. 转换后的 YOLO 标签文件夹路径
YOLO_LABEL_DIR = 'data_up'
# 3. 最终输出的数据集文件夹
OUTPUT_DATASET_DIR = 'train_data_up'
# 4. 训练集比例 (0.8 表示 80% 训练, 20% 验证)
SPLIT_RATIO = 0.9
# --- 运行主函数 ---
split_dataset(ORIGINAL_IMAGE_DIR, YOLO_LABEL_DIR, OUTPUT_DATASET_DIR, SPLIT_RATIO)