You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
192 lines
6.3 KiB
192 lines
6.3 KiB
import json |
|
from collections import defaultdict |
|
|
|
|
|
def create_cross_document_boundaries(input_file, output_file): |
|
""" |
|
创建跨文档边界的句子对数据 |
|
将source_id n的最后一句与source_id n+1的第一句配对,标签设为1(分段) |
|
""" |
|
|
|
# 读取原始数据 |
|
with open(input_file, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
|
|
# 按source_id分组数据 |
|
source_groups = defaultdict(list) |
|
for item in data: |
|
source_id = item['source_id'] |
|
source_groups[source_id].append(item) |
|
|
|
# 按source_id排序 |
|
sorted_source_ids = sorted(source_groups.keys()) |
|
|
|
# 存储新创建的跨文档边界数据 |
|
cross_boundary_data = [] |
|
|
|
print(f"处理 {len(sorted_source_ids)} 个source_id...") |
|
|
|
# 遍历相邻的source_id |
|
for i in range(len(sorted_source_ids) - 1): |
|
current_source_id = sorted_source_ids[i] |
|
next_source_id = sorted_source_ids[i + 1] |
|
|
|
current_group = source_groups[current_source_id] |
|
next_group = source_groups[next_source_id] |
|
|
|
if len(current_group) == 0 or len(next_group) == 0: |
|
continue |
|
|
|
# 获取当前source_id的最后一个句子对的sentence2 |
|
last_item = current_group[-1] |
|
last_sentence = last_item['sentence2'] |
|
|
|
# 获取下一个source_id的第一个句子对的sentence1 |
|
first_item = next_group[0] |
|
first_sentence = first_item['sentence1'] |
|
|
|
# 创建跨文档边界的句子对 |
|
cross_boundary_item = { |
|
"sentence1": last_sentence, |
|
"sentence2": first_sentence, |
|
"label": 1, # 跨文档必须分段 |
|
"reason": f"跨文档边界: source_id {current_source_id} 的结尾与 source_id {next_source_id} 的开头,属于不同文档,必须分段。", |
|
"source_id": f"{current_source_id}-{next_source_id}", |
|
"boundary_type": "cross_document" |
|
} |
|
|
|
cross_boundary_data.append(cross_boundary_item) |
|
|
|
print(f"创建跨界边界: {current_source_id} -> {next_source_id}") |
|
print(f" 句子1: {last_sentence[:50]}...") |
|
print(f" 句子2: {first_sentence[:50]}...") |
|
|
|
print(f"\n总共创建了 {len(cross_boundary_data)} 个跨文档边界样本") |
|
|
|
# 保存跨文档边界数据 |
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
json.dump(cross_boundary_data, f, ensure_ascii=False, indent=2) |
|
|
|
print(f"跨文档边界数据已保存到: {output_file}") |
|
|
|
return cross_boundary_data |
|
|
|
|
|
def merge_with_original_data(original_file, cross_boundary_file, merged_output_file): |
|
""" |
|
将跨文档边界数据与原始数据合并 |
|
""" |
|
|
|
# 读取原始数据 |
|
with open(original_file, 'r', encoding='utf-8') as f: |
|
original_data = json.load(f) |
|
|
|
# 读取跨文档边界数据 |
|
with open(cross_boundary_file, 'r', encoding='utf-8') as f: |
|
cross_boundary_data = json.load(f) |
|
|
|
# 合并数据 |
|
merged_data = original_data + cross_boundary_data |
|
|
|
print(f"原始数据: {len(original_data)} 条") |
|
print(f"跨文档边界数据: {len(cross_boundary_data)} 条") |
|
print(f"合并后数据: {len(merged_data)} 条") |
|
|
|
# 统计标签分布 |
|
label_counts = {} |
|
for item in merged_data: |
|
label = item['label'] |
|
label_counts[label] = label_counts.get(label, 0) + 1 |
|
|
|
print(f"\n合并后标签分布:") |
|
for label, count in label_counts.items(): |
|
label_name = "不分段" if label == 0 else "分段" |
|
percentage = count / len(merged_data) * 100 |
|
print(f" 标签 {label} ({label_name}): {count} 条 ({percentage:.1f}%)") |
|
|
|
# 保存合并数据 |
|
with open(merged_output_file, 'w', encoding='utf-8') as f: |
|
json.dump(merged_data, f, ensure_ascii=False, indent=2) |
|
|
|
print(f"\n合并数据已保存到: {merged_output_file}") |
|
|
|
return merged_data |
|
|
|
|
|
def analyze_source_structure(input_file): |
|
""" |
|
分析source_id的结构,便于理解数据 |
|
""" |
|
|
|
with open(input_file, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
|
|
# 按source_id分组 |
|
source_groups = defaultdict(list) |
|
for item in data: |
|
source_id = item['source_id'] |
|
source_groups[source_id].append(item) |
|
|
|
print(f"数据结构分析:") |
|
print(f"总共 {len(data)} 个句子对") |
|
print(f"涉及 {len(source_groups)} 个source_id") |
|
print(f"source_id范围: {min(source_groups.keys())} - {max(source_groups.keys())}") |
|
|
|
# 显示每个source_id的句子对数量 |
|
print(f"\n各source_id的句子对数量:") |
|
sorted_source_ids = sorted(source_groups.keys()) |
|
for source_id in sorted_source_ids: |
|
count = len(source_groups[source_id]) |
|
print(f" source_id {source_id}: {count} 个句子对") |
|
|
|
# 显示前几个source_id的示例 |
|
print(f"\n前3个source_id的示例:") |
|
for source_id in sorted_source_ids[:3]: |
|
items = source_groups[source_id] |
|
print(f"\nsource_id {source_id}:") |
|
print(f" 第一个句子对: {items[0]['sentence1'][:30]}... -> {items[0]['sentence2'][:30]}...") |
|
if len(items) > 1: |
|
print(f" 最后一个句子对: {items[-1]['sentence1'][:30]}... -> {items[-1]['sentence2'][:30]}...") |
|
|
|
|
|
def main(): |
|
""" |
|
主函数 - 处理跨文档边界数据 |
|
""" |
|
|
|
# 文件路径 |
|
input_file = "segmentation_results_from_7_retried.json" |
|
cross_boundary_output = "cross_document_boundaries.json" |
|
merged_output = "enhanced_training_data_with_boundaries.json" |
|
|
|
print("=" * 60) |
|
print("跨文档边界数据生成") |
|
print("=" * 60) |
|
|
|
# 1. 分析原始数据结构 |
|
print("1. 分析原始数据结构...") |
|
analyze_source_structure(input_file) |
|
|
|
print("\n" + "=" * 60) |
|
|
|
# 2. 创建跨文档边界数据 |
|
print("2. 创建跨文档边界数据...") |
|
cross_boundary_data = create_cross_document_boundaries(input_file, cross_boundary_output) |
|
|
|
print("\n" + "=" * 60) |
|
|
|
# 3. 合并数据 |
|
print("3. 合并原始数据与跨文档边界数据...") |
|
merged_data = merge_with_original_data(input_file, cross_boundary_output, merged_output) |
|
|
|
print("\n" + "=" * 60) |
|
print("处理完成!") |
|
print("=" * 60) |
|
print(f"生成的文件:") |
|
print(f" - 跨文档边界数据: {cross_boundary_output}") |
|
print(f" - 增强训练数据: {merged_output}") |
|
print(f"\n现在可以使用 {merged_output} 进行模型训练") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |