import json import random from collections import defaultdict import pandas as pd def split_dataset_by_source_id(input_file, test_size=150, random_seed=42): """ 根据source_id随机划分数据集为训练集和测试集 Args: input_file: 输入的JSON文件路径 test_size: 测试集中source_id的数量 random_seed: 随机种子,确保结果可重现 Returns: train_data, test_data: 训练集和测试集数据 """ # 设置随机种子 random.seed(random_seed) print(f"正在读取文件: {input_file}") try: # 读取JSON文件 with open(input_file, 'r', encoding='utf-8') as f: all_data = json.load(f) print(f"✓ 成功读取文件,总记录数: {len(all_data)}") # 按source_id分组 source_id_groups = defaultdict(list) for item in all_data: source_id_groups[item['source_id']].append(item) # 获取所有unique的source_id all_source_ids = list(source_id_groups.keys()) total_source_ids = len(all_source_ids) print(f"✓ 发现 {total_source_ids} 个不同的source_id") # 检查测试集大小是否合理 if test_size >= total_source_ids: print(f"✗ 错误:测试集大小 ({test_size}) 大于等于总source_id数量 ({total_source_ids})") print(f"建议将测试集大小设置为小于 {total_source_ids}") return None, None # 随机选择测试集的source_id test_source_ids = random.sample(all_source_ids, test_size) train_source_ids = [sid for sid in all_source_ids if sid not in test_source_ids] print(f"✓ 随机选择了 {len(test_source_ids)} 个source_id作为测试集") print(f"✓ 剩余 {len(train_source_ids)} 个source_id作为训练集") # 构建训练集和测试集 train_data = [] test_data = [] for source_id in train_source_ids: train_data.extend(source_id_groups[source_id]) for source_id in test_source_ids: test_data.extend(source_id_groups[source_id]) print(f"\n=== 数据集划分结果 ===") print(f"训练集:") print(f" - Source ID数量: {len(train_source_ids)}") print(f" - 记录数量: {len(train_data)}") print(f"测试集:") print(f" - Source ID数量: {len(test_source_ids)}") print(f" - 记录数量: {len(test_data)}") # 统计标签分布 def get_label_distribution(data, dataset_name): label_counts = defaultdict(int) for item in data: label_counts[item['label']] += 1 print(f"\n{dataset_name}标签分布:") for label, count in sorted(label_counts.items()): percentage = (count / len(data) * 150) if len(data) > 0 else 0 print(f" 标签 {label}: {count} 条 ({percentage:.2f}%)") return label_counts train_labels = get_label_distribution(train_data, "训练集") test_labels = get_label_distribution(test_data, "测试集") # 显示选中的source_id print(f"\n=== 测试集Source ID列表 ===") print(f"测试集source_id: {sorted(test_source_ids)}") print(f"\n=== 训练集Source ID列表 ===") print(f"训练集source_id: {sorted(train_source_ids)}") return train_data, test_data, train_source_ids, test_source_ids except FileNotFoundError: print(f"✗ 错误:找不到文件 {input_file}") return None, None, None, None except json.JSONDecodeError as e: print(f"✗ 错误:JSON文件格式错误 - {str(e)}") return None, None, None, None except Exception as e: print(f"✗ 错误:处理文件时出现异常 - {str(e)}") return None, None, None, None def save_dataset(data, filename, description): """ 保存数据集到JSON文件 Args: data: 要保存的数据 filename: 输出文件名 description: 数据集描述 """ try: with open(filename, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) print(f"✓ {description}已保存到: {filename}") return True except Exception as e: print(f"✗ 保存{description}时出错: {str(e)}") return False def create_summary_report(train_data, test_data, train_source_ids, test_source_ids): """ 创建数据集划分的详细报告 """ summary = { "split_info": { "total_source_ids": len(train_source_ids) + len(test_source_ids), "train_source_ids": len(train_source_ids), "test_source_ids": len(test_source_ids), "total_records": len(train_data) + len(test_data), "train_records": len(train_data), "test_records": len(test_data) }, "train_source_id_list": sorted(train_source_ids), "test_source_id_list": sorted(test_source_ids), "label_distribution": { "train": {}, "test": {} } } # 计算标签分布 for dataset_name, data in [("train", train_data), ("test", test_data)]: label_counts = defaultdict(int) for item in data: label_counts[item['label']] += 1 summary["label_distribution"][dataset_name] = dict(label_counts) # 保存报告 with open('dataset_split_summary.json', 'w', encoding='utf-8') as f: json.dump(summary, f, ensure_ascii=False, indent=2) print(f"✓ 数据集划分报告已保存到: dataset_split_summary.json") # 主程序执行 if __name__ == "__main__": print("=" * 60) print("数据集划分程序") print("=" * 60) # 输入文件名 input_file = "segmentation_results_from_7_retried.json" # 执行数据集划分 train_data, test_data, train_source_ids, test_source_ids = split_dataset_by_source_id( input_file=input_file, test_size=150, random_seed=42 ) if train_data is not None and test_data is not None: print(f"\n{'=' * 60}") print("开始保存数据集文件") print(f"{'=' * 60}") # 保存训练集 train_success = save_dataset(train_data, "train_dataset.json", "训练集") # 保存测试集 test_success = save_dataset(test_data, "test_dataset.json", "测试集") if train_success and test_success: # 创建详细报告 create_summary_report(train_data, test_data, train_source_ids, test_source_ids) print(f"\n{'=' * 60}") print("数据集划分完成!") print(f"{'=' * 60}") print("生成的文件:") print("1. train_dataset.json - 训练集数据") print("2. test_dataset.json - 测试集数据") print("3. dataset_split_summary.json - 划分报告") # 验证数据完整性 print(f"\n=== 数据完整性验证 ===") original_count = len(train_data) + len(test_data) print(f"原始数据总数: {original_count}") print(f"训练集 + 测试集: {len(train_data)} + {len(test_data)} = {len(train_data) + len(test_data)}") if len(train_data) + len(test_data) == original_count: print("✓ 数据完整性验证通过") else: print("✗ 数据完整性验证失败") else: print("✗ 保存文件时出现错误") else: print("✗ 数据集划分失败")