You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
219 lines
7.5 KiB
219 lines
7.5 KiB
import json |
|
import random |
|
from collections import defaultdict |
|
import pandas as pd |
|
|
|
|
|
def split_dataset_by_source_id(input_file, test_size=150, random_seed=42): |
|
""" |
|
根据source_id随机划分数据集为训练集和测试集 |
|
|
|
Args: |
|
input_file: 输入的JSON文件路径 |
|
test_size: 测试集中source_id的数量 |
|
random_seed: 随机种子,确保结果可重现 |
|
|
|
Returns: |
|
train_data, test_data: 训练集和测试集数据 |
|
""" |
|
|
|
# 设置随机种子 |
|
random.seed(random_seed) |
|
|
|
print(f"正在读取文件: {input_file}") |
|
|
|
try: |
|
# 读取JSON文件 |
|
with open(input_file, 'r', encoding='utf-8') as f: |
|
all_data = json.load(f) |
|
|
|
print(f"✓ 成功读取文件,总记录数: {len(all_data)}") |
|
|
|
# 按source_id分组 |
|
source_id_groups = defaultdict(list) |
|
for item in all_data: |
|
source_id_groups[item['source_id']].append(item) |
|
|
|
# 获取所有unique的source_id |
|
all_source_ids = list(source_id_groups.keys()) |
|
total_source_ids = len(all_source_ids) |
|
|
|
print(f"✓ 发现 {total_source_ids} 个不同的source_id") |
|
|
|
# 检查测试集大小是否合理 |
|
if test_size >= total_source_ids: |
|
print(f"✗ 错误:测试集大小 ({test_size}) 大于等于总source_id数量 ({total_source_ids})") |
|
print(f"建议将测试集大小设置为小于 {total_source_ids}") |
|
return None, None |
|
|
|
# 随机选择测试集的source_id |
|
test_source_ids = random.sample(all_source_ids, test_size) |
|
train_source_ids = [sid for sid in all_source_ids if sid not in test_source_ids] |
|
|
|
print(f"✓ 随机选择了 {len(test_source_ids)} 个source_id作为测试集") |
|
print(f"✓ 剩余 {len(train_source_ids)} 个source_id作为训练集") |
|
|
|
# 构建训练集和测试集 |
|
train_data = [] |
|
test_data = [] |
|
|
|
for source_id in train_source_ids: |
|
train_data.extend(source_id_groups[source_id]) |
|
|
|
for source_id in test_source_ids: |
|
test_data.extend(source_id_groups[source_id]) |
|
|
|
print(f"\n=== 数据集划分结果 ===") |
|
print(f"训练集:") |
|
print(f" - Source ID数量: {len(train_source_ids)}") |
|
print(f" - 记录数量: {len(train_data)}") |
|
|
|
print(f"测试集:") |
|
print(f" - Source ID数量: {len(test_source_ids)}") |
|
print(f" - 记录数量: {len(test_data)}") |
|
|
|
# 统计标签分布 |
|
def get_label_distribution(data, dataset_name): |
|
label_counts = defaultdict(int) |
|
for item in data: |
|
label_counts[item['label']] += 1 |
|
|
|
print(f"\n{dataset_name}标签分布:") |
|
for label, count in sorted(label_counts.items()): |
|
percentage = (count / len(data) * 150) if len(data) > 0 else 0 |
|
print(f" 标签 {label}: {count} 条 ({percentage:.2f}%)") |
|
|
|
return label_counts |
|
|
|
train_labels = get_label_distribution(train_data, "训练集") |
|
test_labels = get_label_distribution(test_data, "测试集") |
|
|
|
# 显示选中的source_id |
|
print(f"\n=== 测试集Source ID列表 ===") |
|
print(f"测试集source_id: {sorted(test_source_ids)}") |
|
|
|
print(f"\n=== 训练集Source ID列表 ===") |
|
print(f"训练集source_id: {sorted(train_source_ids)}") |
|
|
|
return train_data, test_data, train_source_ids, test_source_ids |
|
|
|
except FileNotFoundError: |
|
print(f"✗ 错误:找不到文件 {input_file}") |
|
return None, None, None, None |
|
except json.JSONDecodeError as e: |
|
print(f"✗ 错误:JSON文件格式错误 - {str(e)}") |
|
return None, None, None, None |
|
except Exception as e: |
|
print(f"✗ 错误:处理文件时出现异常 - {str(e)}") |
|
return None, None, None, None |
|
|
|
|
|
def save_dataset(data, filename, description): |
|
""" |
|
保存数据集到JSON文件 |
|
|
|
Args: |
|
data: 要保存的数据 |
|
filename: 输出文件名 |
|
description: 数据集描述 |
|
""" |
|
try: |
|
with open(filename, 'w', encoding='utf-8') as f: |
|
json.dump(data, f, ensure_ascii=False, indent=2) |
|
print(f"✓ {description}已保存到: {filename}") |
|
return True |
|
except Exception as e: |
|
print(f"✗ 保存{description}时出错: {str(e)}") |
|
return False |
|
|
|
|
|
def create_summary_report(train_data, test_data, train_source_ids, test_source_ids): |
|
""" |
|
创建数据集划分的详细报告 |
|
""" |
|
summary = { |
|
"split_info": { |
|
"total_source_ids": len(train_source_ids) + len(test_source_ids), |
|
"train_source_ids": len(train_source_ids), |
|
"test_source_ids": len(test_source_ids), |
|
"total_records": len(train_data) + len(test_data), |
|
"train_records": len(train_data), |
|
"test_records": len(test_data) |
|
}, |
|
"train_source_id_list": sorted(train_source_ids), |
|
"test_source_id_list": sorted(test_source_ids), |
|
"label_distribution": { |
|
"train": {}, |
|
"test": {} |
|
} |
|
} |
|
|
|
# 计算标签分布 |
|
for dataset_name, data in [("train", train_data), ("test", test_data)]: |
|
label_counts = defaultdict(int) |
|
for item in data: |
|
label_counts[item['label']] += 1 |
|
|
|
summary["label_distribution"][dataset_name] = dict(label_counts) |
|
|
|
# 保存报告 |
|
with open('dataset_split_summary.json', 'w', encoding='utf-8') as f: |
|
json.dump(summary, f, ensure_ascii=False, indent=2) |
|
|
|
print(f"✓ 数据集划分报告已保存到: dataset_split_summary.json") |
|
|
|
|
|
# 主程序执行 |
|
if __name__ == "__main__": |
|
print("=" * 60) |
|
print("数据集划分程序") |
|
print("=" * 60) |
|
|
|
# 输入文件名 |
|
input_file = "segmentation_results_from_7_retried.json" |
|
|
|
# 执行数据集划分 |
|
train_data, test_data, train_source_ids, test_source_ids = split_dataset_by_source_id( |
|
input_file=input_file, |
|
test_size=150, |
|
random_seed=42 |
|
) |
|
|
|
if train_data is not None and test_data is not None: |
|
print(f"\n{'=' * 60}") |
|
print("开始保存数据集文件") |
|
print(f"{'=' * 60}") |
|
|
|
# 保存训练集 |
|
train_success = save_dataset(train_data, "train_dataset.json", "训练集") |
|
|
|
# 保存测试集 |
|
test_success = save_dataset(test_data, "test_dataset.json", "测试集") |
|
|
|
if train_success and test_success: |
|
# 创建详细报告 |
|
create_summary_report(train_data, test_data, train_source_ids, test_source_ids) |
|
|
|
print(f"\n{'=' * 60}") |
|
print("数据集划分完成!") |
|
print(f"{'=' * 60}") |
|
print("生成的文件:") |
|
print("1. train_dataset.json - 训练集数据") |
|
print("2. test_dataset.json - 测试集数据") |
|
print("3. dataset_split_summary.json - 划分报告") |
|
|
|
# 验证数据完整性 |
|
print(f"\n=== 数据完整性验证 ===") |
|
original_count = len(train_data) + len(test_data) |
|
print(f"原始数据总数: {original_count}") |
|
print(f"训练集 + 测试集: {len(train_data)} + {len(test_data)} = {len(train_data) + len(test_data)}") |
|
|
|
if len(train_data) + len(test_data) == original_count: |
|
print("✓ 数据完整性验证通过") |
|
else: |
|
print("✗ 数据完整性验证失败") |
|
|
|
else: |
|
print("✗ 保存文件时出现错误") |
|
else: |
|
print("✗ 数据集划分失败") |