You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
247 lines
8.5 KiB
247 lines
8.5 KiB
#!/usr/bin/env python3 |
|
# -*- coding: utf-8 -*- |
|
""" |
|
根据source_id和label标签合并段落并输出txt文件 |
|
将label=0的连续句子合并,label=1作为分界点分段 |
|
""" |
|
|
|
import json |
|
import os |
|
from collections import defaultdict |
|
from typing import List, Dict, Any |
|
|
|
|
|
def load_test_data(file_path: str) -> List[Dict[str, Any]]: |
|
"""加载测试数据""" |
|
try: |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
print(f"成功加载 {len(data)} 条数据") |
|
return data |
|
except FileNotFoundError: |
|
print(f"错误:找不到文件 {file_path}") |
|
return [] |
|
except json.JSONDecodeError as e: |
|
print(f"错误:JSON格式错误 - {e}") |
|
return [] |
|
except Exception as e: |
|
print(f"错误:加载文件时出现问题 - {e}") |
|
return [] |
|
|
|
|
|
def group_by_source_id(data: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: |
|
"""按source_id分组数据""" |
|
grouped_data = defaultdict(list) |
|
|
|
for item in data: |
|
source_id = str(item.get('source_id', 'unknown')) |
|
grouped_data[source_id].append(item) |
|
|
|
# 对每个组内的数据按出现顺序排序(保持原有顺序) |
|
for source_id in grouped_data: |
|
# 如果数据中有索引信息,可以按索引排序 |
|
# 这里假设数据已经按正确顺序排列 |
|
pass |
|
|
|
print(f"数据按source_id分组完成,共 {len(grouped_data)} 个组") |
|
for source_id, items in grouped_data.items(): |
|
print(f" Source ID {source_id}: {len(items)} 条数据") |
|
|
|
return dict(grouped_data) |
|
|
|
|
|
def merge_paragraphs_by_labels(sentence_pairs: List[Dict[str, Any]]) -> List[str]: |
|
""" |
|
根据label合并段落 |
|
label=0: 同一段落,需要合并 |
|
label=1: 不同段落,作为分界点 |
|
""" |
|
if not sentence_pairs: |
|
return [] |
|
|
|
paragraphs = [] |
|
current_paragraph = [] |
|
|
|
# 处理第一个句子 |
|
if sentence_pairs: |
|
current_paragraph.append(sentence_pairs[0]['sentence1']) |
|
|
|
# 遍历所有句子对 |
|
for i, pair in enumerate(sentence_pairs): |
|
sentence2 = pair['sentence2'] |
|
label = pair['label'] |
|
|
|
if label == 0: |
|
# 同一段落,继续添加到当前段落 |
|
# 只添加sentence2,因为sentence1已经在上一轮添加过了 |
|
if sentence2 not in current_paragraph: # 避免重复 |
|
current_paragraph.append(sentence2) |
|
|
|
elif label == 1: |
|
# 不同段落,结束当前段落,开始新段落 |
|
if current_paragraph: |
|
paragraph_text = ''.join(current_paragraph) |
|
if paragraph_text.strip(): # 确保段落不为空 |
|
paragraphs.append(paragraph_text.strip()) |
|
|
|
# 开始新段落 |
|
current_paragraph = [sentence2] |
|
|
|
# 处理最后一个段落 |
|
if current_paragraph: |
|
paragraph_text = ''.join(current_paragraph) |
|
if paragraph_text.strip(): |
|
paragraphs.append(paragraph_text.strip()) |
|
|
|
return paragraphs |
|
|
|
|
|
def process_single_source(source_id: str, sentence_pairs: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
"""处理单个source_id的数据""" |
|
print(f"\n处理Source ID: {source_id}") |
|
print(f"句子对数量: {len(sentence_pairs)}") |
|
|
|
# 统计标签分布 |
|
label_counts = defaultdict(int) |
|
for pair in sentence_pairs: |
|
label_counts[pair['label']] += 1 |
|
|
|
print(f"标签分布: Label 0 (同段): {label_counts[0]}, Label 1 (分段): {label_counts[1]}") |
|
|
|
# 合并段落 |
|
paragraphs = merge_paragraphs_by_labels(sentence_pairs) |
|
|
|
print(f"合并后段落数: {len(paragraphs)}") |
|
|
|
# 统计信息 |
|
total_chars = sum(len(p) for p in paragraphs) |
|
avg_paragraph_length = total_chars / len(paragraphs) if paragraphs else 0 |
|
|
|
return { |
|
'source_id': source_id, |
|
'original_pairs_count': len(sentence_pairs), |
|
'merged_paragraphs_count': len(paragraphs), |
|
'label_distribution': dict(label_counts), |
|
'total_characters': total_chars, |
|
'avg_paragraph_length': avg_paragraph_length, |
|
'paragraphs': paragraphs |
|
} |
|
|
|
|
|
def save_to_txt(results: Dict[str, Dict[str, Any]], output_file: str): |
|
"""保存结果到txt文件""" |
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
f.write("=" * 80 + "\n") |
|
f.write("段落合并结果\n") |
|
f.write("根据source_id和label标签合并的段落文本\n") |
|
f.write("=" * 80 + "\n\n") |
|
|
|
for source_id, result in results.items(): |
|
f.write(f"【Source ID: {source_id}】\n") |
|
f.write(f"原始句子对数量: {result['original_pairs_count']}\n") |
|
f.write(f"合并后段落数量: {result['merged_paragraphs_count']}\n") |
|
f.write(f"标签分布: {result['label_distribution']}\n") |
|
f.write(f"总字符数: {result['total_characters']}\n") |
|
f.write(f"平均段落长度: {result['avg_paragraph_length']:.1f} 字符\n") |
|
f.write("-" * 60 + "\n") |
|
|
|
for i, paragraph in enumerate(result['paragraphs'], 1): |
|
f.write(f"段落 {i}:\n{paragraph}\n\n") |
|
|
|
f.write("=" * 80 + "\n\n") |
|
|
|
|
|
def save_summary_json(results: Dict[str, Dict[str, Any]], output_file: str): |
|
"""保存统计摘要到JSON文件""" |
|
summary = { |
|
'total_source_ids': len(results), |
|
'total_original_pairs': sum(r['original_pairs_count'] for r in results.values()), |
|
'total_merged_paragraphs': sum(r['merged_paragraphs_count'] for r in results.values()), |
|
'total_characters': sum(r['total_characters'] for r in results.values()), |
|
'source_details': {} |
|
} |
|
|
|
for source_id, result in results.items(): |
|
summary['source_details'][source_id] = { |
|
'original_pairs_count': result['original_pairs_count'], |
|
'merged_paragraphs_count': result['merged_paragraphs_count'], |
|
'label_distribution': result['label_distribution'], |
|
'total_characters': result['total_characters'], |
|
'avg_paragraph_length': result['avg_paragraph_length'] |
|
} |
|
|
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
json.dump(summary, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
def main(): |
|
"""主函数""" |
|
# 配置文件路径 |
|
input_file = 'segmentation_results_from_new50.json' # 输入文件路径 |
|
output_txt = 'merged_paragraphs.txt' # 输出txt文件 |
|
output_summary = 'merge_summary.json' # 输出统计摘要 |
|
|
|
print("=" * 80) |
|
print("段落合并处理程序") |
|
print("根据source_id和label标签合并段落") |
|
print("=" * 80) |
|
|
|
# 检查输入文件 |
|
if not os.path.exists(input_file): |
|
print(f"错误:输入文件 {input_file} 不存在!") |
|
print("请确保test.json文件在当前目录下") |
|
return |
|
|
|
try: |
|
# 1. 加载数据 |
|
data = load_test_data(input_file) |
|
if not data: |
|
print("没有有效数据可处理") |
|
return |
|
|
|
# 2. 按source_id分组 |
|
grouped_data = group_by_source_id(data) |
|
|
|
# 3. 处理每个source_id的数据 |
|
results = {} |
|
total_paragraphs = 0 |
|
|
|
for source_id, sentence_pairs in grouped_data.items(): |
|
result = process_single_source(source_id, sentence_pairs) |
|
results[source_id] = result |
|
total_paragraphs += result['merged_paragraphs_count'] |
|
|
|
# 4. 保存结果 |
|
print(f"\n保存结果...") |
|
save_to_txt(results, output_txt) |
|
save_summary_json(results, output_summary) |
|
|
|
# 5. 输出总结 |
|
print("=" * 80) |
|
print("处理完成!") |
|
print("=" * 80) |
|
print(f"📊 处理统计:") |
|
print(f" 🔹 处理的Source ID数量: {len(results)}") |
|
print(f" 🔹 原始句子对总数: {sum(r['original_pairs_count'] for r in results.values())}") |
|
print(f" 🔹 合并后段落总数: {total_paragraphs}") |
|
print(f" 🔹 总字符数: {sum(r['total_characters'] for r in results.values())}") |
|
|
|
print(f"\n📁 输出文件:") |
|
print(f" 📄 {output_txt} - 合并后的段落文本") |
|
print(f" 📄 {output_summary} - 处理统计摘要") |
|
|
|
print(f"\n📋 各Source ID详情:") |
|
for source_id, result in results.items(): |
|
print( |
|
f" Source {source_id}: {result['original_pairs_count']} 对 → {result['merged_paragraphs_count']} 段") |
|
|
|
print(f"\n✅ 段落合并完成!请查看 {output_txt} 文件") |
|
|
|
except Exception as e: |
|
print(f"❌ 处理过程中出现错误: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |