import json from collections import defaultdict def create_cross_document_boundaries(input_file, output_file): """ 创建跨文档边界的句子对数据 将source_id n的最后一句与source_id n+1的第一句配对,标签设为1(分段) """ # 读取原始数据 with open(input_file, 'r', encoding='utf-8') as f: data = json.load(f) # 按source_id分组数据 source_groups = defaultdict(list) for item in data: source_id = item['source_id'] source_groups[source_id].append(item) # 按source_id排序 sorted_source_ids = sorted(source_groups.keys()) # 存储新创建的跨文档边界数据 cross_boundary_data = [] print(f"处理 {len(sorted_source_ids)} 个source_id...") # 遍历相邻的source_id for i in range(len(sorted_source_ids) - 1): current_source_id = sorted_source_ids[i] next_source_id = sorted_source_ids[i + 1] current_group = source_groups[current_source_id] next_group = source_groups[next_source_id] if len(current_group) == 0 or len(next_group) == 0: continue # 获取当前source_id的最后一个句子对的sentence2 last_item = current_group[-1] last_sentence = last_item['sentence2'] # 获取下一个source_id的第一个句子对的sentence1 first_item = next_group[0] first_sentence = first_item['sentence1'] # 创建跨文档边界的句子对 cross_boundary_item = { "sentence1": last_sentence, "sentence2": first_sentence, "label": 1, # 跨文档必须分段 "reason": f"跨文档边界: source_id {current_source_id} 的结尾与 source_id {next_source_id} 的开头,属于不同文档,必须分段。", "source_id": f"{current_source_id}-{next_source_id}", "boundary_type": "cross_document" } cross_boundary_data.append(cross_boundary_item) print(f"创建跨界边界: {current_source_id} -> {next_source_id}") print(f" 句子1: {last_sentence[:50]}...") print(f" 句子2: {first_sentence[:50]}...") print(f"\n总共创建了 {len(cross_boundary_data)} 个跨文档边界样本") # 保存跨文档边界数据 with open(output_file, 'w', encoding='utf-8') as f: json.dump(cross_boundary_data, f, ensure_ascii=False, indent=2) print(f"跨文档边界数据已保存到: {output_file}") return cross_boundary_data def merge_with_original_data(original_file, cross_boundary_file, merged_output_file): """ 将跨文档边界数据与原始数据合并 """ # 读取原始数据 with open(original_file, 'r', encoding='utf-8') as f: original_data = json.load(f) # 读取跨文档边界数据 with open(cross_boundary_file, 'r', encoding='utf-8') as f: cross_boundary_data = json.load(f) # 合并数据 merged_data = original_data + cross_boundary_data print(f"原始数据: {len(original_data)} 条") print(f"跨文档边界数据: {len(cross_boundary_data)} 条") print(f"合并后数据: {len(merged_data)} 条") # 统计标签分布 label_counts = {} for item in merged_data: label = item['label'] label_counts[label] = label_counts.get(label, 0) + 1 print(f"\n合并后标签分布:") for label, count in label_counts.items(): label_name = "不分段" if label == 0 else "分段" percentage = count / len(merged_data) * 100 print(f" 标签 {label} ({label_name}): {count} 条 ({percentage:.1f}%)") # 保存合并数据 with open(merged_output_file, 'w', encoding='utf-8') as f: json.dump(merged_data, f, ensure_ascii=False, indent=2) print(f"\n合并数据已保存到: {merged_output_file}") return merged_data def analyze_source_structure(input_file): """ 分析source_id的结构,便于理解数据 """ with open(input_file, 'r', encoding='utf-8') as f: data = json.load(f) # 按source_id分组 source_groups = defaultdict(list) for item in data: source_id = item['source_id'] source_groups[source_id].append(item) print(f"数据结构分析:") print(f"总共 {len(data)} 个句子对") print(f"涉及 {len(source_groups)} 个source_id") print(f"source_id范围: {min(source_groups.keys())} - {max(source_groups.keys())}") # 显示每个source_id的句子对数量 print(f"\n各source_id的句子对数量:") sorted_source_ids = sorted(source_groups.keys()) for source_id in sorted_source_ids: count = len(source_groups[source_id]) print(f" source_id {source_id}: {count} 个句子对") # 显示前几个source_id的示例 print(f"\n前3个source_id的示例:") for source_id in sorted_source_ids[:3]: items = source_groups[source_id] print(f"\nsource_id {source_id}:") print(f" 第一个句子对: {items[0]['sentence1'][:30]}... -> {items[0]['sentence2'][:30]}...") if len(items) > 1: print(f" 最后一个句子对: {items[-1]['sentence1'][:30]}... -> {items[-1]['sentence2'][:30]}...") def main(): """ 主函数 - 处理跨文档边界数据 """ # 文件路径 input_file = "segmentation_results_from_7_retried.json" cross_boundary_output = "cross_document_boundaries.json" merged_output = "enhanced_training_data_with_boundaries.json" print("=" * 60) print("跨文档边界数据生成") print("=" * 60) # 1. 分析原始数据结构 print("1. 分析原始数据结构...") analyze_source_structure(input_file) print("\n" + "=" * 60) # 2. 创建跨文档边界数据 print("2. 创建跨文档边界数据...") cross_boundary_data = create_cross_document_boundaries(input_file, cross_boundary_output) print("\n" + "=" * 60) # 3. 合并数据 print("3. 合并原始数据与跨文档边界数据...") merged_data = merge_with_original_data(input_file, cross_boundary_output, merged_output) print("\n" + "=" * 60) print("处理完成!") print("=" * 60) print(f"生成的文件:") print(f" - 跨文档边界数据: {cross_boundary_output}") print(f" - 增强训练数据: {merged_output}") print(f"\n现在可以使用 {merged_output} 进行模型训练") if __name__ == "__main__": main()