Files
solidstate-tools/contrast learning/split.py
2025-11-19 12:23:17 +08:00

112 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import shutil
import random
from pathlib import Path
def split_dataset(source_dir: str, output_dir: str, test_ratio: float = 0.2):
"""
将源文件夹中的文件按比例划分到输出文件夹下的 train 和 test 子目录中。
Args:
source_dir (str): 包含所有数据文件的源文件夹路径。
output_dir (str): 用于存放'train''test'文件夹的目标文件夹路径。
test_ratio (float, optional): 测试集所占的比例。默认为 0.2。
"""
print("--- 开始执行数据集划分 ---")
# 1. 路径处理和验证
source_path = Path(source_dir)
output_path = Path(output_dir)
if not source_path.is_dir():
print(f"错误:源文件夹 '{source_dir}' 不存在或不是一个目录。")
return
# 2. 创建输出文件夹 (train 和 test)
train_dir = output_path / 'train'
test_dir = output_path / 'test'
try:
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
print(f"输出目录已准备好: \n 训练集 -> {train_dir}\n 测试集 -> {test_dir}")
except OSError as e:
print(f"错误:创建输出目录时发生错误: {e}")
return
# 3. 获取所有文件并随机打乱
all_files = [f for f in source_path.iterdir() if f.is_file()]
if not all_files:
print(f"警告:源文件夹 '{source_dir}' 中没有文件可供划分。")
return
random.shuffle(all_files)
total_files = len(all_files)
print(f"在源文件夹中找到 {total_files} 个文件。")
# 4. 计算分割数量
num_test = int(total_files * test_ratio)
num_train = total_files - num_test
print(f"划分计划 -> 训练集: {num_train} 个文件 | 测试集: {num_test} 个文件")
# 5. 分割文件列表
test_files = all_files[:num_test]
train_files = all_files[num_test:]
# 6. 定义一个复制/移动文件的辅助函数
def copy_files(files_to_copy, destination_dir):
copied_count = 0
for file_path in files_to_copy:
try:
# 注意:这里使用的是复制(copy),更安全。
# 如果你确认要移动(move)并且清空源文件夹,请将 shutil.copy 改为 shutil.move
shutil.copy(file_path, destination_dir)
copied_count += 1
except Exception as e:
print(f"处理文件 '{file_path.name}' 时出错: {e}")
return copied_count
# 7. 复制文件到对应的文件夹
print(f"\n正在复制文件到 'train' 文件夹...")
copied_train = copy_files(train_files, train_dir)
print(f"成功复制 {copied_train} 个文件到训练集。")
print(f"\n正在复制文件到 'test' 文件夹...")
copied_test = copy_files(test_files, test_dir)
print(f"成功复制 {copied_test} 个文件到测试集。")
print("\n--- 数据集划分完成! ---")
# --- 如何使用这个函数 ---
if __name__ == '__main__':
# --- 请在这里配置你的文件夹路径 ---
# 你的原始数据集所在的文件夹
# 例如: 'C:/Users/YourUser/Desktop/my_dataset' (Windows)
# 或: '/home/user/project/raw/all_images' (Linux/macOS)
SOURCE_DATA_DIR = 'D:/download/2025-10/data_OS/input/S'
# 你希望将'train'和'test'文件夹创建在哪里
# 例如: 'C:/Users/YourUser/Desktop/split_output' (Windows)
# 或: '/home/user/project/raw/processed' (Linux/macOS)
# 如果使用 '.', 表示在当前脚本所在的目录下创建
OUTPUT_DIR = 'D:/download/2025-10/data_OS/train/S'
# --- 配置完成,下面是调用函数 ---
# 检查示例路径是否存在,如果不存在则创建并填充一些假文件用于演示
if not os.path.exists(SOURCE_DATA_DIR):
print(f"演示目录 '{SOURCE_DATA_DIR}' 不存在正在创建并生成100个示例文件...")
os.makedirs(SOURCE_DATA_DIR)
for i in range(100):
with open(os.path.join(SOURCE_DATA_DIR, f'file_{i + 1:03d}.txt'), 'w') as f:
f.write(f'This is file {i + 1}.')
print("示例文件创建完毕。")
# 调用函数执行划分
split_dataset(SOURCE_DATA_DIR, OUTPUT_DIR, test_ratio=0.2)