You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

219 lines
7.5 KiB

import json
import random
from collections import defaultdict
import pandas as pd
def split_dataset_by_source_id(input_file, test_size=150, random_seed=42):
"""
根据source_id随机划分数据集为训练集和测试集
Args:
input_file: 输入的JSON文件路径
test_size: 测试集中source_id的数量
random_seed: 随机种子,确保结果可重现
Returns:
train_data, test_data: 训练集和测试集数据
"""
# 设置随机种子
random.seed(random_seed)
print(f"正在读取文件: {input_file}")
try:
# 读取JSON文件
with open(input_file, 'r', encoding='utf-8') as f:
all_data = json.load(f)
print(f"✓ 成功读取文件,总记录数: {len(all_data)}")
# 按source_id分组
source_id_groups = defaultdict(list)
for item in all_data:
source_id_groups[item['source_id']].append(item)
# 获取所有unique的source_id
all_source_ids = list(source_id_groups.keys())
total_source_ids = len(all_source_ids)
print(f"✓ 发现 {total_source_ids} 个不同的source_id")
# 检查测试集大小是否合理
if test_size >= total_source_ids:
print(f"✗ 错误:测试集大小 ({test_size}) 大于等于总source_id数量 ({total_source_ids})")
print(f"建议将测试集大小设置为小于 {total_source_ids}")
return None, None
# 随机选择测试集的source_id
test_source_ids = random.sample(all_source_ids, test_size)
train_source_ids = [sid for sid in all_source_ids if sid not in test_source_ids]
print(f"✓ 随机选择了 {len(test_source_ids)} 个source_id作为测试集")
print(f"✓ 剩余 {len(train_source_ids)} 个source_id作为训练集")
# 构建训练集和测试集
train_data = []
test_data = []
for source_id in train_source_ids:
train_data.extend(source_id_groups[source_id])
for source_id in test_source_ids:
test_data.extend(source_id_groups[source_id])
print(f"\n=== 数据集划分结果 ===")
print(f"训练集:")
print(f" - Source ID数量: {len(train_source_ids)}")
print(f" - 记录数量: {len(train_data)}")
print(f"测试集:")
print(f" - Source ID数量: {len(test_source_ids)}")
print(f" - 记录数量: {len(test_data)}")
# 统计标签分布
def get_label_distribution(data, dataset_name):
label_counts = defaultdict(int)
for item in data:
label_counts[item['label']] += 1
print(f"\n{dataset_name}标签分布:")
for label, count in sorted(label_counts.items()):
percentage = (count / len(data) * 150) if len(data) > 0 else 0
print(f" 标签 {label}: {count} 条 ({percentage:.2f}%)")
return label_counts
train_labels = get_label_distribution(train_data, "训练集")
test_labels = get_label_distribution(test_data, "测试集")
# 显示选中的source_id
print(f"\n=== 测试集Source ID列表 ===")
print(f"测试集source_id: {sorted(test_source_ids)}")
print(f"\n=== 训练集Source ID列表 ===")
print(f"训练集source_id: {sorted(train_source_ids)}")
return train_data, test_data, train_source_ids, test_source_ids
except FileNotFoundError:
print(f"✗ 错误:找不到文件 {input_file}")
return None, None, None, None
except json.JSONDecodeError as e:
print(f"✗ 错误:JSON文件格式错误 - {str(e)}")
return None, None, None, None
except Exception as e:
print(f"✗ 错误:处理文件时出现异常 - {str(e)}")
return None, None, None, None
def save_dataset(data, filename, description):
"""
保存数据集到JSON文件
Args:
data: 要保存的数据
filename: 输出文件名
description: 数据集描述
"""
try:
with open(filename, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"{description}已保存到: {filename}")
return True
except Exception as e:
print(f"✗ 保存{description}时出错: {str(e)}")
return False
def create_summary_report(train_data, test_data, train_source_ids, test_source_ids):
"""
创建数据集划分的详细报告
"""
summary = {
"split_info": {
"total_source_ids": len(train_source_ids) + len(test_source_ids),
"train_source_ids": len(train_source_ids),
"test_source_ids": len(test_source_ids),
"total_records": len(train_data) + len(test_data),
"train_records": len(train_data),
"test_records": len(test_data)
},
"train_source_id_list": sorted(train_source_ids),
"test_source_id_list": sorted(test_source_ids),
"label_distribution": {
"train": {},
"test": {}
}
}
# 计算标签分布
for dataset_name, data in [("train", train_data), ("test", test_data)]:
label_counts = defaultdict(int)
for item in data:
label_counts[item['label']] += 1
summary["label_distribution"][dataset_name] = dict(label_counts)
# 保存报告
with open('dataset_split_summary.json', 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print(f"✓ 数据集划分报告已保存到: dataset_split_summary.json")
# 主程序执行
if __name__ == "__main__":
print("=" * 60)
print("数据集划分程序")
print("=" * 60)
# 输入文件名
input_file = "segmentation_results_from_7_retried.json"
# 执行数据集划分
train_data, test_data, train_source_ids, test_source_ids = split_dataset_by_source_id(
input_file=input_file,
test_size=150,
random_seed=42
)
if train_data is not None and test_data is not None:
print(f"\n{'=' * 60}")
print("开始保存数据集文件")
print(f"{'=' * 60}")
# 保存训练集
train_success = save_dataset(train_data, "train_dataset.json", "训练集")
# 保存测试集
test_success = save_dataset(test_data, "test_dataset.json", "测试集")
if train_success and test_success:
# 创建详细报告
create_summary_report(train_data, test_data, train_source_ids, test_source_ids)
print(f"\n{'=' * 60}")
print("数据集划分完成!")
print(f"{'=' * 60}")
print("生成的文件:")
print("1. train_dataset.json - 训练集数据")
print("2. test_dataset.json - 测试集数据")
print("3. dataset_split_summary.json - 划分报告")
# 验证数据完整性
print(f"\n=== 数据完整性验证 ===")
original_count = len(train_data) + len(test_data)
print(f"原始数据总数: {original_count}")
print(f"训练集 + 测试集: {len(train_data)} + {len(test_data)} = {len(train_data) + len(test_data)}")
if len(train_data) + len(test_data) == original_count:
print("✓ 数据完整性验证通过")
else:
print("✗ 数据完整性验证失败")
else:
print("✗ 保存文件时出现错误")
else:
print("✗ 数据集划分失败")