You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

192 lines
6.3 KiB

import json
from collections import defaultdict
def create_cross_document_boundaries(input_file, output_file):
"""
创建跨文档边界的句子对数据
将source_id n的最后一句与source_id n+1的第一句配对,标签设为1(分段)
"""
# 读取原始数据
with open(input_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# 按source_id分组数据
source_groups = defaultdict(list)
for item in data:
source_id = item['source_id']
source_groups[source_id].append(item)
# 按source_id排序
sorted_source_ids = sorted(source_groups.keys())
# 存储新创建的跨文档边界数据
cross_boundary_data = []
print(f"处理 {len(sorted_source_ids)} 个source_id...")
# 遍历相邻的source_id
for i in range(len(sorted_source_ids) - 1):
current_source_id = sorted_source_ids[i]
next_source_id = sorted_source_ids[i + 1]
current_group = source_groups[current_source_id]
next_group = source_groups[next_source_id]
if len(current_group) == 0 or len(next_group) == 0:
continue
# 获取当前source_id的最后一个句子对的sentence2
last_item = current_group[-1]
last_sentence = last_item['sentence2']
# 获取下一个source_id的第一个句子对的sentence1
first_item = next_group[0]
first_sentence = first_item['sentence1']
# 创建跨文档边界的句子对
cross_boundary_item = {
"sentence1": last_sentence,
"sentence2": first_sentence,
"label": 1, # 跨文档必须分段
"reason": f"跨文档边界: source_id {current_source_id} 的结尾与 source_id {next_source_id} 的开头,属于不同文档,必须分段。",
"source_id": f"{current_source_id}-{next_source_id}",
"boundary_type": "cross_document"
}
cross_boundary_data.append(cross_boundary_item)
print(f"创建跨界边界: {current_source_id} -> {next_source_id}")
print(f" 句子1: {last_sentence[:50]}...")
print(f" 句子2: {first_sentence[:50]}...")
print(f"\n总共创建了 {len(cross_boundary_data)} 个跨文档边界样本")
# 保存跨文档边界数据
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(cross_boundary_data, f, ensure_ascii=False, indent=2)
print(f"跨文档边界数据已保存到: {output_file}")
return cross_boundary_data
def merge_with_original_data(original_file, cross_boundary_file, merged_output_file):
"""
将跨文档边界数据与原始数据合并
"""
# 读取原始数据
with open(original_file, 'r', encoding='utf-8') as f:
original_data = json.load(f)
# 读取跨文档边界数据
with open(cross_boundary_file, 'r', encoding='utf-8') as f:
cross_boundary_data = json.load(f)
# 合并数据
merged_data = original_data + cross_boundary_data
print(f"原始数据: {len(original_data)}")
print(f"跨文档边界数据: {len(cross_boundary_data)}")
print(f"合并后数据: {len(merged_data)}")
# 统计标签分布
label_counts = {}
for item in merged_data:
label = item['label']
label_counts[label] = label_counts.get(label, 0) + 1
print(f"\n合并后标签分布:")
for label, count in label_counts.items():
label_name = "不分段" if label == 0 else "分段"
percentage = count / len(merged_data) * 100
print(f" 标签 {label} ({label_name}): {count} 条 ({percentage:.1f}%)")
# 保存合并数据
with open(merged_output_file, 'w', encoding='utf-8') as f:
json.dump(merged_data, f, ensure_ascii=False, indent=2)
print(f"\n合并数据已保存到: {merged_output_file}")
return merged_data
def analyze_source_structure(input_file):
"""
分析source_id的结构,便于理解数据
"""
with open(input_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# 按source_id分组
source_groups = defaultdict(list)
for item in data:
source_id = item['source_id']
source_groups[source_id].append(item)
print(f"数据结构分析:")
print(f"总共 {len(data)} 个句子对")
print(f"涉及 {len(source_groups)} 个source_id")
print(f"source_id范围: {min(source_groups.keys())} - {max(source_groups.keys())}")
# 显示每个source_id的句子对数量
print(f"\n各source_id的句子对数量:")
sorted_source_ids = sorted(source_groups.keys())
for source_id in sorted_source_ids:
count = len(source_groups[source_id])
print(f" source_id {source_id}: {count} 个句子对")
# 显示前几个source_id的示例
print(f"\n前3个source_id的示例:")
for source_id in sorted_source_ids[:3]:
items = source_groups[source_id]
print(f"\nsource_id {source_id}:")
print(f" 第一个句子对: {items[0]['sentence1'][:30]}... -> {items[0]['sentence2'][:30]}...")
if len(items) > 1:
print(f" 最后一个句子对: {items[-1]['sentence1'][:30]}... -> {items[-1]['sentence2'][:30]}...")
def main():
"""
主函数 - 处理跨文档边界数据
"""
# 文件路径
input_file = "segmentation_results_from_7_retried.json"
cross_boundary_output = "cross_document_boundaries.json"
merged_output = "enhanced_training_data_with_boundaries.json"
print("=" * 60)
print("跨文档边界数据生成")
print("=" * 60)
# 1. 分析原始数据结构
print("1. 分析原始数据结构...")
analyze_source_structure(input_file)
print("\n" + "=" * 60)
# 2. 创建跨文档边界数据
print("2. 创建跨文档边界数据...")
cross_boundary_data = create_cross_document_boundaries(input_file, cross_boundary_output)
print("\n" + "=" * 60)
# 3. 合并数据
print("3. 合并原始数据与跨文档边界数据...")
merged_data = merge_with_original_data(input_file, cross_boundary_output, merged_output)
print("\n" + "=" * 60)
print("处理完成!")
print("=" * 60)
print(f"生成的文件:")
print(f" - 跨文档边界数据: {cross_boundary_output}")
print(f" - 增强训练数据: {merged_output}")
print(f"\n现在可以使用 {merged_output} 进行模型训练")
if __name__ == "__main__":
main()