You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

247 lines
8.5 KiB

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
根据source_id和label标签合并段落并输出txt文件
将label=0的连续句子合并,label=1作为分界点分段
"""
import json
import os
from collections import defaultdict
from typing import List, Dict, Any
def load_test_data(file_path: str) -> List[Dict[str, Any]]:
"""加载测试数据"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"成功加载 {len(data)} 条数据")
return data
except FileNotFoundError:
print(f"错误:找不到文件 {file_path}")
return []
except json.JSONDecodeError as e:
print(f"错误:JSON格式错误 - {e}")
return []
except Exception as e:
print(f"错误:加载文件时出现问题 - {e}")
return []
def group_by_source_id(data: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
"""按source_id分组数据"""
grouped_data = defaultdict(list)
for item in data:
source_id = str(item.get('source_id', 'unknown'))
grouped_data[source_id].append(item)
# 对每个组内的数据按出现顺序排序(保持原有顺序)
for source_id in grouped_data:
# 如果数据中有索引信息,可以按索引排序
# 这里假设数据已经按正确顺序排列
pass
print(f"数据按source_id分组完成,共 {len(grouped_data)} 个组")
for source_id, items in grouped_data.items():
print(f" Source ID {source_id}: {len(items)} 条数据")
return dict(grouped_data)
def merge_paragraphs_by_labels(sentence_pairs: List[Dict[str, Any]]) -> List[str]:
"""
根据label合并段落
label=0: 同一段落,需要合并
label=1: 不同段落,作为分界点
"""
if not sentence_pairs:
return []
paragraphs = []
current_paragraph = []
# 处理第一个句子
if sentence_pairs:
current_paragraph.append(sentence_pairs[0]['sentence1'])
# 遍历所有句子对
for i, pair in enumerate(sentence_pairs):
sentence2 = pair['sentence2']
label = pair['label']
if label == 0:
# 同一段落,继续添加到当前段落
# 只添加sentence2,因为sentence1已经在上一轮添加过了
if sentence2 not in current_paragraph: # 避免重复
current_paragraph.append(sentence2)
elif label == 1:
# 不同段落,结束当前段落,开始新段落
if current_paragraph:
paragraph_text = ''.join(current_paragraph)
if paragraph_text.strip(): # 确保段落不为空
paragraphs.append(paragraph_text.strip())
# 开始新段落
current_paragraph = [sentence2]
# 处理最后一个段落
if current_paragraph:
paragraph_text = ''.join(current_paragraph)
if paragraph_text.strip():
paragraphs.append(paragraph_text.strip())
return paragraphs
def process_single_source(source_id: str, sentence_pairs: List[Dict[str, Any]]) -> Dict[str, Any]:
"""处理单个source_id的数据"""
print(f"\n处理Source ID: {source_id}")
print(f"句子对数量: {len(sentence_pairs)}")
# 统计标签分布
label_counts = defaultdict(int)
for pair in sentence_pairs:
label_counts[pair['label']] += 1
print(f"标签分布: Label 0 (同段): {label_counts[0]}, Label 1 (分段): {label_counts[1]}")
# 合并段落
paragraphs = merge_paragraphs_by_labels(sentence_pairs)
print(f"合并后段落数: {len(paragraphs)}")
# 统计信息
total_chars = sum(len(p) for p in paragraphs)
avg_paragraph_length = total_chars / len(paragraphs) if paragraphs else 0
return {
'source_id': source_id,
'original_pairs_count': len(sentence_pairs),
'merged_paragraphs_count': len(paragraphs),
'label_distribution': dict(label_counts),
'total_characters': total_chars,
'avg_paragraph_length': avg_paragraph_length,
'paragraphs': paragraphs
}
def save_to_txt(results: Dict[str, Dict[str, Any]], output_file: str):
"""保存结果到txt文件"""
with open(output_file, 'w', encoding='utf-8') as f:
f.write("=" * 80 + "\n")
f.write("段落合并结果\n")
f.write("根据source_id和label标签合并的段落文本\n")
f.write("=" * 80 + "\n\n")
for source_id, result in results.items():
f.write(f"【Source ID: {source_id}\n")
f.write(f"原始句子对数量: {result['original_pairs_count']}\n")
f.write(f"合并后段落数量: {result['merged_paragraphs_count']}\n")
f.write(f"标签分布: {result['label_distribution']}\n")
f.write(f"总字符数: {result['total_characters']}\n")
f.write(f"平均段落长度: {result['avg_paragraph_length']:.1f} 字符\n")
f.write("-" * 60 + "\n")
for i, paragraph in enumerate(result['paragraphs'], 1):
f.write(f"段落 {i}:\n{paragraph}\n\n")
f.write("=" * 80 + "\n\n")
def save_summary_json(results: Dict[str, Dict[str, Any]], output_file: str):
"""保存统计摘要到JSON文件"""
summary = {
'total_source_ids': len(results),
'total_original_pairs': sum(r['original_pairs_count'] for r in results.values()),
'total_merged_paragraphs': sum(r['merged_paragraphs_count'] for r in results.values()),
'total_characters': sum(r['total_characters'] for r in results.values()),
'source_details': {}
}
for source_id, result in results.items():
summary['source_details'][source_id] = {
'original_pairs_count': result['original_pairs_count'],
'merged_paragraphs_count': result['merged_paragraphs_count'],
'label_distribution': result['label_distribution'],
'total_characters': result['total_characters'],
'avg_paragraph_length': result['avg_paragraph_length']
}
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
def main():
"""主函数"""
# 配置文件路径
input_file = 'segmentation_results_from_new50.json' # 输入文件路径
output_txt = 'merged_paragraphs.txt' # 输出txt文件
output_summary = 'merge_summary.json' # 输出统计摘要
print("=" * 80)
print("段落合并处理程序")
print("根据source_id和label标签合并段落")
print("=" * 80)
# 检查输入文件
if not os.path.exists(input_file):
print(f"错误:输入文件 {input_file} 不存在!")
print("请确保test.json文件在当前目录下")
return
try:
# 1. 加载数据
data = load_test_data(input_file)
if not data:
print("没有有效数据可处理")
return
# 2. 按source_id分组
grouped_data = group_by_source_id(data)
# 3. 处理每个source_id的数据
results = {}
total_paragraphs = 0
for source_id, sentence_pairs in grouped_data.items():
result = process_single_source(source_id, sentence_pairs)
results[source_id] = result
total_paragraphs += result['merged_paragraphs_count']
# 4. 保存结果
print(f"\n保存结果...")
save_to_txt(results, output_txt)
save_summary_json(results, output_summary)
# 5. 输出总结
print("=" * 80)
print("处理完成!")
print("=" * 80)
print(f"📊 处理统计:")
print(f" 🔹 处理的Source ID数量: {len(results)}")
print(f" 🔹 原始句子对总数: {sum(r['original_pairs_count'] for r in results.values())}")
print(f" 🔹 合并后段落总数: {total_paragraphs}")
print(f" 🔹 总字符数: {sum(r['total_characters'] for r in results.values())}")
print(f"\n📁 输出文件:")
print(f" 📄 {output_txt} - 合并后的段落文本")
print(f" 📄 {output_summary} - 处理统计摘要")
print(f"\n📋 各Source ID详情:")
for source_id, result in results.items():
print(
f" Source {source_id}: {result['original_pairs_count']} 对 → {result['merged_paragraphs_count']}")
print(f"\n✅ 段落合并完成!请查看 {output_txt} 文件")
except Exception as e:
print(f"❌ 处理过程中出现错误: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()