#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 根据source_id和label标签合并段落并输出txt文件 将label=0的连续句子合并,label=1作为分界点分段 """ import json import os from collections import defaultdict from typing import List, Dict, Any def load_test_data(file_path: str) -> List[Dict[str, Any]]: """加载测试数据""" try: with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) print(f"成功加载 {len(data)} 条数据") return data except FileNotFoundError: print(f"错误:找不到文件 {file_path}") return [] except json.JSONDecodeError as e: print(f"错误:JSON格式错误 - {e}") return [] except Exception as e: print(f"错误:加载文件时出现问题 - {e}") return [] def group_by_source_id(data: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: """按source_id分组数据""" grouped_data = defaultdict(list) for item in data: source_id = str(item.get('source_id', 'unknown')) grouped_data[source_id].append(item) # 对每个组内的数据按出现顺序排序(保持原有顺序) for source_id in grouped_data: # 如果数据中有索引信息,可以按索引排序 # 这里假设数据已经按正确顺序排列 pass print(f"数据按source_id分组完成,共 {len(grouped_data)} 个组") for source_id, items in grouped_data.items(): print(f" Source ID {source_id}: {len(items)} 条数据") return dict(grouped_data) def merge_paragraphs_by_labels(sentence_pairs: List[Dict[str, Any]]) -> List[str]: """ 根据label合并段落 label=0: 同一段落,需要合并 label=1: 不同段落,作为分界点 """ if not sentence_pairs: return [] paragraphs = [] current_paragraph = [] # 处理第一个句子 if sentence_pairs: current_paragraph.append(sentence_pairs[0]['sentence1']) # 遍历所有句子对 for i, pair in enumerate(sentence_pairs): sentence2 = pair['sentence2'] label = pair['label'] if label == 0: # 同一段落,继续添加到当前段落 # 只添加sentence2,因为sentence1已经在上一轮添加过了 if sentence2 not in current_paragraph: # 避免重复 current_paragraph.append(sentence2) elif label == 1: # 不同段落,结束当前段落,开始新段落 if current_paragraph: paragraph_text = ''.join(current_paragraph) if paragraph_text.strip(): # 确保段落不为空 paragraphs.append(paragraph_text.strip()) # 开始新段落 current_paragraph = [sentence2] # 处理最后一个段落 if current_paragraph: paragraph_text = ''.join(current_paragraph) if paragraph_text.strip(): paragraphs.append(paragraph_text.strip()) return paragraphs def process_single_source(source_id: str, sentence_pairs: List[Dict[str, Any]]) -> Dict[str, Any]: """处理单个source_id的数据""" print(f"\n处理Source ID: {source_id}") print(f"句子对数量: {len(sentence_pairs)}") # 统计标签分布 label_counts = defaultdict(int) for pair in sentence_pairs: label_counts[pair['label']] += 1 print(f"标签分布: Label 0 (同段): {label_counts[0]}, Label 1 (分段): {label_counts[1]}") # 合并段落 paragraphs = merge_paragraphs_by_labels(sentence_pairs) print(f"合并后段落数: {len(paragraphs)}") # 统计信息 total_chars = sum(len(p) for p in paragraphs) avg_paragraph_length = total_chars / len(paragraphs) if paragraphs else 0 return { 'source_id': source_id, 'original_pairs_count': len(sentence_pairs), 'merged_paragraphs_count': len(paragraphs), 'label_distribution': dict(label_counts), 'total_characters': total_chars, 'avg_paragraph_length': avg_paragraph_length, 'paragraphs': paragraphs } def save_to_txt(results: Dict[str, Dict[str, Any]], output_file: str): """保存结果到txt文件""" with open(output_file, 'w', encoding='utf-8') as f: f.write("=" * 80 + "\n") f.write("段落合并结果\n") f.write("根据source_id和label标签合并的段落文本\n") f.write("=" * 80 + "\n\n") for source_id, result in results.items(): f.write(f"【Source ID: {source_id}】\n") f.write(f"原始句子对数量: {result['original_pairs_count']}\n") f.write(f"合并后段落数量: {result['merged_paragraphs_count']}\n") f.write(f"标签分布: {result['label_distribution']}\n") f.write(f"总字符数: {result['total_characters']}\n") f.write(f"平均段落长度: {result['avg_paragraph_length']:.1f} 字符\n") f.write("-" * 60 + "\n") for i, paragraph in enumerate(result['paragraphs'], 1): f.write(f"段落 {i}:\n{paragraph}\n\n") f.write("=" * 80 + "\n\n") def save_summary_json(results: Dict[str, Dict[str, Any]], output_file: str): """保存统计摘要到JSON文件""" summary = { 'total_source_ids': len(results), 'total_original_pairs': sum(r['original_pairs_count'] for r in results.values()), 'total_merged_paragraphs': sum(r['merged_paragraphs_count'] for r in results.values()), 'total_characters': sum(r['total_characters'] for r in results.values()), 'source_details': {} } for source_id, result in results.items(): summary['source_details'][source_id] = { 'original_pairs_count': result['original_pairs_count'], 'merged_paragraphs_count': result['merged_paragraphs_count'], 'label_distribution': result['label_distribution'], 'total_characters': result['total_characters'], 'avg_paragraph_length': result['avg_paragraph_length'] } with open(output_file, 'w', encoding='utf-8') as f: json.dump(summary, f, ensure_ascii=False, indent=2) def main(): """主函数""" # 配置文件路径 input_file = 'segmentation_results_from_new50.json' # 输入文件路径 output_txt = 'merged_paragraphs.txt' # 输出txt文件 output_summary = 'merge_summary.json' # 输出统计摘要 print("=" * 80) print("段落合并处理程序") print("根据source_id和label标签合并段落") print("=" * 80) # 检查输入文件 if not os.path.exists(input_file): print(f"错误:输入文件 {input_file} 不存在!") print("请确保test.json文件在当前目录下") return try: # 1. 加载数据 data = load_test_data(input_file) if not data: print("没有有效数据可处理") return # 2. 按source_id分组 grouped_data = group_by_source_id(data) # 3. 处理每个source_id的数据 results = {} total_paragraphs = 0 for source_id, sentence_pairs in grouped_data.items(): result = process_single_source(source_id, sentence_pairs) results[source_id] = result total_paragraphs += result['merged_paragraphs_count'] # 4. 保存结果 print(f"\n保存结果...") save_to_txt(results, output_txt) save_summary_json(results, output_summary) # 5. 输出总结 print("=" * 80) print("处理完成!") print("=" * 80) print(f"📊 处理统计:") print(f" 🔹 处理的Source ID数量: {len(results)}") print(f" 🔹 原始句子对总数: {sum(r['original_pairs_count'] for r in results.values())}") print(f" 🔹 合并后段落总数: {total_paragraphs}") print(f" 🔹 总字符数: {sum(r['total_characters'] for r in results.values())}") print(f"\n📁 输出文件:") print(f" 📄 {output_txt} - 合并后的段落文本") print(f" 📄 {output_summary} - 处理统计摘要") print(f"\n📋 各Source ID详情:") for source_id, result in results.items(): print( f" Source {source_id}: {result['original_pairs_count']} 对 → {result['merged_paragraphs_count']} 段") print(f"\n✅ 段落合并完成!请查看 {output_txt} 文件") except Exception as e: print(f"❌ 处理过程中出现错误: {str(e)}") import traceback traceback.print_exc() if __name__ == "__main__": main()